Skip to content

Commit

Permalink
server : fix crash when prompt exceeds context size (#3996)
Browse files Browse the repository at this point in the history
  • Loading branch information
z80maniac authored Nov 11, 2023
1 parent 34b0a08 commit d96ca7d
Showing 1 changed file with 29 additions and 29 deletions.
58 changes: 29 additions & 29 deletions examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1557,6 +1557,35 @@ struct llama_server_context

slot.num_prompt_tokens = prompt_tokens.size();

if (slot.params.n_keep < 0)
{
slot.params.n_keep = slot.num_prompt_tokens;
}
slot.params.n_keep = std::min(slot.n_ctx - 4, slot.params.n_keep);

// if input prompt is too big, truncate it
if (slot.num_prompt_tokens >= slot.n_ctx)
{
const int n_left = slot.n_ctx - slot.params.n_keep;
const int n_block_size = n_left / 2;
const int erased_blocks = (slot.num_prompt_tokens - slot.params.n_keep - n_block_size) / n_block_size;

std::vector<llama_token> new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + slot.params.n_keep);
new_tokens.insert(new_tokens.end(), prompt_tokens.begin() + slot.params.n_keep + erased_blocks * n_block_size, prompt_tokens.end());

LOG_VERBOSE("input truncated", {
{"n_ctx", slot.n_ctx},
{"n_keep", slot.params.n_keep},
{"n_left", n_left},
{"new_tokens", tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend())},
});
slot.truncated = true;
prompt_tokens = new_tokens;

slot.num_prompt_tokens = prompt_tokens.size();
GGML_ASSERT(slot.num_prompt_tokens < slot.n_ctx);
}

if (!slot.params.cache_prompt)
{
llama_sampling_reset(slot.ctx_sampling);
Expand All @@ -1566,35 +1595,6 @@ struct llama_server_context
}
else
{
if (slot.params.n_keep < 0)
{
slot.params.n_keep = slot.num_prompt_tokens;
}
slot.params.n_keep = std::min(slot.n_ctx - 4, slot.params.n_keep);

// if input prompt is too big, truncate it
if (slot.num_prompt_tokens >= slot.n_ctx)
{
const int n_left = slot.n_ctx - slot.params.n_keep;
const int n_block_size = n_left / 2;
const int erased_blocks = (slot.num_prompt_tokens - slot.params.n_keep - n_block_size) / n_block_size;

std::vector<llama_token> new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + slot.params.n_keep);
new_tokens.insert(new_tokens.end(), prompt_tokens.begin() + slot.params.n_keep + erased_blocks * n_block_size, prompt_tokens.end());

LOG_VERBOSE("input truncated", {
{"n_ctx", slot.n_ctx},
{"n_keep", slot.params.n_keep},
{"n_left", n_left},
{"new_tokens", tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend())},
});
slot.truncated = true;
prompt_tokens = new_tokens;

slot.num_prompt_tokens = prompt_tokens.size();
GGML_ASSERT(slot.num_prompt_tokens < slot.n_ctx);
}

// push the prompt into the sampling context (do not apply grammar)
for (auto &token : prompt_tokens)
{
Expand Down

0 comments on commit d96ca7d

Please sign in to comment.