Skip to content

Commit

Permalink
Add -pfc for prefix file cache
Browse files Browse the repository at this point in the history
Modify prefix logic to support different models running the same prompt without exiting.
  • Loading branch information
HoiV committed May 3, 2024
1 parent b50cac0 commit f1d6d0c
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 20 deletions.
4 changes: 4 additions & 0 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,10 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
params.custom_p_file = argv[i];
return true;
}
if (arg == "-pfc" || arg == "--prefix_cache") {
params.use_prefix_cache = true;
return true;
}
if (arg == "-n" || arg == "--n-predict") {
if (++i >= argc) {
invalid_param = true;
Expand Down
1 change: 1 addition & 0 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ struct gpt_params {
bool check_tensors = false; // validate tensor data
bool scripted = false; // input is scripted
bool custom_prompts_on = false; // custom prompts are available
bool use_prefix_cache = false; // use prefix cache if it exists and create one if it is not present

std::string cache_type_k = "f16"; // KV cache data type for the K
std::string cache_type_v = "f16"; // KV cache data type for the V
Expand Down
52 changes: 32 additions & 20 deletions examples/main/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,13 @@ bool processCustomPromptsFromFile(const std::string& custom_p_file) {

std::string pfx_file_path(std::string pfx) {
static std::string dir = "shared_prefix";
if (!CreateDirectory(dir.c_str(), NULL)) {
if (GetLastError() != ERROR_ALREADY_EXISTS) {
fprintf(stderr, "%s: Failed to create directory: %s - use current dir for prefix cache\n",
__func__, dir.c_str());
dir = ".";
}
}
static std::hash<std::string> hasher;
return dir + "/" + std::to_string(hasher(pfx));
}
Expand Down Expand Up @@ -1019,36 +1026,41 @@ int main(int argc, char ** argv) {
pfx_shared = full_custom_prompt.substr(0, pos);
std::string pfx_file = pfx_file_path(pfx_shared);

// note tokenize(a) + tokenize(b) != tokenize(a+b), we tokenize pfx and content seperately
// note tokenize(a) + tokenize(b) != tokenize(a+b), we tokenize pfx and content separately
const auto line_pfx_shared = ::llama_tokenize(ctx, pfx_shared, false, false);
embd_inp.insert(embd_inp.end(), line_pfx_shared.begin(), line_pfx_shared.end());

if (file_exists(pfx_file)) {
if (params.use_prefix_cache && file_exists(pfx_file)) {
// The file exists and is not empty
session_tokens.resize(n_ctx);
size_t n_token_count_out = 0;
if (!llama_state_load_file(ctx, pfx_file.c_str(), session_tokens.data(), session_tokens.capacity(), &n_token_count_out)) {
LOG_TEE("%s: error: failed to load session file '%s'\n", __func__, pfx_file.c_str());
return 1;
}
session_tokens.resize(n_token_count_out);
llama_set_rng_seed(ctx, params.seed);
LOG_TEE("%s: loaded a session with prompt size of %d tokens\n", __func__, (int)session_tokens.size());

// sanity check
GGML_ASSERT(line_pfx_shared.size() <= session_tokens.size());
for (size_t i = 0; i < line_pfx_shared.size(); i++) {
GGML_ASSERT(line_pfx_shared[i] == session_tokens[i]);
LOG_TEE("> %s: error: failed to load session file '%s' - create new file\n",
__func__, pfx_file.c_str());
session_tokens.resize(0);
need_save_pfx = true;
} else {
session_tokens.resize(n_token_count_out);
llama_set_rng_seed(ctx, params.seed);
LOG_TEE("> %s: loaded saved session '%s' with prompt size of (%d) tokens\n",
__func__, pfx_file.c_str(), (int)session_tokens.size());

// sanity check
GGML_ASSERT(line_pfx_shared.size() <= session_tokens.size());
for (size_t i = 0; i < line_pfx_shared.size(); i++) {
GGML_ASSERT(line_pfx_shared[i] == session_tokens[i]);
}
// remove any "future" tokens that we might have inherited from the previous session
llama_kv_cache_seq_rm(ctx, -1, line_pfx_shared.size(), -1);
n_consumed += line_pfx_shared.size();
n_past += line_pfx_shared.size();
}
// remove any "future" tokens that we might have inherited from the previous session
llama_kv_cache_seq_rm(ctx, -1, line_pfx_shared.size(), -1);
n_consumed += line_pfx_shared.size();
n_past += line_pfx_shared.size();
}
else {
} else {
// todo: shared position for saving
//buffer = full_custom_prompt;
need_save_pfx = true;
if (params.use_prefix_cache) {
need_save_pfx = true;
}
}

// construct complete prompt
Expand Down

0 comments on commit f1d6d0c

Please sign in to comment.