Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve Support for Mistral-Instruct #2547

Merged
merged 2 commits into from
Oct 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 24 additions & 7 deletions fastchat/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class SeparatorStyle(IntEnum):
PHOENIX = auto()
ROBIN = auto()
FALCON_CHAT = auto()
MISTRAL_INSTRUCT = auto()


@dataclasses.dataclass
Expand Down Expand Up @@ -212,6 +213,17 @@ def get_prompt(self) -> str:
ret += role + ":"

return ret
elif self.sep_style == SeparatorStyle.MISTRAL_INSTRUCT:
ret = self.sep
for i, (role, message) in enumerate(self.messages):
if role == "user":
if self.system_message and i == 0:
ret += "[INST] " + system_prompt + " " + message + " [/INST]"
else:
ret += "[INST] " + message + " [/INST]"
elif role == "assistant" and message:
ret += message + self.sep2 + " "
return ret
else:
raise ValueError(f"Invalid style: {self.sep_style}")

Expand Down Expand Up @@ -840,16 +852,21 @@ def get_conv_template(name: str) -> Conversation:
)
)

# Mistral template
# Mistral instruct template
# source: https://docs.mistral.ai/llm/mistral-instruct-v0.1#chat-template
# https://docs.mistral.ai/usage/guardrailing/
# https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1/blob/main/tokenizer_config.json
register_conv_template(
Conversation(
name="mistral",
system_template="",
roles=("[INST] ", " [/INST]"),
sep_style=SeparatorStyle.LLAMA2,
sep="",
sep2=" </s>",
name="mistral-instruct",
system_message="Always assist with care, respect, and truth. "
"Respond with utmost utility yet securely. "
"Avoid harmful, unethical, prejudiced, or negative content. "
"Ensure replies promote fairness and positivity.",
roles=("user", "assistant"),
sep_style=SeparatorStyle.MISTRAL_INSTRUCT,
sep="<s>",
sep2="</s>",
)
)

Expand Down
10 changes: 5 additions & 5 deletions fastchat/model/model_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -1283,11 +1283,11 @@ def get_default_conv_template(self, model_path: str) -> Conversation:
return get_conv_template("starchat")


class MistralAdapter(BaseModelAdapter):
"""The model adapter for Mistral AI models"""
class MistralInstructAdapter(BaseModelAdapter):
"""The model adapter for Mistral Instruct AI models"""

def match(self, model_path: str):
return "mistral" in model_path.lower()
return "mistral" in model_path.lower() and "instruct" in model_path.lower()

def load_model(self, model_path: str, from_pretrained_kwargs: dict):
model, tokenizer = super().load_model(model_path, from_pretrained_kwargs)
Expand All @@ -1296,7 +1296,7 @@ def load_model(self, model_path: str, from_pretrained_kwargs: dict):
return model, tokenizer

def get_default_conv_template(self, model_path: str) -> Conversation:
return get_conv_template("mistral")
return get_conv_template("mistral-instruct")


class Llama2Adapter(BaseModelAdapter):
Expand Down Expand Up @@ -1716,7 +1716,7 @@ def get_default_conv_template(self, model_path: str) -> Conversation:
register_model_adapter(InternLMChatAdapter)
register_model_adapter(StarChatAdapter)
register_model_adapter(Llama2Adapter)
register_model_adapter(MistralAdapter)
register_model_adapter(MistralInstructAdapter)
register_model_adapter(CuteGPTAdapter)
register_model_adapter(OpenOrcaAdapter)
register_model_adapter(WizardCoderAdapter)
Expand Down
2 changes: 1 addition & 1 deletion fastchat/model/model_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ def get_model_info(name: str) -> ModelInfo:
)
register_model_info(
["mistral-7b-instruct"],
"Mistral",
"Mistral-Instruct",
"https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1",
"a large language model by Mistral AI team",
)
Expand Down