Skip to content

Commit

Permalink
fix(LLM): mistral ignoring assistant messages (#1954)
Browse files Browse the repository at this point in the history
* fix: mistral ignoring assistant messages

* fix: typing

* fix: fix tests
  • Loading branch information
pabloogc authored May 30, 2024
1 parent 3b3e96a commit c7212ac
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 19 deletions.
26 changes: 15 additions & 11 deletions private_gpt/components/llm/prompt_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,18 +173,22 @@ def _completion_to_prompt(self, completion: str) -> str:

class MistralPromptStyle(AbstractPromptStyle):
def _messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str:
prompt = "<s>"
inst_buffer = []
text = ""
for message in messages:
role = message.role
content = message.content or ""
if role.lower() == "system":
message_from_user = f"[INST] {content.strip()} [/INST]"
prompt += message_from_user
elif role.lower() == "user":
prompt += "</s>"
message_from_user = f"[INST] {content.strip()} [/INST]"
prompt += message_from_user
return prompt
if message.role == MessageRole.SYSTEM or message.role == MessageRole.USER:
inst_buffer.append(str(message.content).strip())
elif message.role == MessageRole.ASSISTANT:
text += "<s>[INST] " + "\n".join(inst_buffer) + " [/INST]"
text += " " + str(message.content).strip() + "</s>"
inst_buffer.clear()
else:
raise ValueError(f"Unknown message role {message.role}")

if len(inst_buffer) > 0:
text += "<s>[INST] " + "\n".join(inst_buffer) + " [/INST]"

return text

def _completion_to_prompt(self, completion: str) -> str:
return self._messages_to_prompt(
Expand Down
20 changes: 12 additions & 8 deletions tests/test_prompt_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,17 +69,21 @@ def test_tag_prompt_style_format_with_system_prompt():
def test_mistral_prompt_style_format():
prompt_style = MistralPromptStyle()
messages = [
ChatMessage(content="You are an AI assistant.", role=MessageRole.SYSTEM),
ChatMessage(content="Hello, how are you doing?", role=MessageRole.USER),
ChatMessage(content="A", role=MessageRole.SYSTEM),
ChatMessage(content="B", role=MessageRole.USER),
]

expected_prompt = (
"<s>[INST] You are an AI assistant. [/INST]</s>"
"[INST] Hello, how are you doing? [/INST]"
)

expected_prompt = "<s>[INST] A\nB [/INST]"
assert prompt_style.messages_to_prompt(messages) == expected_prompt

messages2 = [
ChatMessage(content="A", role=MessageRole.SYSTEM),
ChatMessage(content="B", role=MessageRole.USER),
ChatMessage(content="C", role=MessageRole.ASSISTANT),
ChatMessage(content="D", role=MessageRole.USER),
]
expected_prompt2 = "<s>[INST] A\nB [/INST] C</s><s>[INST] D [/INST]"
assert prompt_style.messages_to_prompt(messages2) == expected_prompt2


def test_chatml_prompt_style_format():
prompt_style = ChatMLPromptStyle()
Expand Down

0 comments on commit c7212ac

Please sign in to comment.