Skip to content

Commit

Permalink
Fix Chat Templates (#916)
Browse files Browse the repository at this point in the history
* Update pyproject.toml

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update _utils.py

* Update _utils.py

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* fix_tokenizer

* Update tokenizer_utils.py

* Update tokenizer_utils.py

* Update save.py

* Update save.py

* Update save.py

* Update save.py

* Update save.py

* Update loader.py

* Update pyproject.toml

* Update _utils.py

* Update gemma2.py

* Update gemma2.py

* Update _utils.py

* gemma 2 mask

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update _utils.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update _utils.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update _utils.py

* Update _utils.py

* Update _utils.py

* Update _utils.py

* Update _utils.py

* Update _utils.py

* Torch 2.4 Xformers 0.0.27post2

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Gemma 2 fixes

* Update gemma2.py

* Update llama.py

* Update llama.py

* Update save.py

* Update save.py

* Update llama.py

* Update cross_entropy_loss.py

* Update dpo.py

* Update dpo.py

* Update dpo.py

* Update dpo.py

* Update dpo.py

* Update dpo.py

* Update dpo.py

* Update dpo.py

* Update dpo.py

* Update dpo.py

* Update dpo.py

* Update dpo.py

* Update dpo.py

* Update dpo.py

* Update dpo.py

* Update dpo.py

* Update dpo.py

* Update dpo.py

* Update dpo.py

* Update dpo.py

* Update dpo.py

* Update dpo.py

* Update dpo.py

* Update dpo.py

* Update dpo.py

* Update dpo.py

* Providing more flexibility for users to customize their llama when using LoRA (#910)

* Update llama.py

* Update llama.py

* Update llama.py

* Update llama.py

* Update chat_templates.py

* return model

* Update tokenizer_utils.py

* Update chat_templates.py

* Update tokenizer_utils.py

---------

Co-authored-by: Po-Lung Wang <[email protected]>
  • Loading branch information
danielhanchen and Brownwang0426 committed Aug 14, 2024
1 parent 3781a03 commit a64b8f6
Show file tree
Hide file tree
Showing 3 changed files with 256 additions and 11 deletions.
222 changes: 214 additions & 8 deletions unsloth/chat_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,6 +508,200 @@
CHAT_TEMPLATES["phi-3"] = (phi3_template, phi3_template_eos_token, False, phi3_ollama,)
pass

# =========================================== Llama-3.1
"""
No trimming in Llama 3.1 Instruct!
Also an extra newline for Cutting Knowledge Date
See https://colab.research.google.com/drive/1Xpqq5xpIgO-B00MQ-UccYMwN2J8QFgBM?usp=sharing
Also should be
import datetime
tokenizer.apply_chat_template(
messages,
add_generation_prompt = True,
tokenize = False,
date_string = datetime.today().strftime("%d %B %Y")),
)
"""

llama31_template = \
"""{{- bos_token }}
{%- if custom_tools is defined %}
{%- set tools = custom_tools %}
{%- endif %}
{%- if not tools_in_user_message is defined %}
{%- set tools_in_user_message = true %}
{%- endif %}
{%- if not date_string is defined %}
{%- set date_string = "26 July 2024" %}
{%- endif %}
{%- if not tools is defined %}
{%- set tools = none %}
{%- endif %}
{#- This block extracts the system message, so we can slot it into the right place. #}
{%- if messages[0]['role'] == 'system' %}
{%- set system_message = messages[0]['content'] %}
{%- set messages = messages[1:] %}
{%- else %}
{%- set system_message = "" %}
{%- endif %}
{#- System message + builtin tools #}
{{- "<|start_header_id|>system<|end_header_id|>\n\n" }}
{%- if builtin_tools is defined or tools is not none %}
{{- "Environment: ipython\n" }}
{%- endif %}
{%- if builtin_tools is defined %}
{{- "Tools: " + builtin_tools | reject('equalto', 'code_interpreter') | join(", ") + "\n\n"}}
{%- endif %}
{{- "Cutting Knowledge Date: December 2023\n" }}
{{- "Today Date: " + date_string + "\n\n" }}
{%- if tools is not none and not tools_in_user_message %}
{{- "You have access to the following functions. To call a function, please respond with JSON for a function call." }}
{{- 'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}.' }}
{{- "Do not use variables.\n\n" }}
{%- for t in tools %}
{{- t | tojson(indent=4) }}
{{- "\n\n" }}
{%- endfor %}
{%- endif %}
{{- system_message }}
{{- "<|eot_id|>" }}
{#- Custom tools are passed in a user message with some extra guidance #}
{%- if tools_in_user_message and not tools is none %}
{#- Extract the first user message so we can plug it in here #}
{%- if messages | length != 0 %}
{%- set first_user_message = messages[0]['content'] %}
{%- set messages = messages[1:] %}
{%- else %}
{{- raise_exception("Cannot put tools in the first user message when there's no first user message!") }}
{%- endif %}
{{- '<|start_header_id|>user<|end_header_id|>\n\n' -}}
{{- "Given the following functions, please respond with a JSON for a function call " }}
{{- "with its proper arguments that best answers the given prompt.\n\n" }}
{{- 'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}.' }}
{{- "Do not use variables.\n\n" }}
{%- for t in tools %}
{{- t | tojson(indent=4) }}
{{- "\n\n" }}
{%- endfor %}
{{- first_user_message + "<|eot_id|>"}}
{%- endif %}
{%- for message in messages %}
{%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %}
{{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] + '<|eot_id|>' }}
{%- elif 'tool_calls' in message %}
{%- if not message.tool_calls|length == 1 %}
{{- raise_exception("This model only supports single tool-calls at once!") }}
{%- endif %}
{%- set tool_call = message.tool_calls[0].function %}
{%- if builtin_tools is defined and tool_call.name in builtin_tools %}
{{- '<|start_header_id|>assistant<|end_header_id|>\n\n' -}}
{{- "<|python_tag|>" + tool_call.name + ".call(" }}
{%- for arg_name, arg_val in tool_call.arguments | items %}
{{- arg_name + '="' + arg_val + '"' }}
{%- if not loop.last %}
{{- ", " }}
{%- endif %}
{%- endfor %}
{{- ")" }}
{%- else %}
{{- '<|start_header_id|>assistant<|end_header_id|>\n\n' -}}
{{- '{"name": "' + tool_call.name + '", ' }}
{{- '"parameters": ' }}
{{- tool_call.arguments | tojson }}
{{- "}" }}
{%- endif %}
{%- if builtin_tools is defined %}
{#- This means we're in ipython mode #}
{{- "<|eom_id|>" }}
{%- else %}
{{- "<|eot_id|>" }}
{%- endif %}
{%- elif message.role == "tool" or message.role == "ipython" %}
{{- "<|start_header_id|>ipython<|end_header_id|>\n\n" }}
{%- if message.content is mapping or message.content is iterable %}
{{- message.content | tojson }}
{%- else %}
{{- message.content }}
{%- endif %}
{{- "<|eot_id|>" }}
{%- endif %}
{%- endfor %}
{%- if add_generation_prompt %}
{{- '<|start_header_id|>assistant<|end_header_id|>\n\n' }}
{%- endif %}
"""
pass

# Ollama from https://ollama.com/library/llama3.1 (needs updating!)
llama31_ollama = \
'''
FROM {__FILE_LOCATION__}
TEMPLATE """{{ if .Messages }}
{{- if or .System .Tools }}<|start_header_id|>system<|end_header_id|>
{{- if .System }}
{{ .System }}
{{- end }}
{{- if .Tools }}
You are a helpful assistant with tool calling capabilities. When you receive a tool call response, use the output to format an answer to the orginal use question.
{{- end }}
{{- end }}<|eot_id|>
{{- range $i, $_ := .Messages }}
{{- $last := eq (len (slice $.Messages $i)) 1 }}
{{- if eq .Role "user" }}<|start_header_id|>user<|end_header_id|>
{{- if and $.Tools $last }}
Given the following functions, please respond with a JSON for a function call with its proper arguments that best answers the given prompt.
Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}. Do not use variables.
{{ $.Tools }}
{{- end }}
{{ .Content }}<|eot_id|>{{ if $last }}<|start_header_id|>assistant<|end_header_id|>
{{ end }}
{{- else if eq .Role "assistant" }}<|start_header_id|>assistant<|end_header_id|>
{{- if .ToolCalls }}
{{- range .ToolCalls }}{"name": "{{ .Function.Name }}", "parameters": {{ .Function.Arguments }}}{{ end }}
{{- else }}
{{ .Content }}{{ if not $last }}<|eot_id|>{{ end }}
{{- end }}
{{- else if eq .Role "tool" }}<|start_header_id|>ipython<|end_header_id|>
{{ .Content }}<|eot_id|>{{ if $last }}<|start_header_id|>assistant<|end_header_id|>
{{ end }}
{{- end }}
{{- end }}
{{- else }}
{{- if .System }}<|start_header_id|>system<|end_header_id|>
{{ .System }}<|eot_id|>{{ end }}{{ if .Prompt }}<|start_header_id|>user<|end_header_id|>
{{ .Prompt }}<|eot_id|>{{ end }}<|start_header_id|>assistant<|end_header_id|>
{{ end }}{{ .Response }}{{ if .Response }}<|eot_id|>{{ end }}"""
PARAMETER stop "<|start_header_id|>"
PARAMETER stop "<|end_header_id|>"
PARAMETER stop "<|eot_id|>"
PARAMETER stop "<|eom_id|>"
'''

llama31_template_eos_token = "eos_token"
CHAT_TEMPLATES["llama-3.1"] = (llama31_template, llama31_template_eos_token, False, llama31_ollama,)
CHAT_TEMPLATES["llama-31"] = (llama31_template, llama31_template_eos_token, False, llama31_ollama,)
pass


def get_chat_template(
tokenizer,
Expand Down Expand Up @@ -680,21 +874,33 @@ def get_chat_template(
)
pass

# For ShareGPT role -> from and content -> value
chat_template = chat_template\
.replace("'role'", "'" + mapping["role"] + "'")\
.replace("'content'", "'" + mapping["content"] + "'")\
.replace("'user'", "'" + mapping["user"] + "'")\
.replace("'assistant'", "'" + mapping["assistant"] + "'")

# Careful on Gemma
# bos_token is a must or else losses become too high
if IS_GEMMA and not chat_template.startswith("{{ bos_token }}"):
chat_template = "{{ bos_token }}" + chat_template
pass

# For ShareGPT role -> from and content -> value
new_chat_template = chat_template\
.replace("'role'", "'" + mapping["role"] + "'")\
.replace("'content'", "'" + mapping["content"] + "'")\
.replace("'user'", "'" + mapping["user"] + "'")\
.replace("'assistant'", "'" + mapping["assistant"] + "'")

_, tokenizer = patch_tokenizer(model = None, tokenizer = tokenizer)
tokenizer.padding_side = old_padding_side
tokenizer.padding_side = old_padding_side

# If not normal HF, we add a check to make old templates work
if mapping != {"role" : "role", "content" : "content", "user" : "user", "assistant" : "assistant"}:
chat_template = \
"{% if 'role' in messages[0] %}" + \
chat_template + \
"{% else %}" + \
new_chat_template + \
"{% endif %}"
else:
chat_template = new_chat_template
pass
tokenizer.chat_template = chat_template

# Also fix up other tokens
Expand Down
17 changes: 15 additions & 2 deletions unsloth/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -1873,8 +1873,17 @@ def get_peft_model(
else: modules_to_save.append("embed_tokens")

else:
assert(module in accepted_modules)
final_modules.append(module)
try:
assert(module in accepted_modules)
final_modules.append(module)
except AssertionError as e:
final_modules.append(module)
print(
"Unsloth: You added custom modules, but Unsloth hasn't optimized for this.\n"\
"Beware - your finetuning might be noticeably slower!"
)
pass
pass
pass

# Check if we added new tokens!
Expand Down Expand Up @@ -2253,6 +2262,8 @@ def for_inference(model):
if hasattr(internal_model, "_saved_temp_tokenizer"):
internal_model._saved_temp_tokenizer.padding_side = "left"
pass

return model
pass


Expand Down Expand Up @@ -2291,6 +2302,8 @@ def for_training(model, use_gradient_checkpointing = True):
if hasattr(internal_model, "_saved_temp_tokenizer"):
internal_model._saved_temp_tokenizer.padding_side = "right"
pass

return model
pass
pass

28 changes: 27 additions & 1 deletion unsloth/tokenizer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,8 +597,34 @@ def fix_chat_template(tokenizer):
if chat_template is None: return None

### 1. Check if add_generation_prompt works
# Check for ShareGPT style first
is_sharegpt = None
try:
messages = [
{"role": "user", "content": "Who are you?"},
]
tokenizer.apply_chat_template(messages, add_generation_prompt = False, tokenize = False)
is_sharegpt = False
except:
try:
messages = [
{"from": "human", "value": "Who are you?"},
]
tokenizer.apply_chat_template(messages, add_generation_prompt = False, tokenize = False)
is_sharegpt = True
except:
is_sharegpt = None
pass
pass

# Not ShareGPT or HF style - just return
if is_sharegpt is None: return chat_template

# Tokenize
messages = [
{"role": "user", "content": "Who are you?"},
{"role": "user", "content": "Who are you?"} \
if not is_sharegpt else \
{"from": "human", "value": "Who are you?"}
]
no = tokenizer.apply_chat_template(messages, add_generation_prompt = False, tokenize = False)
yes = tokenizer.apply_chat_template(messages, add_generation_prompt = True, tokenize = False)
Expand Down

0 comments on commit a64b8f6

Please sign in to comment.