Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add chat template support for llama-cli #8068

Merged
merged 11 commits into from
Jun 25, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 46 additions & 1 deletion common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1814,7 +1814,10 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
options.push_back({ "main", " --cfg-negative-prompt-file FNAME",
"negative prompt file to use for guidance" });
options.push_back({ "main", " --cfg-scale N", "strength of guidance (default: %.1f, 1.0 = disable)", (double)sparams.cfg_scale });

options.push_back({ "main", " --chat-template JINJA_TEMPLATE",
"set custom jinja chat template (default: template taken from model's metadata)\n"
"only commonly used templates are accepted:\n"
"https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template" });
options.push_back({ "grammar" });
options.push_back({ "*", " --grammar GRAMMAR", "BNF-like grammar to constrain generations (see samples in grammars/ dir) (default: '%s')", sparams.grammar.c_str() });
options.push_back({ "*", " --grammar-file FNAME", "file to read grammar from" });
Expand Down Expand Up @@ -2967,12 +2970,54 @@ bool llama_should_add_bos_token(const llama_model * model) {
return add_bos != -1 ? bool(add_bos) : (llama_vocab_type(model) == LLAMA_VOCAB_TYPE_SPM);
}

//
// Chat template utils
//

bool llama_chat_verify_template(const std::string & tmpl) {
llama_chat_message chat[] = {{"user", "test"}};
int res = llama_chat_apply_template(nullptr, tmpl.c_str(), chat, 1, true, nullptr, 0);
return res >= 0;
}

std::string llama_chat_apply_template(const struct llama_model * model,
const std::string & tmpl,
const std::vector<llama_chat_msg> & msgs,
bool add_ass) {
std::vector<llama_chat_message> chat;
for (auto & msg : msgs) {
chat.push_back({msg.role.c_str(), msg.content.c_str()});
}

const char * ptr_tmpl = tmpl.empty() ? nullptr : tmpl.c_str();
std::vector<char> buf;
ngxson marked this conversation as resolved.
Show resolved Hide resolved

// run the first time to get the total output length
int32_t res = llama_chat_apply_template(model, ptr_tmpl, chat.data(), chat.size(), add_ass, buf.data(), buf.size());

// if it turns out that our buffer is too small, we resize it
if ((size_t) res > buf.size()) {
buf.resize(res);
res = llama_chat_apply_template(model, ptr_tmpl, chat.data(), chat.size(), add_ass, buf.data(), buf.size());
}

const std::string formatted_chat(buf.data(), res);
return formatted_chat;
ngxson marked this conversation as resolved.
Show resolved Hide resolved
}

std::string llama_chat_format_single(const struct llama_model * model,
const std::string & tmpl,
const std::vector<llama_chat_msg> & past_msg,
const llama_chat_msg & new_msg,
bool add_ass) {
auto fmt_past_msg = llama_chat_apply_template(model, tmpl, past_msg, false);
std::vector<llama_chat_msg> chat_new(past_msg);
chat_new.push_back(new_msg);
auto fmt_new_msg = llama_chat_apply_template(model, tmpl, chat_new, add_ass);
auto formatted = fmt_new_msg.substr(fmt_past_msg.size(), fmt_new_msg.size() - fmt_past_msg.size());
return formatted;
}

//
// KV cache utils
//
Expand Down
19 changes: 19 additions & 0 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -360,9 +360,28 @@ bool llama_should_add_bos_token(const llama_model * model);
// Chat template utils
//

// same with llama_chat_message, but uses std::string
struct llama_chat_msg {
std::string role;
std::string content;
};

// Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid
bool llama_chat_verify_template(const std::string & tmpl);

// CPP wrapper for llama_chat_apply_template
std::string llama_chat_apply_template(const struct llama_model * model,
const std::string & tmpl,
const std::vector<llama_chat_msg> & chat,
bool add_ass);

// Format single message, while taking into account the position of that message in chat history
std::string llama_chat_format_single(const struct llama_model * model,
const std::string & tmpl,
const std::vector<llama_chat_msg> & past_msg,
const llama_chat_msg & new_msg,
bool add_ass);

//
// KV cache utils
//
Expand Down
67 changes: 49 additions & 18 deletions examples/main/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,20 +31,21 @@
#pragma warning(disable: 4244 4267) // possible loss of data
#endif

static llama_context ** g_ctx;
static llama_model ** g_model;
static gpt_params * g_params;
static std::vector<llama_token> * g_input_tokens;
static std::ostringstream * g_output_ss;
static std::vector<llama_token> * g_output_tokens;
static llama_context ** g_ctx;
static llama_model ** g_model;
static gpt_params * g_params;
static std::vector<llama_token> * g_input_tokens;
static std::ostringstream * g_output_ss;
static std::vector<llama_token> * g_output_tokens;
static std::vector<llama_chat_msg> * g_chat_msgs;
ngxson marked this conversation as resolved.
Show resolved Hide resolved
static bool is_interacting = false;

static bool file_exists(const std::string &path) {
static bool file_exists(const std::string & path) {
std::ifstream f(path.c_str());
return f.good();
}

static bool file_is_empty(const std::string &path) {
static bool file_is_empty(const std::string & path) {
std::ifstream f;
f.exceptions(std::ifstream::failbit | std::ifstream::badbit);
f.open(path.c_str(), std::ios::in | std::ios::binary | std::ios::ate);
Expand Down Expand Up @@ -117,6 +118,14 @@ static void llama_log_callback_logTee(ggml_log_level level, const char * text, v
LOG_TEE("%s", text);
}

static std::string chat_add_and_format(std::string role, std::string content) {
llama_chat_msg new_msg{role, content};
auto formatted = llama_chat_format_single(
*g_model, g_params->chat_template, *g_chat_msgs, new_msg, role == "user");
g_chat_msgs->push_back({role, content});
return formatted;
}

int main(int argc, char ** argv) {
gpt_params params;
g_params = &params;
Expand Down Expand Up @@ -190,8 +199,10 @@ int main(int argc, char ** argv) {
llama_model * model;
llama_context * ctx;
llama_context * ctx_guidance = NULL;
std::vector<llama_chat_msg> chat_msgs;
g_model = &model;
g_ctx = &ctx;
g_chat_msgs = &chat_msgs;

// load the model and apply lora adapter, if any
LOG("%s: load the model and apply lora adapter, if any\n", __func__);
Expand Down Expand Up @@ -249,16 +260,21 @@ int main(int argc, char ** argv) {

std::vector<llama_token> embd_inp;

if (params.interactive_first || !params.prompt.empty() || session_tokens.empty()) {
LOG("tokenize the prompt\n");
embd_inp = ::llama_tokenize(ctx, params.prompt, true, true);
} else {
LOG("use session tokens\n");
embd_inp = session_tokens;
}
{
auto prompt = params.conversation
? chat_add_and_format("system", params.prompt) // format the system prompt in conversation mode
: params.prompt;
if (params.interactive_first || !params.prompt.empty() || session_tokens.empty()) {
LOG("tokenize the prompt\n");
embd_inp = ::llama_tokenize(ctx, prompt, true, true);
} else {
LOG("use session tokens\n");
embd_inp = session_tokens;
}

LOG("prompt: \"%s\"\n", log_tostr(params.prompt));
LOG("tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_inp).c_str());
LOG("prompt: \"%s\"\n", log_tostr(prompt));
LOG("tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_inp).c_str());
}

// Should not run without any tokens
if (embd_inp.empty()) {
Expand Down Expand Up @@ -478,6 +494,7 @@ int main(int argc, char ** argv) {
std::vector<int> input_tokens; g_input_tokens = &input_tokens;
std::vector<int> output_tokens; g_output_tokens = &output_tokens;
std::ostringstream output_ss; g_output_ss = &output_ss;
std::ostringstream assistant_ss; // for storing current assistant message, used in conversation mode

// the first thing we will do is to output the prompt, so set color accordingly
console::set_display(console::prompt);
Expand Down Expand Up @@ -793,11 +810,18 @@ int main(int argc, char ** argv) {
is_antiprompt = true;
}

chat_add_and_format("system", assistant_ss.str());
is_interacting = true;
printf("\n");
}
}

// if current token is not EOG, we add it to current assistant message
if (params.conversation) {
auto id = llama_sampling_last(ctx_sampling);
assistant_ss << llama_token_to_piece(ctx, id, false);
}

if (n_past > 0 && is_interacting) {
LOG("waiting for user input\n");

Expand Down Expand Up @@ -848,8 +872,12 @@ int main(int argc, char ** argv) {
string_process_escapes(buffer);
}

std::string user_inp = params.conversation
? chat_add_and_format("user", buffer)
: buffer;
// TODO: one inconvenient of current chat template implementation is that we can't distinguish between user input and special tokens (prefix/postfix)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When params.conversation == false there is an extra string copy that should be avoided here

Regarding the comment - can you illustrate with an example as I'm not sure what is the issue

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An example would be a prompt like this: Which one is correct HTML tag? <s> or <a>?

Some models having <s> as BOS will see the prompt as Which one is correct HTML tag? BOS or <a>?

Leaving special == false will fix that, but will also break chat template since we're now adding special tokens to user's text. This could be avoided with some more code. But IMO it's not really a big deal though, assuming that special tokens are unlikely to accidentally appear in the text.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a std::move(buffer) since we no longer use buffer after this line. Is it OK to do so?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aha got it. Yes, for now let's make have the simple solution

const auto line_pfx = ::llama_tokenize(ctx, params.input_prefix, false, true);
const auto line_inp = ::llama_tokenize(ctx, buffer, false, false);
const auto line_inp = ::llama_tokenize(ctx, user_inp, false, params.conversation);
const auto line_sfx = ::llama_tokenize(ctx, params.input_suffix, false, true);

LOG("input tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, line_inp).c_str());
Expand All @@ -864,6 +892,9 @@ int main(int argc, char ** argv) {
output_ss << llama_token_to_piece(ctx, token);
}

// reset assistant message
assistant_ss.str("");

n_remain -= line_inp.size();
LOG("n_remain: %d\n", n_remain);
} else {
Expand Down
29 changes: 5 additions & 24 deletions examples/server/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,36 +118,17 @@ static inline void server_log(const char * level, const char * function, int lin

// Format given chat. If tmpl is empty, we take the template from model metadata
inline std::string format_chat(const struct llama_model * model, const std::string & tmpl, const std::vector<json> & messages) {
size_t alloc_size = 0;
// vector holding all allocated string to be passed to llama_chat_apply_template
std::vector<std::string> str(messages.size() * 2);
std::vector<llama_chat_message> chat(messages.size());
std::vector<llama_chat_msg> chat;

for (size_t i = 0; i < messages.size(); ++i) {
const auto & curr_msg = messages[i];
str[i*2 + 0] = json_value(curr_msg, "role", std::string(""));
str[i*2 + 1] = json_value(curr_msg, "content", std::string(""));
alloc_size += str[i*2 + 1].length();
chat[i].role = str[i*2 + 0].c_str();
chat[i].content = str[i*2 + 1].c_str();
std::string role = json_value(curr_msg, "role", std::string(""));
std::string content = json_value(curr_msg, "content", std::string(""));
chat.push_back({role, content});
}

const char * ptr_tmpl = tmpl.empty() ? nullptr : tmpl.c_str();
std::vector<char> buf(alloc_size * 2);

// run the first time to get the total output length
int32_t res = llama_chat_apply_template(model, ptr_tmpl, chat.data(), chat.size(), true, buf.data(), buf.size());

// if it turns out that our buffer is too small, we resize it
if ((size_t) res > buf.size()) {
buf.resize(res);
res = llama_chat_apply_template(model, ptr_tmpl, chat.data(), chat.size(), true, buf.data(), buf.size());
}

const std::string formatted_chat(buf.data(), res);

auto formatted_chat = llama_chat_apply_template(model, tmpl, chat, true);
LOG_VERBOSE("formatted_chat", {{"text", formatted_chat.c_str()}});

return formatted_chat;
}

Expand Down
4 changes: 2 additions & 2 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18589,10 +18589,10 @@ static int32_t llama_chat_apply_template_internal(
if (add_ass) {
ss << "<|im_start|>assistant\n";
}
} else if (tmpl == "llama2" || tmpl.find("[INST]") != std::string::npos) {
} else if (tmpl == "llama2" || tmpl == "mistral" || tmpl.find("[INST]") != std::string::npos) {
// llama2 template and its variants
// [variant] support system message
bool support_system_message = tmpl.find("<<SYS>>") != std::string::npos;
bool support_system_message = tmpl.find("<<SYS>>") != std::string::npos || tmpl == "mistral";
// [variant] space before + after response
bool space_around_response = tmpl.find("' ' + eos_token") != std::string::npos;
// [variant] add BOS inside history
Expand Down
20 changes: 20 additions & 0 deletions tests/test-chat-template.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <cassert>

#include "llama.h"
#include "common.h"

int main(void) {
llama_chat_message conversation[] = {
Expand Down Expand Up @@ -119,5 +120,24 @@ int main(void) {
std::cout << output << "\n-------------------------\n";
assert(output == expected);
}

// test llama_chat_format_single
std::cout << "\n\n=== llama_chat_format_single ===\n\n";
std::vector<llama_chat_msg> chat2;
chat2.push_back({"system", "You are a helpful assistant"});
chat2.push_back({"user", "Hello"});
chat2.push_back({"assistant", "I am assistant"});
llama_chat_msg new_msg{"user", "How are you"};

auto fmt_single = [&](std::string tmpl) {
auto output = llama_chat_format_single(nullptr, tmpl, chat2, new_msg, true);
std::cout << "fmt_single(" << tmpl << ")\n" << output << "\n-------------------------\n";
return output;
};
assert(fmt_single("chatml") == "<|im_start|>user\nHow are you<|im_end|>\n<|im_start|>assistant\n");
assert(fmt_single("llama2") == "[INST] How are you [/INST]");
assert(fmt_single("gemma") == "<start_of_turn>user\nHow are you<end_of_turn>\n<start_of_turn>model\n");
assert(fmt_single("llama3") == "<|start_header_id|>user<|end_header_id|>\n\nHow are you<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n");

return 0;
}
Loading