diff --git a/common/common.cpp b/common/common.cpp index cf69535e2d1f5..8c11863cde5eb 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -2947,3 +2947,46 @@ llama_control_vector_data llama_control_vector_load(const std::vector & messages, const std::string & tmpl) { + auto apply_chat_template = [&tmpl](const std::vector & msgs, size_t delta, bool add_ass) { + std::vector chat(msgs.size()); + size_t alloc_size = 0; + size_t chat_size = chat.size() - delta; + for (size_t i = 0; i < msgs.size(); ++i) { + chat[i].role = msgs[i].role.c_str(); + chat[i].content = msgs[i].content.c_str(); + alloc_size += msgs[i].role.size() + msgs[i].content.size(); + } + + const char * ptr_tmpl = tmpl.empty() ? nullptr : tmpl.c_str(); + std::vector buf(alloc_size * 2); + + // run the first time to get the total output length + int32_t res = llama_chat_apply_template(nullptr, ptr_tmpl, chat.data(), chat_size, add_ass, buf.data(), buf.size()); + + // if it turns out that our buffer is too small, we resize it + if ((size_t) res > buf.size()) { + buf.resize(res); + res = llama_chat_apply_template(nullptr, ptr_tmpl, chat.data(), chat_size, add_ass, buf.data(), buf.size()); + } + + const std::string formatted_chat(buf.data(), res); + return formatted_chat; + }; + + std::string formatted_chat_last = messages.size() > 0 + ? apply_chat_template(messages, 1, false) // (n_msgs - 1) messages + : ""; + std::string formatted_chat_curr = apply_chat_template(messages, 0, true); + + // Extract the added part (user prompt) + auto get_diff_part = [](const std::string & str1, const std::string & str2) { + size_t i = 0; + while (i < str1.size() && i < str2.size() && str1[i] == str2[i]) + ++i; + return str2.substr(i); + }; + return get_diff_part(formatted_chat_last, formatted_chat_curr); +} diff --git a/common/common.h b/common/common.h index cca44268e6df5..f31994f1a61ca 100644 --- a/common/common.h +++ b/common/common.h @@ -322,3 +322,13 @@ llama_control_vector_data llama_control_vector_load(const std::vector & messages, const std::string & tmpl); diff --git a/tests/test-chat-template.cpp b/tests/test-chat-template.cpp index 522cc7d0d9e84..728e65621a44a 100644 --- a/tests/test-chat-template.cpp +++ b/tests/test-chat-template.cpp @@ -7,6 +7,7 @@ #include #include "llama.h" +#include "common.h" int main(void) { llama_chat_message conversation[] = { @@ -96,8 +97,19 @@ int main(void) { ); formatted_chat.resize(res); std::string output(formatted_chat.data(), formatted_chat.size()); - std::cout << output << "\n-------------------------\n"; + std::cout << output << "\n-----\n"; assert(output == expected); + + std::vector v_messages; + for (size_t i = 0; i < message_count; ++i) { + v_messages.push_back({ + conversation[i].role, + conversation[i].content, + }); + } + std::cout << "chat_get_added_part(): " << chat_get_added_part(v_messages, custom_template); + std::cout << "\n-------------------------\n"; + // TODO: chat_get_added_part is currently printed for debugging. Should we add tests for it in the future? } return 0; }