diff --git a/chatsky/llm/conditions.py b/chatsky/llm/conditions.py index 1d7e7d66c..5f66fc96b 100644 --- a/chatsky/llm/conditions.py +++ b/chatsky/llm/conditions.py @@ -1,5 +1,6 @@ from chatsky.llm.methods import BaseMethod + def llm_condition(model_name: str, prompt: str, method: BaseMethod): """ Basic function for using LLM in condition cases. diff --git a/chatsky/llm/filters.py b/chatsky/llm/filters.py index 25f5810c2..8808565ba 100644 --- a/chatsky/llm/filters.py +++ b/chatsky/llm/filters.py @@ -14,6 +14,7 @@ class BaseFilter(BaseModel, abc.ABC): """ Base class for all message history filters. """ + @abc.abstractmethod def __call__(self, ctx: Context, request: Message, response: Message, model_name: str) -> bool: """ @@ -29,6 +30,7 @@ class IsImportant(BaseFilter): """ Filter that checks if the "important" field in a Message.misc is True. """ + def __call__( self, ctx: Context = None, request: Message = None, response: Message = None, model_name: str = None ) -> bool: @@ -43,6 +45,7 @@ class FromTheModel(BaseFilter): """ Filter that checks if the message was sent by the model. """ + def __call__( self, ctx: Context = None, request: Message = None, response: Message = None, model_name: str = None ) -> bool: diff --git a/chatsky/llm/llm_api.py b/chatsky/llm/llm_api.py index 95e8b413d..854e4e16f 100644 --- a/chatsky/llm/llm_api.py +++ b/chatsky/llm/llm_api.py @@ -12,6 +12,7 @@ from langchain_mistralai import ChatMistralAI from langchain_core.output_parsers import StrOutputParser from langchain_core.language_models.chat_models import BaseChatModel + langchain_available = True except ImportError: langchain_available = False @@ -78,7 +79,10 @@ async def respond( async def condition(self, prompt: str, method: BaseMethod, return_schema=None): async def process_input(ctx: Context, _: Pipeline) -> bool: - condition_history = [await message_to_langchain(Message(prompt), pipeline=_, source="system"), await message_to_langchain(ctx.last_request, pipeline=_, source="human")] + condition_history = [ + await message_to_langchain(Message(prompt), pipeline=_, source="system"), + await message_to_langchain(ctx.last_request, pipeline=_, source="human"), + ] result = method(ctx, await self.model.agenerate([condition_history], logprobs=True, top_logprobs=10)) return result diff --git a/chatsky/llm/methods.py b/chatsky/llm/methods.py index f157ab2dd..762c9eb54 100644 --- a/chatsky/llm/methods.py +++ b/chatsky/llm/methods.py @@ -15,6 +15,7 @@ class BaseMethod(BaseModel, abc.ABC): """ Base class to evaluate models response as condition. """ + @abc.abstractmethod async def __call__(self, ctx: Context, model_result: LLMResult) -> bool: raise NotImplementedError @@ -35,6 +36,7 @@ class Contains(BaseMethod): :return: True if pattern is contained in model result :rtype: bool """ + pattern: str async def __call__(self, ctx: Context, model_result: LLMResult) -> bool: @@ -52,11 +54,13 @@ class LogProb(BaseMethod): :return: True if logprob is higher then threshold :rtype: bool """ + target_token: str threshold: float = -0.5 + async def __call__(self, ctx: Context, model_result: LLMResult) -> bool: try: - result = model_result.generations[0][0].generation_info['logprobs']['content'][0]['top_logprobs'] + result = model_result.generations[0][0].generation_info["logprobs"]["content"][0]["top_logprobs"] except ValueError: raise ValueError("LogProb method can only be applied to OpenAI models.") for tok in result: diff --git a/chatsky/llm/responses.py b/chatsky/llm/responses.py index 61bb0bc57..0968e3f05 100644 --- a/chatsky/llm/responses.py +++ b/chatsky/llm/responses.py @@ -7,7 +7,14 @@ from pydantic import BaseModel -def llm_response(model_name: str, prompt: str = "", history: int = 5, filter_func: Callable = lambda *args: True, message_schema: Union[None, Type[Message], Type[BaseModel]] = None, max_size: int=1000): +def llm_response( + model_name: str, + prompt: str = "", + history: int = 5, + filter_func: Callable = lambda *args: True, + message_schema: Union[None, Type[Message], Type[BaseModel]] = None, + max_size: int = 1000, +): """ Basic function for receiving LLM responses. :param ctx: Context object. (Assigned automatically) @@ -33,19 +40,25 @@ async def wrapped(ctx: Context, pipeline: Pipeline) -> Message: # populate history with global and local prompts if "global_prompt" in current_misc: global_prompt = current_misc["global_prompt"] - history_messages.append(await message_to_langchain(Message(global_prompt), pipeline=pipeline, source="system")) + history_messages.append( + await message_to_langchain(Message(global_prompt), pipeline=pipeline, source="system") + ) if "local_prompt" in current_misc: local_prompt = current_misc["local_prompt"] - history_messages.append(await message_to_langchain(Message(local_prompt), pipeline=pipeline, source="system")) + history_messages.append( + await message_to_langchain(Message(local_prompt), pipeline=pipeline, source="system") + ) if "prompt" in current_misc: node_prompt = current_misc["prompt"] - history_messages.append(await message_to_langchain(Message(node_prompt), pipeline=pipeline, source="system")) + history_messages.append( + await message_to_langchain(Message(node_prompt), pipeline=pipeline, source="system") + ) # iterate over context to retrieve history messages if not (history == 0 or len(ctx.responses) == 0 or len(ctx.requests) == 0): pairs = zip( - [ctx.requests[x] for x in range(1, len(ctx.requests)+1)], - [ctx.responses[x] for x in range(1, len(ctx.responses)+1)], + [ctx.requests[x] for x in range(1, len(ctx.requests) + 1)], + [ctx.responses[x] for x in range(1, len(ctx.responses) + 1)], ) # print(f"Pairs: {[p for p in pairs]}") if history != -1: @@ -64,4 +77,3 @@ async def wrapped(ctx: Context, pipeline: Pipeline) -> Message: return await model.respond(history_messages, message_schema=message_schema) return wrapped - diff --git a/chatsky/llm/utils.py b/chatsky/llm/utils.py index 55064271c..af101e9e9 100644 --- a/chatsky/llm/utils.py +++ b/chatsky/llm/utils.py @@ -4,7 +4,7 @@ from langchain_core.messages import HumanMessage, SystemMessage, AIMessage -async def message_to_langchain(message: Message, pipeline: Pipeline, source: str = "human", max_size: int=1000): +async def message_to_langchain(message: Message, pipeline: Pipeline, source: str = "human", max_size: int = 1000): """ Creates a langchain message from a ~chatsky.script.core.message.Message object. @@ -19,13 +19,19 @@ async def message_to_langchain(message: Message, pipeline: Pipeline, source: str if len(message.text) > max_size: raise ValueError("Message is too long.") - if message.text is None: message.text = "" + if message.text is None: + message.text = "" content = [{"type": "text", "text": message.text}] if message.attachments: for image in message.attachments: if isinstance(image, Image): - content.append({"type": "image_url", "image_url": {"url": await __attachment_to_content(image, pipeline.messenger_interface)}}) + content.append( + { + "type": "image_url", + "image_url": {"url": await __attachment_to_content(image, pipeline.messenger_interface)}, + } + ) if source == "human": return HumanMessage(content=content) diff --git a/tests/llm/test_llm.py b/tests/llm/test_llm.py index 7df4094e5..496c6e640 100644 --- a/tests/llm/test_llm.py +++ b/tests/llm/test_llm.py @@ -1,4 +1,4 @@ -from chatsky.llm.llm_api import LLM_API +from chatsky.llm.llm_api import LLM_API from chatsky.llm.responses import llm_response from chatsky.llm.conditions import llm_condition from chatsky.llm.utils import message_to_langchain, __attachment_to_content @@ -39,36 +39,40 @@ def with_structured_output(self, message_schema): def respond(self, history: list = [""]): return self.ainvoke(history) + class MockedStructuredModel: def __init__(self, root_model): self.root = root_model async def ainvoke(self, history): - inst = self.root(history = history) + inst = self.root(history=history) return inst() + class MessageSchema(BaseModel): history: list[str] def __call__(self): return {"history": self.history} + @pytest.fixture def mock_structured_model(): return MockedStructuredModel + async def test_structured_output(monkeypatch, mock_structured_model): # Create a mock LLM_API instance llm_api = LLM_API(MockChatOpenAI()) - + # Test data history = ["message1", "message2"] - + # Call the respond method result = await llm_api.respond(message_schema=MessageSchema, history=history) - + # Assert the result - expected_result = Message(text=str({"history": history}), annotations={'__generated_by_model__': ''}) + expected_result = Message(text=str({"history": history}), annotations={"__generated_by_model__": ""}) assert result == expected_result @@ -76,6 +80,7 @@ async def test_structured_output(monkeypatch, mock_structured_model): def mock_model(): return MockChatOpenAI() + class MockPipeline: def __init__(self, mock_model): self.models = {"test_model": LLM_API(mock_model), "struct_model": LLM_API(mock_structured_model)} @@ -86,6 +91,7 @@ def __init__(self, mock_model): def pipeline(mock_model): return MockPipeline(mock_model) + @pytest.fixture def filter_context(): ctx = Context.init(AbsoluteNodeLabel(flow_name="flow", node_name="node")) @@ -121,6 +127,7 @@ def context(): ctx.add_request("Last request") return ctx + async def test_message_to_langchain(pipeline): assert await message_to_langchain(Message(text="hello"), pipeline, source="human") == HumanMessage( content=[{"type": "text", "text": "hello"}] @@ -133,11 +140,22 @@ async def test_message_to_langchain(pipeline): class MockMessengerInterface(MessengerInterfaceWithAttachments): async def connect(self): pass - + async def get_attachment_bytes(self, attachment): return b"mock_bytes" -@pytest.mark.parametrize("img,expected", [(Image(source="https://raw.githubusercontent.com/deeppavlov/chatsky/master/docs/source/_static/images/Chatsky-full-dark.svg"), '')]) + +@pytest.mark.parametrize( + "img,expected", + [ + ( + Image( + source="https://raw.githubusercontent.com/deeppavlov/chatsky/master/docs/source/_static/images/Chatsky-full-dark.svg" + ), + "", + ) + ], +) async def test_attachments(img, expected): script = {"flow": {"node": {RESPONSE: Message(), TRANSITIONS: [Tr(dst="node", cnd=True)]}}} pipe = Pipeline(script=script, start_label=("flow", "node"), messenger_interface=MockMessengerInterface()) diff --git a/tutorials/llm/1_basics.py b/tutorials/llm/1_basics.py index 61327f4f0..dc9d87038 100644 --- a/tutorials/llm/1_basics.py +++ b/tutorials/llm/1_basics.py @@ -17,7 +17,7 @@ Pipeline, Transition as Tr, conditions as cnd, - destinations as dst + destinations as dst, # all the aliases used in tutorials are available for direct import # e.g. you can do `from chatsky import Tr` instead ) @@ -65,18 +65,29 @@ }, "greeting_node": { RESPONSE: llm_response(model_name="barista_model", history=0), - TRANSITIONS: [Tr(dst="main_node", cnd=cnd.ExactMatch("Who are you?"))], + TRANSITIONS: [ + Tr(dst="main_node", cnd=cnd.ExactMatch("Who are you?")) + ], }, "main_node": { RESPONSE: llm_response(model_name="barista_model"), TRANSITIONS: [ - Tr(dst="latte_art_node", cnd=cnd.ExactMatch("Tell me about latte art.")), - Tr(dst="image_desc_node", cnd=cnd.ExactMatch("Tell me what coffee is it?")), - Tr(dst="boss_node", cnd=llm_condition( - model_name="barista_model", - prompt="Return only TRUE if your customer says that he is your boss, or FALSE if he don't. Only ONE word must be in the output.", - method=Contains(pattern="TRUE"), - )), + Tr( + dst="latte_art_node", + cnd=cnd.ExactMatch("Tell me about latte art."), + ), + Tr( + dst="image_desc_node", + cnd=cnd.ExactMatch("Tell me what coffee is it?"), + ), + Tr( + dst="boss_node", + cnd=llm_condition( + model_name="barista_model", + prompt="Return only TRUE if your customer says that he is your boss, or FALSE if he don't. Only ONE word must be in the output.", + method=Contains(pattern="TRUE"), + ), + ), Tr(dst=dst.Current(), cnd=cnd.true()), ], }, @@ -92,7 +103,9 @@ model_name="barista_model", prompt="PROMPT: pretend that you have never heard about latte art before and DO NOT answer the following questions. Instead ask a person about it.", ), - TRANSITIONS: [Tr(dst="main_node", cnd=cnd.ExactMatch("Ok, goodbye."))], + TRANSITIONS: [ + Tr(dst="main_node", cnd=cnd.ExactMatch("Ok, goodbye.")) + ], }, "image_desc_node": { # we expect user to send some images of coffee. diff --git a/tutorials/llm/2_prompt_usage.py b/tutorials/llm/2_prompt_usage.py index de6550efc..f4b927dd7 100644 --- a/tutorials/llm/2_prompt_usage.py +++ b/tutorials/llm/2_prompt_usage.py @@ -35,7 +35,7 @@ Transition as Tr, conditions as cnd, destinations as dst, - labels as lbl + labels as lbl, ) from chatsky.utils.testing import ( is_interactive_mode, @@ -79,8 +79,13 @@ "greeting_node": { RESPONSE: llm_response(model_name="bank_model", history=0), TRANSITIONS: [ - Tr(dst=("loan_flow", "start_node"), cnd=cnd.ExactMatch("/loan")), - Tr(dst=("hr_flow", "start_node"), cnd=cnd.ExactMatch("/vacancies")), + Tr( + dst=("loan_flow", "start_node"), cnd=cnd.ExactMatch("/loan") + ), + Tr( + dst=("hr_flow", "start_node"), + cnd=cnd.ExactMatch("/vacancies"), + ), Tr(dst=dst.Current(), cnd=cnd.true()), ], }, @@ -100,7 +105,10 @@ "start_node": { RESPONSE: llm_response(model_name="bank_model"), TRANSITIONS: [ - Tr(dst=("greeting_flow", "greeting_node"), cnd=cnd.ExactMatch("/end")), + Tr( + dst=("greeting_flow", "greeting_node"), + cnd=cnd.ExactMatch("/end"), + ), Tr(dst=dst.Current(), cnd=cnd.true()), ], }, @@ -115,7 +123,10 @@ "start_node": { RESPONSE: llm_response(model_name="bank_model"), TRANSITIONS: [ - Tr(dst=("greeting_flow", "greeting_node"), cnd=cnd.ExactMatch("/end")), + Tr( + dst=("greeting_flow", "greeting_node"), + cnd=cnd.ExactMatch("/end"), + ), Tr(dst="cook_node", cnd=cnd.Regexp(r".*cook.*")), Tr(dst=dst.Current(), cnd=cnd.true()), ], diff --git a/tutorials/llm/3_filtering_history.py b/tutorials/llm/3_filtering_history.py index e791d90c0..0b2d3a013 100644 --- a/tutorials/llm/3_filtering_history.py +++ b/tutorials/llm/3_filtering_history.py @@ -17,7 +17,7 @@ Transition as Tr, conditions as cnd, destinations as dst, - labels as lbl + labels as lbl, ) from chatsky.utils.testing import ( is_interactive_mode, @@ -73,7 +73,9 @@ def __call__( }, "greeting_node": { RESPONSE: llm_response(model_name="assistant_model", history=0), - TRANSITIONS: [Tr(dst="main_node", cnd=cnd.ExactMatch("Who are you?"))], + TRANSITIONS: [ + Tr(dst="main_node", cnd=cnd.ExactMatch("Who are you?")) + ], }, "main_node": { RESPONSE: llm_response(model_name="assistant_model", history=3), diff --git a/tutorials/llm/4_structured_output.py b/tutorials/llm/4_structured_output.py index cd40f5538..8a77b7dbd 100644 --- a/tutorials/llm/4_structured_output.py +++ b/tutorials/llm/4_structured_output.py @@ -17,7 +17,7 @@ Transition as Tr, conditions as cnd, destinations as dst, - labels as lbl + labels as lbl, ) from chatsky.utils.testing import ( is_interactive_mode, @@ -42,7 +42,7 @@ """ # %% assistant_model = LLM_API(ChatOpenAI(model="gpt-3.5-turbo")) -movie_model = LLM_API(ChatAnthropic(model='claude-3-opus-20240229')) +movie_model = LLM_API(ChatAnthropic(model="claude-3-opus-20240229")) # %% [markdown] """ @@ -50,6 +50,8 @@ The `Movie`, inherited from the `BaseModel` will act as a schema for the response _text_, that will contain valid JSON containing desribed information. The `ImportantMessage`, inherited from the `Message` class, will otherwise define the fields of the output `Message`. In this example we will use this to mark the message as important. """ + + # %% class Movie(BaseModel): name: str = Field(description="Name of the movie") @@ -60,17 +62,23 @@ class Movie(BaseModel): class ImportantMessage(Message): text: str = Field(description="Text of the note") - misc: dict = Field(description="A dictonary with 'important' key and true/false value in it") + misc: dict = Field( + description="A dictonary with 'important' key and true/false value in it" + ) + # %% script = { GLOBAL: { TRANSITIONS: [ - Tr(dst=("greeting_flow", "start_node"), cnd=cnd.ExactMatch("/start")), + Tr( + dst=("greeting_flow", "start_node"), + cnd=cnd.ExactMatch("/start"), + ), Tr(dst=("movie_flow", "main_node"), cnd=cnd.ExactMatch("/movie")), Tr(dst=("note_flow", "main_node"), cnd=cnd.ExactMatch("/note")), - ] + ] }, "greeting_flow": { "start_node": { @@ -79,20 +87,28 @@ class ImportantMessage(Message): "fallback_node": { RESPONSE: Message("I did not quite understand you..."), TRANSITIONS: [Tr(dst="start_node", cnd=cnd.true())], - } + }, }, "movie_flow": { "main_node": { - RESPONSE: llm_response("movie_model", prompt="Ask user to request you for movie ideas.", message_schema=Movie), + RESPONSE: llm_response( + "movie_model", + prompt="Ask user to request you for movie ideas.", + message_schema=Movie, + ), TRANSITIONS: [Tr(dst=dst.Current(), cnd=cnd.true())], } }, "note_flow": { "main_node": { - RESPONSE: llm_response("note_model", prompt="Help user take notes and mark the important ones.", message_schema=ImportantMessage), + RESPONSE: llm_response( + "note_model", + prompt="Help user take notes and mark the important ones.", + message_schema=ImportantMessage, + ), TRANSITIONS: [Tr(dst=dst.Current(), cnd=cnd.true())], } - } + }, } # %%