Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use llama_chat_apply_template in main (WIP) #6810

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2947,3 +2947,46 @@ llama_control_vector_data llama_control_vector_load(const std::vector<llama_cont

return result;
}

// apply chat template for (n_msgs - 1) and (n_msgs), then get the added part
std::string chat_get_added_part(const std::vector<chat_message> & messages, const std::string & tmpl) {
auto apply_chat_template = [&tmpl](const std::vector<chat_message> & msgs, size_t delta, bool add_ass) {
std::vector<llama_chat_message> chat(msgs.size());
size_t alloc_size = 0;
size_t chat_size = chat.size() - delta;
for (size_t i = 0; i < msgs.size(); ++i) {
chat[i].role = msgs[i].role.c_str();
chat[i].content = msgs[i].content.c_str();
alloc_size += msgs[i].role.size() + msgs[i].content.size();
}

const char * ptr_tmpl = tmpl.empty() ? nullptr : tmpl.c_str();
std::vector<char> buf(alloc_size * 2);

// run the first time to get the total output length
int32_t res = llama_chat_apply_template(nullptr, ptr_tmpl, chat.data(), chat_size, add_ass, buf.data(), buf.size());

// if it turns out that our buffer is too small, we resize it
if ((size_t) res > buf.size()) {
buf.resize(res);
res = llama_chat_apply_template(nullptr, ptr_tmpl, chat.data(), chat_size, add_ass, buf.data(), buf.size());
}

const std::string formatted_chat(buf.data(), res);
return formatted_chat;
};

std::string formatted_chat_last = messages.size() > 0
? apply_chat_template(messages, 1, false) // (n_msgs - 1) messages
: "";
std::string formatted_chat_curr = apply_chat_template(messages, 0, true);

// Extract the added part (user prompt)
auto get_diff_part = [](const std::string & str1, const std::string & str2) {
size_t i = 0;
while (i < str1.size() && i < str2.size() && str1[i] == str2[i])
++i;
return str2.substr(i);
};
return get_diff_part(formatted_chat_last, formatted_chat_curr);
}
10 changes: 10 additions & 0 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -322,3 +322,13 @@ llama_control_vector_data llama_control_vector_load(const std::vector<llama_cont
static const char * const LLM_KV_SPLIT_NO = "split.no";
static const char * const LLM_KV_SPLIT_COUNT = "split.count";
static const char * const LLM_KV_SPLIT_TENSORS_COUNT = "split.tensors.count";

//
// Chat templates utils
//
typedef struct chat_message {
std::string role;
std::string content;
} chat_message;

std::string chat_get_added_part(const std::vector<chat_message> & messages, const std::string & tmpl);
14 changes: 13 additions & 1 deletion tests/test-chat-template.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <cassert>

#include "llama.h"
#include "common.h"

int main(void) {
llama_chat_message conversation[] = {
Expand Down Expand Up @@ -96,8 +97,19 @@ int main(void) {
);
formatted_chat.resize(res);
std::string output(formatted_chat.data(), formatted_chat.size());
std::cout << output << "\n-------------------------\n";
std::cout << output << "\n-----\n";
assert(output == expected);

std::vector<chat_message> v_messages;
for (size_t i = 0; i < message_count; ++i) {
v_messages.push_back({
conversation[i].role,
conversation[i].content,
});
}
std::cout << "chat_get_added_part(): " << chat_get_added_part(v_messages, custom_template);
std::cout << "\n-------------------------\n";
// TODO: chat_get_added_part is currently printed for debugging. Should we add tests for it in the future?
}
return 0;
}
Loading