Skip to content

Commit

Permalink
Feature/update hunyuan (#25779)
Browse files Browse the repository at this point in the history
Description: 
    - Add system templates and user templates in integration testing
    - initialize the response id field value to request_id
    - Adjust the default model to hunyuan-pro
    - Remove the default values of Temperature and TopP
    - Add SystemMessage

all the integration tests have passed.
1、Execute integration tests for the first time
<img width="1359" alt="71ca77a2-e9be-4af6-acdc-4d665002bd9b"
src="https://github.com/user-attachments/assets/9298dc3a-aa26-4bfa-968b-c011a4e699c9">

2、Run the integration test a second time
<img width="1501" alt="image"
src="https://github.com/user-attachments/assets/61335416-4a67-4840-bb89-090ba668e237">

Issue: None
Dependencies: None
Twitter handle: None

---------

Co-authored-by: Chester Curme <[email protected]>
  • Loading branch information
xander-art and ccurme authored Sep 2, 2024
1 parent 566e9ba commit 6cd452d
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 7 deletions.
11 changes: 9 additions & 2 deletions libs/community/langchain_community/chat_models/hunyuan.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
ChatMessageChunk,
HumanMessage,
HumanMessageChunk,
SystemMessage,
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
Expand All @@ -33,6 +34,8 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
message_dict: Dict[str, Any]
if isinstance(message, ChatMessage):
message_dict = {"Role": message.role, "Content": message.content}
elif isinstance(message, SystemMessage):
message_dict = {"Role": "system", "Content": message.content}
elif isinstance(message, HumanMessage):
message_dict = {"Role": "user", "Content": message.content}
elif isinstance(message, AIMessage):
Expand All @@ -45,7 +48,9 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:

def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
role = _dict["Role"]
if role == "user":
if role == "system":
return SystemMessage(content=_dict.get("Content", "") or "")
elif role == "user":
return HumanMessage(content=_dict["Content"])
elif role == "assistant":
return AIMessage(content=_dict.get("Content", "") or "")
Expand Down Expand Up @@ -73,6 +78,7 @@ def _create_chat_result(response: Mapping[str, Any]) -> ChatResult:
generations = []
for choice in response["Choices"]:
message = _convert_dict_to_message(choice["Message"])
message.id = response.get("Id", "")
generations.append(ChatGeneration(message=message))

token_usage = response["Usage"]
Expand Down Expand Up @@ -115,7 +121,7 @@ def lc_serializable(self) -> bool:
model: str = "hunyuan-lite"
"""What Model to use.
Optional model:
- hunyuan-lite
- hunyuan-lite
- hunyuan-standard
- hunyuan-standard-256K
- hunyuan-pro
Expand Down Expand Up @@ -233,6 +239,7 @@ def _stream(
chunk = _convert_delta_to_message_chunk(
choice["Delta"], default_chunk_class
)
chunk.id = response.get("Id", "")
default_chunk_class = chunk.__class__
cg_chunk = ChatGenerationChunk(message=chunk)
if run_manager:
Expand Down
37 changes: 37 additions & 0 deletions libs/community/tests/integration_tests/chat_models/test_hunyuan.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,15 @@
import uuid
from operator import itemgetter
from typing import Any

import pytest
from langchain.prompts.chat import (
ChatPromptTemplate,
HumanMessagePromptTemplate,
SystemMessagePromptTemplate,
)
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.runnables.base import RunnableSerializable

from langchain_community.chat_models.hunyuan import ChatHunyuan

Expand All @@ -11,6 +21,8 @@ def test_chat_hunyuan() -> None:
response = chat.invoke([message])
assert isinstance(response, AIMessage)
assert isinstance(response.content, str)
assert response.id is not None, "request_id is empty"
assert uuid.UUID(response.id), "Invalid UUID"


@pytest.mark.requires("tencentcloud-sdk-python")
Expand All @@ -20,6 +32,8 @@ def test_chat_hunyuan_with_temperature() -> None:
response = chat.invoke([message])
assert isinstance(response, AIMessage)
assert isinstance(response.content, str)
assert response.id is not None, "request_id is empty"
assert uuid.UUID(response.id), "Invalid UUID"


@pytest.mark.requires("tencentcloud-sdk-python")
Expand All @@ -29,6 +43,8 @@ def test_chat_hunyuan_with_model_name() -> None:
response = chat.invoke([message])
assert isinstance(response, AIMessage)
assert isinstance(response.content, str)
assert response.id is not None, "request_id is empty"
assert uuid.UUID(response.id), "Invalid UUID"


@pytest.mark.requires("tencentcloud-sdk-python")
Expand All @@ -38,6 +54,27 @@ def test_chat_hunyuan_with_stream() -> None:
response = chat.invoke([message])
assert isinstance(response, AIMessage)
assert isinstance(response.content, str)
assert response.id is not None, "request_id is empty"
assert uuid.UUID(response.id), "Invalid UUID"


@pytest.mark.requires("tencentcloud-sdk-python")
def test_chat_hunyuan_with_prompt_template() -> None:
system_prompt = SystemMessagePromptTemplate.from_template(
"You are a helpful assistant! Your name is {name}."
)
user_prompt = HumanMessagePromptTemplate.from_template("Question: {query}")
chat_prompt = ChatPromptTemplate.from_messages([system_prompt, user_prompt])
chat: RunnableSerializable[Any, Any] = (
{"query": itemgetter("query"), "name": itemgetter("name")}
| chat_prompt
| ChatHunyuan()
)
response = chat.invoke({"query": "Hello", "name": "Tom"})
assert isinstance(response, AIMessage)
assert isinstance(response.content, str)
assert response.id is not None, "request_id is empty"
assert uuid.UUID(response.id), "Invalid UUID"


def test_extra_kwargs() -> None:
Expand Down
17 changes: 12 additions & 5 deletions libs/community/tests/unit_tests/chat_models/test_hunyuan.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@ def test__convert_message_to_dict_ai() -> None:

def test__convert_message_to_dict_system() -> None:
message = SystemMessage(content="foo")
with pytest.raises(TypeError) as e:
_convert_message_to_dict(message)
assert "Got unknown type" in str(e)
result = _convert_message_to_dict(message)
expected_output = {"Role": "system", "Content": "foo"}
assert result == expected_output


def test__convert_message_to_dict_function() -> None:
Expand All @@ -58,10 +58,17 @@ def test__convert_dict_to_message_ai() -> None:
assert result == expected_output


def test__convert_dict_to_message_other_role() -> None:
def test__convert_dict_to_message_system() -> None:
message_dict = {"Role": "system", "Content": "foo"}
result = _convert_dict_to_message(message_dict)
expected_output = ChatMessage(role="system", content="foo")
expected_output = SystemMessage(content="foo")
assert result == expected_output


def test__convert_dict_to_message_other_role() -> None:
message_dict = {"Role": "other", "Content": "foo"}
result = _convert_dict_to_message(message_dict)
expected_output = ChatMessage(role="other", content="foo")
assert result == expected_output


Expand Down

0 comments on commit 6cd452d

Please sign in to comment.