Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add llama-3 chat template #102

Merged
merged 1 commit into from
Apr 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion lib/chat_models/chat_bumblebee.ex
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ defmodule LangChain.ChatModels.ChatBumblebee do
# # more focused and deterministic.
# field :temperature, :float, default: 1.0

field :template_format, Ecto.Enum, values: [:inst, :im_start, :zephyr, :llama_2]
field :template_format, Ecto.Enum, values: [:inst, :im_start, :zephyr, :llama_2, :llama_3]

# The bumblebee model may compile differently based on the stream true/false
# option on the serving. Therefore, streaming should be enabled on the
Expand Down
45 changes: 44 additions & 1 deletion lib/utils/chat_templates.ex
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,21 @@ defmodule LangChain.Utils.ChatTemplates do
Note: The `:llama_2` format supports specific system messages. It is a
variation of the `:inst` format.

### `:llama_3`

```
<|begin_of_text|>
<|start_header_id|>system<|end_header_id|>

System message.<|eot_id|>
<|start_header_id|>user<|end_header_id|>

User message.<|eot_id|>
<|start_header_id|>assistant<|end_header_id|>

Assistant message.<|eot_id|>
```

### `:zephyr`

```
Expand All @@ -69,7 +84,7 @@ defmodule LangChain.Utils.ChatTemplates do
alias LangChain.LangChainError
alias LangChain.Message

@type chat_format :: :inst | :im_start | :llama_2 | :zephyr
@type chat_format :: :inst | :im_start | :llama_2 | :llama_3 | :zephyr

# Option:
# - `add_generation_prompt`: boolean. Defaults to False.
Expand Down Expand Up @@ -275,6 +290,34 @@ defmodule LangChain.Utils.ChatTemplates do
EEx.eval_string(text, assigns: [system_text: system_text, first_user: first_user, rest: rest])
end

# Does LLaMa 3 formatted text
def apply_chat_template!(messages, :llama_3, opts) do
# <|begin_of_text|>
# <|start_header_id|>system<|end_header_id|>
#
# You are a helpful assistant.<|eot_id|>
# <|start_header_id|>user<|end_header_id|>
#
# What do you know about elixir?<|eot_id|>
# <|start_header_id|>assistant<|end_header_id|>

add_generation_prompt =
Keyword.get(opts, :add_generation_prompt, default_add_generation_prompt_value(messages))

{system, first_user, rest} = prep_and_validate_messages(messages)

# intentionally as a single line for explicit control of newlines and spaces.
text =
"<|begin_of_text|>\n<%= for message <- @messages do %><|start_header_id|><%= message.role %><|end_header_id|>\n\n<%= message.content %><|eot_id|>\n<% end %><%= if @add_generation_prompt do %><|start_header_id|>assistant<|end_header_id|>\n\n<% end %>"

EEx.eval_string(text,
assigns: [
messages: [system, first_user | rest] |> Enum.drop_while(&(&1 == nil)),
add_generation_prompt: add_generation_prompt
]
)
end

# return the desired true/false value. Only set to true when the last message
# is a user prompt.
defp default_add_generation_prompt_value(messages) do
Expand Down
66 changes: 66 additions & 0 deletions test/utils/chat_templates_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -399,4 +399,70 @@ defmodule LangChain.Utils.ChatTemplatesTest do
assert result == expected
end
end

describe "apply_chat_template!/3 - :llama_3 format" do
test "includes provided system message" do
messages = [
Message.new_system!("system_message"),
Message.new_user!("user_prompt")
]

expected = "<|begin_of_text|>\n<|start_header_id|>system<|end_header_id|>\n\nsystem_message<|eot_id|>\n<|start_header_id|>user<|end_header_id|>\n\nuser_prompt<|eot_id|>\n<|start_header_id|>assistant<|end_header_id|>\n\n"

result = ChatTemplates.apply_chat_template!(messages, :llama_3)
assert result == expected
end

test "does not add generation prompt when set to false" do
messages = [
Message.new_system!("system_message"),
Message.new_user!("user_prompt")
]

expected = "<|begin_of_text|>\n<|start_header_id|>system<|end_header_id|>\n\nsystem_message<|eot_id|>\n<|start_header_id|>user<|end_header_id|>\n\nuser_prompt<|eot_id|>\n"

result =
ChatTemplates.apply_chat_template!(messages, :llama_3, add_generation_prompt: false)

assert result == expected
end

test "no system message when not provided" do
messages = [Message.new_user!("user_prompt")]

expected = "<|begin_of_text|>\n<|start_header_id|>user<|end_header_id|>\n\nuser_prompt<|eot_id|>\n<|start_header_id|>assistant<|end_header_id|>\n\n"

result = ChatTemplates.apply_chat_template!(messages, :llama_3)
assert result == expected
end

test "formats answered question correctly" do
messages = [
Message.new_system!("system_message"),
Message.new_user!("user_prompt"),
Message.new_assistant!("assistant_response")
]

expected =
"<|begin_of_text|>\n<|start_header_id|>system<|end_header_id|>\n\nsystem_message<|eot_id|>\n<|start_header_id|>user<|end_header_id|>\n\nuser_prompt<|eot_id|>\n<|start_header_id|>assistant<|end_header_id|>\n\nassistant_response<|eot_id|>\n"

result = ChatTemplates.apply_chat_template!(messages, :llama_3)
assert result == expected
end

test "formats 2nd question correctly" do
messages = [
Message.new_system!("system_message"),
Message.new_user!("user_prompt"),
Message.new_assistant!("assistant_response"),
Message.new_user!("user_2nd")
]

expected =
"<|begin_of_text|>\n<|start_header_id|>system<|end_header_id|>\n\nsystem_message<|eot_id|>\n<|start_header_id|>user<|end_header_id|>\n\nuser_prompt<|eot_id|>\n<|start_header_id|>assistant<|end_header_id|>\n\nassistant_response<|eot_id|>\n<|start_header_id|>user<|end_header_id|>\n\nuser_2nd<|eot_id|>\n<|start_header_id|>assistant<|end_header_id|>\n\n"

result = ChatTemplates.apply_chat_template!(messages, :llama_3)
assert result == expected
end
end
end
Loading