From a0bffea4f2ffd8d1d4f463c49cd1bdc271e6d958 Mon Sep 17 00:00:00 2001 From: Nathan Bowyer Date: Wed, 24 Apr 2024 19:57:02 -0500 Subject: [PATCH] Add llama-3 chat template --- lib/chat_models/chat_bumblebee.ex | 2 +- lib/utils/chat_templates.ex | 45 +++++++++++++++++++- test/utils/chat_templates_test.exs | 66 ++++++++++++++++++++++++++++++ 3 files changed, 111 insertions(+), 2 deletions(-) diff --git a/lib/chat_models/chat_bumblebee.ex b/lib/chat_models/chat_bumblebee.ex index 8b45925..f2454f7 100644 --- a/lib/chat_models/chat_bumblebee.ex +++ b/lib/chat_models/chat_bumblebee.ex @@ -119,7 +119,7 @@ defmodule LangChain.ChatModels.ChatBumblebee do # # more focused and deterministic. # field :temperature, :float, default: 1.0 - field :template_format, Ecto.Enum, values: [:inst, :im_start, :zephyr, :llama_2] + field :template_format, Ecto.Enum, values: [:inst, :im_start, :zephyr, :llama_2, :llama_3] # The bumblebee model may compile differently based on the stream true/false # option on the serving. Therefore, streaming should be enabled on the diff --git a/lib/utils/chat_templates.ex b/lib/utils/chat_templates.ex index e65c191..3d220f3 100644 --- a/lib/utils/chat_templates.ex +++ b/lib/utils/chat_templates.ex @@ -49,6 +49,21 @@ defmodule LangChain.Utils.ChatTemplates do Note: The `:llama_2` format supports specific system messages. It is a variation of the `:inst` format. + ### `:llama_3` + + ``` + <|begin_of_text|> + <|start_header_id|>system<|end_header_id|> + + System message.<|eot_id|> + <|start_header_id|>user<|end_header_id|> + + User message.<|eot_id|> + <|start_header_id|>assistant<|end_header_id|> + + Assistant message.<|eot_id|> + ``` + ### `:zephyr` ``` @@ -69,7 +84,7 @@ defmodule LangChain.Utils.ChatTemplates do alias LangChain.LangChainError alias LangChain.Message - @type chat_format :: :inst | :im_start | :llama_2 | :zephyr + @type chat_format :: :inst | :im_start | :llama_2 | :llama_3 | :zephyr # Option: # - `add_generation_prompt`: boolean. Defaults to False. @@ -275,6 +290,34 @@ defmodule LangChain.Utils.ChatTemplates do EEx.eval_string(text, assigns: [system_text: system_text, first_user: first_user, rest: rest]) end + # Does LLaMa 3 formatted text + def apply_chat_template!(messages, :llama_3, opts) do + # <|begin_of_text|> + # <|start_header_id|>system<|end_header_id|> + # + # You are a helpful assistant.<|eot_id|> + # <|start_header_id|>user<|end_header_id|> + # + # What do you know about elixir?<|eot_id|> + # <|start_header_id|>assistant<|end_header_id|> + + add_generation_prompt = + Keyword.get(opts, :add_generation_prompt, default_add_generation_prompt_value(messages)) + + {system, first_user, rest} = prep_and_validate_messages(messages) + + # intentionally as a single line for explicit control of newlines and spaces. + text = + "<|begin_of_text|>\n<%= for message <- @messages do %><|start_header_id|><%= message.role %><|end_header_id|>\n\n<%= message.content %><|eot_id|>\n<% end %><%= if @add_generation_prompt do %><|start_header_id|>assistant<|end_header_id|>\n\n<% end %>" + + EEx.eval_string(text, + assigns: [ + messages: [system, first_user | rest] |> Enum.drop_while(&(&1 == nil)), + add_generation_prompt: add_generation_prompt + ] + ) + end + # return the desired true/false value. Only set to true when the last message # is a user prompt. defp default_add_generation_prompt_value(messages) do diff --git a/test/utils/chat_templates_test.exs b/test/utils/chat_templates_test.exs index e323127..35b2c63 100644 --- a/test/utils/chat_templates_test.exs +++ b/test/utils/chat_templates_test.exs @@ -399,4 +399,70 @@ defmodule LangChain.Utils.ChatTemplatesTest do assert result == expected end end + + describe "apply_chat_template!/3 - :llama_3 format" do + test "includes provided system message" do + messages = [ + Message.new_system!("system_message"), + Message.new_user!("user_prompt") + ] + + expected = "<|begin_of_text|>\n<|start_header_id|>system<|end_header_id|>\n\nsystem_message<|eot_id|>\n<|start_header_id|>user<|end_header_id|>\n\nuser_prompt<|eot_id|>\n<|start_header_id|>assistant<|end_header_id|>\n\n" + + result = ChatTemplates.apply_chat_template!(messages, :llama_3) + assert result == expected + end + + test "does not add generation prompt when set to false" do + messages = [ + Message.new_system!("system_message"), + Message.new_user!("user_prompt") + ] + + expected = "<|begin_of_text|>\n<|start_header_id|>system<|end_header_id|>\n\nsystem_message<|eot_id|>\n<|start_header_id|>user<|end_header_id|>\n\nuser_prompt<|eot_id|>\n" + + result = + ChatTemplates.apply_chat_template!(messages, :llama_3, add_generation_prompt: false) + + assert result == expected + end + + test "no system message when not provided" do + messages = [Message.new_user!("user_prompt")] + + expected = "<|begin_of_text|>\n<|start_header_id|>user<|end_header_id|>\n\nuser_prompt<|eot_id|>\n<|start_header_id|>assistant<|end_header_id|>\n\n" + + result = ChatTemplates.apply_chat_template!(messages, :llama_3) + assert result == expected + end + + test "formats answered question correctly" do + messages = [ + Message.new_system!("system_message"), + Message.new_user!("user_prompt"), + Message.new_assistant!("assistant_response") + ] + + expected = + "<|begin_of_text|>\n<|start_header_id|>system<|end_header_id|>\n\nsystem_message<|eot_id|>\n<|start_header_id|>user<|end_header_id|>\n\nuser_prompt<|eot_id|>\n<|start_header_id|>assistant<|end_header_id|>\n\nassistant_response<|eot_id|>\n" + + result = ChatTemplates.apply_chat_template!(messages, :llama_3) + assert result == expected + end + + test "formats 2nd question correctly" do + messages = [ + Message.new_system!("system_message"), + Message.new_user!("user_prompt"), + Message.new_assistant!("assistant_response"), + Message.new_user!("user_2nd") + ] + + expected = + "<|begin_of_text|>\n<|start_header_id|>system<|end_header_id|>\n\nsystem_message<|eot_id|>\n<|start_header_id|>user<|end_header_id|>\n\nuser_prompt<|eot_id|>\n<|start_header_id|>assistant<|end_header_id|>\n\nassistant_response<|eot_id|>\n<|start_header_id|>user<|end_header_id|>\n\nuser_2nd<|eot_id|>\n<|start_header_id|>assistant<|end_header_id|>\n\n" + + result = ChatTemplates.apply_chat_template!(messages, :llama_3) + assert result == expected + end + end end