Skip to content

Commit

Permalink
Add chatml fallback for cpp llama_chat_apply_template (#8160)
Browse files Browse the repository at this point in the history
* add chatml fallback for cpp `llama_chat_apply_template`

* remove redundant code
  • Loading branch information
ngxson authored Jun 27, 2024
1 parent ab36791 commit 16791b8
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 1 deletion.
19 changes: 18 additions & 1 deletion common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2618,6 +2618,7 @@ std::string llama_chat_apply_template(const struct llama_model * model,
const std::vector<llama_chat_msg> & msgs,
bool add_ass) {
int alloc_size = 0;
bool fallback = false; // indicate if we must fallback to default chatml
std::vector<llama_chat_message> chat;
for (auto & msg : msgs) {
chat.push_back({msg.role.c_str(), msg.content.c_str()});
Expand All @@ -2630,10 +2631,26 @@ std::string llama_chat_apply_template(const struct llama_model * model,
// run the first time to get the total output length
int32_t res = llama_chat_apply_template(model, ptr_tmpl, chat.data(), chat.size(), add_ass, buf.data(), buf.size());

// error: chat template is not supported
if (res < 0) {
if (ptr_tmpl != nullptr) {
// if the custom "tmpl" is not supported, we throw an error
// this is a bit redundant (for good), since we're not sure if user validated the custom template with llama_chat_verify_template()
throw std::runtime_error("this custom template is not supported");
} else {
// If the built-in template is not supported, we default to chatml
res = llama_chat_apply_template(nullptr, "chatml", chat.data(), chat.size(), add_ass, buf.data(), buf.size());
fallback = true;
}
}

// if it turns out that our buffer is too small, we resize it
if ((size_t) res > buf.size()) {
buf.resize(res);
res = llama_chat_apply_template(model, ptr_tmpl, chat.data(), chat.size(), add_ass, buf.data(), buf.size());
res = llama_chat_apply_template(
fallback ? nullptr : model,
fallback ? "chatml" : ptr_tmpl,
chat.data(), chat.size(), add_ass, buf.data(), buf.size());
}

std::string formatted_chat(buf.data(), res);
Expand Down
2 changes: 2 additions & 0 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,8 @@ struct llama_chat_msg {
bool llama_chat_verify_template(const std::string & tmpl);

// CPP wrapper for llama_chat_apply_template
// If the built-in template is not supported, we default to chatml
// If the custom "tmpl" is not supported, we throw an error
std::string llama_chat_apply_template(const struct llama_model * model,
const std::string & tmpl,
const std::vector<llama_chat_msg> & chat,
Expand Down

0 comments on commit 16791b8

Please sign in to comment.