Skip to content

Commit

Permalink
llama_chat_apply_template: use term "chat" everywhere
Browse files Browse the repository at this point in the history
  • Loading branch information
ngxson committed Feb 18, 2024
1 parent dba4337 commit 73fbd67
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 15 deletions.
28 changes: 17 additions & 11 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12459,7 +12459,10 @@ int32_t llama_token_to_piece(const struct llama_model * model, llama_token token
return 0;
}

int32_t llama_chat_apply_template_internal(std::string &dest, std::string chat_template, std::vector<const llama_chat_message *> conversation, bool add_ass);
int32_t llama_chat_apply_template_internal(
const std::string & chat_template,
const std::vector<const llama_chat_message *> & chat,
std::string & dest, bool add_ass);

// trim whitespace from the beginning and end of a string
static std::string trim(const std::string & str) {
Expand All @@ -12476,12 +12479,15 @@ static std::string trim(const std::string & str) {

// Simple version of "llama_apply_chat_template" that only works with strings
// This function uses heuristic checks to determine commonly used template. It is not a jinja parser.
int32_t llama_chat_apply_template_internal(std::string &dest, std::string chat_template, std::vector<const llama_chat_message *> conversation, bool add_ass) {
int32_t llama_chat_apply_template_internal(
const std::string & chat_template,
const std::vector<const llama_chat_message *> & chat,
std::string & dest, bool add_ass) {
// Taken from the research: https://github.com/ggerganov/llama.cpp/issues/5527
std::stringstream ss;
if (chat_template.find("<|im_start|>") != std::string::npos) {
// chatml template
for (auto message : conversation) {
for (auto message : chat) {
ss << "<|im_start|>" << message->role << "\n" << message->content << "<|im_end|>\n";
}
if (add_ass) {
Expand All @@ -12500,7 +12506,7 @@ int32_t llama_chat_apply_template_internal(std::string &dest, std::string chat_t
// construct the prompt
bool is_inside_turn = true; // skip BOS at the beginning
ss << "[INST] ";
for (auto message : conversation) {
for (auto message : chat) {
std::string content = strip_message ? trim(message->content) : message->content;
std::string role(message->role);
if (!is_inside_turn) {
Expand All @@ -12524,7 +12530,7 @@ int32_t llama_chat_apply_template_internal(std::string &dest, std::string chat_t
// llama2 templates seem to not care about "add_generation_prompt"
} else if (chat_template.find("<|user|>") != std::string::npos) {
// zephyr template
for (auto message : conversation) {
for (auto message : chat) {
ss << "<|" << message->role << "|>" << "\n" << message->content << "<|endoftext|>\n";
}
if (add_ass) {
Expand All @@ -12541,7 +12547,7 @@ int32_t llama_chat_apply_template_internal(std::string &dest, std::string chat_t
LLAMA_API int32_t llama_chat_apply_template(
const struct llama_model * model,
const char * custom_template,
const struct llama_chat_message * msg,
const struct llama_chat_message * chat,
size_t n_msg,
bool add_ass,
char * buf,
Expand All @@ -12560,14 +12566,14 @@ LLAMA_API int32_t llama_chat_apply_template(
current_template = std::string(model_template.data(), model_template.size());
}
}
// format the conversation to string
std::vector<const llama_chat_message *> conversation_vec;
conversation_vec.resize(n_msg);
// format the chat to string
std::vector<const llama_chat_message *> chat_vec;
chat_vec.resize(n_msg);
for (size_t i = 0; i < n_msg; i++) {
conversation_vec[i] = &msg[i];
chat_vec[i] = &chat[i];
}
std::string formatted_chat;
int32_t res = llama_chat_apply_template_internal(formatted_chat, current_template, conversation_vec, add_ass);
int32_t res = llama_chat_apply_template_internal(current_template, chat_vec, formatted_chat, add_ass);
if (res < 0) {
return res;
}
Expand Down
10 changes: 6 additions & 4 deletions llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -704,18 +704,20 @@ extern "C" {
char * buf,
int32_t length);

/// Apply chat template and maybe tokenize it. Inspired by hf apply_chat_template() on python.
/// Apply chat template. Inspired by hf apply_chat_template() on python.
/// Both "model" and "custom_template" are optional, but at least one is required. "custom_template" has higher precedence than "model"
/// NOTE: This function only support some known jinja templates. It is not a jinja parser.
/// @param custom_template A Jinja template to use for this conversion. If this is nullptr, the model’s default chat template will be used instead.
/// @param msg Pointer to a list of multiple llama_chat_message
/// @param custom_template A Jinja template to use for this chat. If this is nullptr, the model’s default chat template will be used instead.
/// @param chat Pointer to a list of multiple llama_chat_message
/// @param n_msg Number of llama_chat_message in this chat
/// @param add_ass Whether to end the prompt with the token(s) that indicate the start of an assistant message.
/// @param buf A buffer to hold the output formatted prompt. The recommended alloc size is 2 * (total number of characters of all messages)
/// @param length The size of the allocated buffer
/// @return The total number of bytes of the formatted prompt. If is it larger than the size of buffer, you may need to re-alloc it and then re-apply the template.
LLAMA_API int32_t llama_chat_apply_template(
const struct llama_model * model,
const char * custom_template,
const struct llama_chat_message * msg,
const struct llama_chat_message * chat,
size_t n_msg,
bool add_ass,
char * buf,
Expand Down

0 comments on commit 73fbd67

Please sign in to comment.