Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/update hunyuan #25779

Merged
merged 12 commits into from
Sep 2, 2024
18 changes: 11 additions & 7 deletions libs/community/langchain_community/chat_models/hunyuan.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,41 @@
import json
import logging
from typing import Any, Dict, Iterator, List, Mapping, Optional, Type

from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models.chat_models import (
BaseChatModel,
generate_from_stream,
)
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
BaseMessageChunk,
ChatMessage,
ChatMessageChunk,
SystemMessage,
HumanMessage,
HumanMessageChunk,
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
from langchain_core.utils import (
convert_to_secret_str,
get_from_dict_or_env,
get_pydantic_field_names,
pre_init,
)

logger = logging.getLogger(__name__)

Check failure on line 30 in libs/community/langchain_community/chat_models/hunyuan.py

View workflow job for this annotation

GitHub Actions / cd libs/community / make lint #3.12

Ruff (I001)

langchain_community/chat_models/hunyuan.py:1:1: I001 Import block is un-sorted or un-formatted

Check failure on line 30 in libs/community/langchain_community/chat_models/hunyuan.py

View workflow job for this annotation

GitHub Actions / cd libs/community / make lint #3.8

Ruff (I001)

langchain_community/chat_models/hunyuan.py:1:1: I001 Import block is un-sorted or un-formatted


def _convert_message_to_dict(message: BaseMessage) -> dict:
message_dict: Dict[str, Any]
if isinstance(message, ChatMessage):
message_dict = {"Role": message.role, "Content": message.content}
elif isinstance(message, SystemMessage):
message_dict = {"Role": "system", "Content": message.content}
elif isinstance(message, HumanMessage):
message_dict = {"Role": "user", "Content": message.content}
elif isinstance(message, AIMessage):
Expand All @@ -45,7 +48,9 @@

def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
role = _dict["Role"]
if role == "user":
if role == "system":
return SystemMessage(content=_dict.get("Content", "") or "")
elif role == "user":
return HumanMessage(content=_dict["Content"])
elif role == "assistant":
return AIMessage(content=_dict.get("Content", "") or "")
Expand Down Expand Up @@ -73,10 +78,10 @@
generations = []
for choice in response["Choices"]:
message = _convert_dict_to_message(choice["Message"])
message.id = response.get('Id','')
generations.append(ChatGeneration(message=message))

token_usage = response["Usage"]
llm_output = {"token_usage": token_usage}
llm_output = {"token_usage": response.get('Usage','')}
return ChatResult(generations=generations, llm_output=llm_output)


Expand Down Expand Up @@ -112,10 +117,10 @@
"""What sampling temperature to use."""
top_p: float = 1.0
"""What probability mass to use."""
model: str = "hunyuan-lite"
model: str = "hunyuan-pro"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this will quietly change the model type for anyone relying on the default model, here. Suggest keeping unchanged.

Copy link
Contributor Author

@xander-art xander-art Aug 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok,Restore Default Model
image

"""What Model to use.
Optional model:
- hunyuan-lite
- hunyuan-lite
- hunyuan-standard
- hunyuan-standard-256K
- hunyuan-pro
Expand Down Expand Up @@ -186,8 +191,6 @@
def _default_params(self) -> Dict[str, Any]:
"""Get the default parameters for calling Hunyuan API."""
normal_params = {
"Temperature": self.temperature,
"TopP": self.top_p,
xander-art marked this conversation as resolved.
Show resolved Hide resolved
"Model": self.model,
"Stream": self.streaming,
"StreamModeration": self.stream_moderation,
Expand Down Expand Up @@ -233,6 +236,7 @@
chunk = _convert_delta_to_message_chunk(
choice["Delta"], default_chunk_class
)
chunk.id = response.get('Id', '')
default_chunk_class = chunk.__class__
cg_chunk = ChatGenerationChunk(message=chunk)
if run_manager:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import pytest
import uuid
from langchain_core.messages import AIMessage, HumanMessage

from langchain.prompts.chat import ChatPromptTemplate,SystemMessagePromptTemplate,HumanMessagePromptTemplate
from operator import itemgetter
from langchain_community.chat_models.hunyuan import ChatHunyuan


Expand All @@ -11,6 +13,8 @@ def test_chat_hunyuan() -> None:
response = chat.invoke([message])
assert isinstance(response, AIMessage)
assert isinstance(response.content, str)
assert len(response.id) > 0, "request_id 不为空"
assert uuid.UUID(response.id), "无效的UUID"


@pytest.mark.requires("tencentcloud-sdk-python")
Expand All @@ -20,6 +24,8 @@ def test_chat_hunyuan_with_temperature() -> None:
response = chat.invoke([message])
assert isinstance(response, AIMessage)
assert isinstance(response.content, str)
assert len(response.id) > 0, "request_id 不为空"
assert uuid.UUID(response.id), "无效的UUID"


@pytest.mark.requires("tencentcloud-sdk-python")
Expand All @@ -29,6 +35,8 @@ def test_chat_hunyuan_with_model_name() -> None:
response = chat.invoke([message])
assert isinstance(response, AIMessage)
assert isinstance(response.content, str)
assert len(response.id) > 0, "request_id 不为空"
assert uuid.UUID(response.id), "无效的UUID"


@pytest.mark.requires("tencentcloud-sdk-python")
Expand All @@ -38,7 +46,20 @@ def test_chat_hunyuan_with_stream() -> None:
response = chat.invoke([message])
assert isinstance(response, AIMessage)
assert isinstance(response.content, str)
assert len(response.id) > 0, "request_id 不为空"
assert uuid.UUID(response.id), "无效的UUID"

@pytest.mark.requires("tencentcloud-sdk-python")
def test_chat_hunyuan_with_prompt_template() -> None:
system_prompt=SystemMessagePromptTemplate.from_template('You are a helpful assistant! Your name is {name}.')
user_prompt=HumanMessagePromptTemplate.from_template('Question: {query}')
chat_prompt=ChatPromptTemplate.from_messages([system_prompt,user_prompt])
chat={"query": itemgetter("query"),"name": itemgetter("name")}|chat_prompt|ChatHunyuan()
response=chat.invoke({'query':'Hello','name':'Tom'})
assert isinstance(response, AIMessage)
assert isinstance(response.content, str)
assert len(response.id) > 0, "request_id 不为空"
assert uuid.UUID(response.id), "无效的UUID"

def test_extra_kwargs() -> None:
chat = ChatHunyuan(temperature=0.88, top_p=0.7)
Expand Down
Loading