diff --git a/chatsky/llm/conditions.py b/chatsky/llm/conditions.py index 1d7e7d66c..5f66fc96b 100644 --- a/chatsky/llm/conditions.py +++ b/chatsky/llm/conditions.py @@ -1,5 +1,6 @@ from chatsky.llm.methods import BaseMethod + def llm_condition(model_name: str, prompt: str, method: BaseMethod): """ Basic function for using LLM in condition cases. diff --git a/chatsky/llm/filters.py b/chatsky/llm/filters.py index 25f5810c2..8808565ba 100644 --- a/chatsky/llm/filters.py +++ b/chatsky/llm/filters.py @@ -14,6 +14,7 @@ class BaseFilter(BaseModel, abc.ABC): """ Base class for all message history filters. """ + @abc.abstractmethod def __call__(self, ctx: Context, request: Message, response: Message, model_name: str) -> bool: """ @@ -29,6 +30,7 @@ class IsImportant(BaseFilter): """ Filter that checks if the "important" field in a Message.misc is True. """ + def __call__( self, ctx: Context = None, request: Message = None, response: Message = None, model_name: str = None ) -> bool: @@ -43,6 +45,7 @@ class FromTheModel(BaseFilter): """ Filter that checks if the message was sent by the model. """ + def __call__( self, ctx: Context = None, request: Message = None, response: Message = None, model_name: str = None ) -> bool: diff --git a/chatsky/llm/llm_api.py b/chatsky/llm/llm_api.py index 95e8b413d..854e4e16f 100644 --- a/chatsky/llm/llm_api.py +++ b/chatsky/llm/llm_api.py @@ -12,6 +12,7 @@ from langchain_mistralai import ChatMistralAI from langchain_core.output_parsers import StrOutputParser from langchain_core.language_models.chat_models import BaseChatModel + langchain_available = True except ImportError: langchain_available = False @@ -78,7 +79,10 @@ async def respond( async def condition(self, prompt: str, method: BaseMethod, return_schema=None): async def process_input(ctx: Context, _: Pipeline) -> bool: - condition_history = [await message_to_langchain(Message(prompt), pipeline=_, source="system"), await message_to_langchain(ctx.last_request, pipeline=_, source="human")] + condition_history = [ + await message_to_langchain(Message(prompt), pipeline=_, source="system"), + await message_to_langchain(ctx.last_request, pipeline=_, source="human"), + ] result = method(ctx, await self.model.agenerate([condition_history], logprobs=True, top_logprobs=10)) return result diff --git a/chatsky/llm/methods.py b/chatsky/llm/methods.py index f157ab2dd..762c9eb54 100644 --- a/chatsky/llm/methods.py +++ b/chatsky/llm/methods.py @@ -15,6 +15,7 @@ class BaseMethod(BaseModel, abc.ABC): """ Base class to evaluate models response as condition. """ + @abc.abstractmethod async def __call__(self, ctx: Context, model_result: LLMResult) -> bool: raise NotImplementedError @@ -35,6 +36,7 @@ class Contains(BaseMethod): :return: True if pattern is contained in model result :rtype: bool """ + pattern: str async def __call__(self, ctx: Context, model_result: LLMResult) -> bool: @@ -52,11 +54,13 @@ class LogProb(BaseMethod): :return: True if logprob is higher then threshold :rtype: bool """ + target_token: str threshold: float = -0.5 + async def __call__(self, ctx: Context, model_result: LLMResult) -> bool: try: - result = model_result.generations[0][0].generation_info['logprobs']['content'][0]['top_logprobs'] + result = model_result.generations[0][0].generation_info["logprobs"]["content"][0]["top_logprobs"] except ValueError: raise ValueError("LogProb method can only be applied to OpenAI models.") for tok in result: diff --git a/chatsky/llm/responses.py b/chatsky/llm/responses.py index 61bb0bc57..0968e3f05 100644 --- a/chatsky/llm/responses.py +++ b/chatsky/llm/responses.py @@ -7,7 +7,14 @@ from pydantic import BaseModel -def llm_response(model_name: str, prompt: str = "", history: int = 5, filter_func: Callable = lambda *args: True, message_schema: Union[None, Type[Message], Type[BaseModel]] = None, max_size: int=1000): +def llm_response( + model_name: str, + prompt: str = "", + history: int = 5, + filter_func: Callable = lambda *args: True, + message_schema: Union[None, Type[Message], Type[BaseModel]] = None, + max_size: int = 1000, +): """ Basic function for receiving LLM responses. :param ctx: Context object. (Assigned automatically) @@ -33,19 +40,25 @@ async def wrapped(ctx: Context, pipeline: Pipeline) -> Message: # populate history with global and local prompts if "global_prompt" in current_misc: global_prompt = current_misc["global_prompt"] - history_messages.append(await message_to_langchain(Message(global_prompt), pipeline=pipeline, source="system")) + history_messages.append( + await message_to_langchain(Message(global_prompt), pipeline=pipeline, source="system") + ) if "local_prompt" in current_misc: local_prompt = current_misc["local_prompt"] - history_messages.append(await message_to_langchain(Message(local_prompt), pipeline=pipeline, source="system")) + history_messages.append( + await message_to_langchain(Message(local_prompt), pipeline=pipeline, source="system") + ) if "prompt" in current_misc: node_prompt = current_misc["prompt"] - history_messages.append(await message_to_langchain(Message(node_prompt), pipeline=pipeline, source="system")) + history_messages.append( + await message_to_langchain(Message(node_prompt), pipeline=pipeline, source="system") + ) # iterate over context to retrieve history messages if not (history == 0 or len(ctx.responses) == 0 or len(ctx.requests) == 0): pairs = zip( - [ctx.requests[x] for x in range(1, len(ctx.requests)+1)], - [ctx.responses[x] for x in range(1, len(ctx.responses)+1)], + [ctx.requests[x] for x in range(1, len(ctx.requests) + 1)], + [ctx.responses[x] for x in range(1, len(ctx.responses) + 1)], ) # print(f"Pairs: {[p for p in pairs]}") if history != -1: @@ -64,4 +77,3 @@ async def wrapped(ctx: Context, pipeline: Pipeline) -> Message: return await model.respond(history_messages, message_schema=message_schema) return wrapped - diff --git a/chatsky/llm/utils.py b/chatsky/llm/utils.py index 55064271c..af101e9e9 100644 --- a/chatsky/llm/utils.py +++ b/chatsky/llm/utils.py @@ -4,7 +4,7 @@ from langchain_core.messages import HumanMessage, SystemMessage, AIMessage -async def message_to_langchain(message: Message, pipeline: Pipeline, source: str = "human", max_size: int=1000): +async def message_to_langchain(message: Message, pipeline: Pipeline, source: str = "human", max_size: int = 1000): """ Creates a langchain message from a ~chatsky.script.core.message.Message object. @@ -19,13 +19,19 @@ async def message_to_langchain(message: Message, pipeline: Pipeline, source: str if len(message.text) > max_size: raise ValueError("Message is too long.") - if message.text is None: message.text = "" + if message.text is None: + message.text = "" content = [{"type": "text", "text": message.text}] if message.attachments: for image in message.attachments: if isinstance(image, Image): - content.append({"type": "image_url", "image_url": {"url": await __attachment_to_content(image, pipeline.messenger_interface)}}) + content.append( + { + "type": "image_url", + "image_url": {"url": await __attachment_to_content(image, pipeline.messenger_interface)}, + } + ) if source == "human": return HumanMessage(content=content) diff --git a/tests/llm/test_llm.py b/tests/llm/test_llm.py index 7df4094e5..496c6e640 100644 --- a/tests/llm/test_llm.py +++ b/tests/llm/test_llm.py @@ -1,4 +1,4 @@ -from chatsky.llm.llm_api import LLM_API +from chatsky.llm.llm_api import LLM_API from chatsky.llm.responses import llm_response from chatsky.llm.conditions import llm_condition from chatsky.llm.utils import message_to_langchain, __attachment_to_content @@ -39,36 +39,40 @@ def with_structured_output(self, message_schema): def respond(self, history: list = [""]): return self.ainvoke(history) + class MockedStructuredModel: def __init__(self, root_model): self.root = root_model async def ainvoke(self, history): - inst = self.root(history = history) + inst = self.root(history=history) return inst() + class MessageSchema(BaseModel): history: list[str] def __call__(self): return {"history": self.history} + @pytest.fixture def mock_structured_model(): return MockedStructuredModel + async def test_structured_output(monkeypatch, mock_structured_model): # Create a mock LLM_API instance llm_api = LLM_API(MockChatOpenAI()) - + # Test data history = ["message1", "message2"] - + # Call the respond method result = await llm_api.respond(message_schema=MessageSchema, history=history) - + # Assert the result - expected_result = Message(text=str({"history": history}), annotations={'__generated_by_model__': ''}) + expected_result = Message(text=str({"history": history}), annotations={"__generated_by_model__": ""}) assert result == expected_result @@ -76,6 +80,7 @@ async def test_structured_output(monkeypatch, mock_structured_model): def mock_model(): return MockChatOpenAI() + class MockPipeline: def __init__(self, mock_model): self.models = {"test_model": LLM_API(mock_model), "struct_model": LLM_API(mock_structured_model)} @@ -86,6 +91,7 @@ def __init__(self, mock_model): def pipeline(mock_model): return MockPipeline(mock_model) + @pytest.fixture def filter_context(): ctx = Context.init(AbsoluteNodeLabel(flow_name="flow", node_name="node")) @@ -121,6 +127,7 @@ def context(): ctx.add_request("Last request") return ctx + async def test_message_to_langchain(pipeline): assert await message_to_langchain(Message(text="hello"), pipeline, source="human") == HumanMessage( content=[{"type": "text", "text": "hello"}] @@ -133,11 +140,22 @@ async def test_message_to_langchain(pipeline): class MockMessengerInterface(MessengerInterfaceWithAttachments): async def connect(self): pass - + async def get_attachment_bytes(self, attachment): return b"mock_bytes" -@pytest.mark.parametrize("img,expected", [(Image(source="https://raw.githubusercontent.com/deeppavlov/chatsky/master/docs/source/_static/images/Chatsky-full-dark.svg"), 'data:image/svg;base64,PHN2ZyB3aWR0aD0iMTM3OCIgaGVpZ2h0PSIyNzYiIHZpZXdCb3g9IjAgMCAxMzc4IDI3NiIgZmlsbD0ibm9uZSIgeG1sbnM9Imh0dHA6Ly93d3cudzMub3JnLzIwMDAvc3ZnIj4KPHBhdGggZD0iTTE0NCAxMDguMDE5SDIyOC4wNDhDMjY3LjQ3NCAxMDguMDE5IDMwNi4zNTkgOTguODg4IDM0MS42MjMgODEuMzQ5N0MzNTAuMDY2IDc3LjE1MDcgMzYwIDgzLjI1NzUgMzYwIDkyLjY0NjlWMTA4LjAxOUMzNjAgMTg3LjEyNSAyOTUuNTI5IDI1MS4yNTMgMjE2IDI1MS4yNTNIMTE1LjM1NEMxMDYuMDQ5IDI1MS4yNTMgMTAxLjM5NiAyNTEuMjUzIDk2LjgwNzkgMjUxLjU0NUM4MC42NzU4IDI1Mi41NzEgNjQuODMyNSAyNTYuMjkxIDQ5Ljk0MjMgMjYyLjU1QzQ1LjcwNzMgMjY0LjMzIDQxLjU0NTggMjY2LjM5OSAzMy4yMjI4IDI3MC41MzlMMzIuMTk5NCAyNzEuMDQ4QzI2LjgwNDggMjczLjczMSAyNC4xMDc1IDI3NS4wNzIgMjEuOTg5OCAyNzUuNTUxQzEyLjQwOTcgMjc3LjcxNyAyLjg1MzA3IDI3MS44NDIgMC41MTY0ODggMjYyLjM1QzAgMjYwLjI1MiAwIDI1Ny4yNTIgMCAyNTEuMjUzQzAgMTcyLjE0NyA2NC40NzEgMTA4LjAxOSAxNDQgMTA4LjAxOVoiIGZpbGw9IiMwMEEzRkYiLz4KPHBhdGggZD0iTTI1MC4zNjQgMEMyODkuMjkxIDAgMzIzLjA0MSAyMi41MDUzIDMzOS41NjkgNTUuMzk4N0MzNDAuMDk1IDU2Ljc4ODQgMzQwLjI0MyA1OS4wNjQ3IDMzOS40MTUgNjEuNjg1NEMzMzguNjIxIDY0LjE5NjUgMzM3LjI3OCA2NS45MjM0IDMzNi4wODkgNjYuNzg3NEMzMDEuOTM0IDgzLjM2NTYgMjY0LjMxNiA5MiAyMjYuMTY3IDkySDE3MC4zNTdDMTU5Ljc5MyA5MiAxNTguMDExIDkxLjcwMzYgMTU2LjI5OCA5MC42MzgxQzE1NS45MDIgOTAuMzkyMyAxNTQuOTc4IDg5LjYwMzUgMTUzLjk3NSA4OC4yNzVDMTUyLjk3MyA4Ni45NDY1IDE1Mi40NjQgODUuODM2OSAxNTIuMzMzIDg1LjM4NjRDMTUxLjcyNiA4My4yODg3IDE1MS43OTQgODIuMjEwNCAxNTMuODY2IDc0LjUxNjhDMTY1LjQzNSAzMS41NjU0IDIwNC4yNjggMCAyNTAuMzY0IDBaIiBmaWxsPSIjRkZBRDBEIi8+CjxwYXRoIGQ9Ik01MDAuMDAxIDIxOEM0OTAuOCAyMTggNDgyLjMzNCAyMTYuNjc0IDQ3NC42IDIxNC4wMjJDNDY3IDIxMS4zNyA0NjAuNCAyMDcuMjYgNDU0LjggMjAxLjY5MUM0NDkuMiAxOTUuOTg5IDQ0NC44IDE4OC42MyA0NDEuNiAxNzkuNjEzQzQzOC41MzMgMTcwLjU5NyA0MzcgMTU5LjcyNCA0MzcgMTQ2Ljk5NFYxNDMuMDE3QzQzNyAxMzAuODE4IDQzOC42IDEyMC40MDkgNDQxLjggMTExLjc5QzQ0NS4xMzMgMTAzLjAzOSA0NDkuNiA5NS44Nzg1IDQ1NS4yIDkwLjMwOTRDNDYwLjkzNCA4NC43NDAzIDQ2Ny42IDgwLjYyOTggNDc1LjIgNzcuOTc3OUM0ODIuOTM0IDc1LjMyNiA0OTEuMiA3NCA1MDAuMDAxIDc0QzUwNy44NjcgNzQgNTE1LjI2NyA3NC45MjgyIDUyMi4yMDEgNzYuNzg0NUM1MjkuMTM0IDc4LjY0MDkgNTM1LjIwMSA4MS41NTggNTQwLjQwMSA4NS41MzU5QzU0NS43MzQgODkuMzgxMiA1NTAuMDAxIDk0LjM1MzYgNTUzLjIwMSAxMDAuNDUzQzU1Ni41MzQgMTA2LjU1MiA1NTguNDY4IDExMy43NzkgNTU5LjAwMSAxMjIuMTMzSDUyNy44MDFDNTI2LjMzNCAxMTQuMTc3IDUyMy4xMzQgMTA4LjM0MyA1MTguMjAxIDEwNC42M0M1MTMuMjY3IDEwMC45MTcgNTA3LjIwMSA5OS4wNjA4IDUwMC4wMDEgOTkuMDYwOEM0OTUuODY3IDk5LjA2MDggNDkxLjg2NyA5OS43OTAxIDQ4OCAxMDEuMjQ5QzQ4NC4yNjcgMTAyLjcwNyA0ODAuOTM0IDEwNS4xNiA0NzggMTA4LjYwOEM0NzUuMDY3IDExMi4wNTUgNDcyLjY2NyAxMTYuNTY0IDQ3MC44IDEyMi4xMzNDNDY5LjA2NyAxMjcuNzAyIDQ2OC4yIDEzNC42NjMgNDY4LjIgMTQzLjAxN1YxNDYuOTk0QzQ2OC4yIDE1NS43NDYgNDY5LjEzNCAxNjMuMTA1IDQ3MSAxNjkuMDcyQzQ3Mi44NjcgMTc0LjkwNiA0NzUuMjY3IDE3OS42MTMgNDc4LjIgMTgzLjE5M0M0ODEuMjY3IDE4Ni42NDEgNDg0LjY2NyAxODkuMTYgNDg4LjQgMTkwLjc1MUM0OTIuMjY3IDE5Mi4yMSA0OTYuMTM0IDE5Mi45MzkgNTAwLjAwMSAxOTIuOTM5QzUwNy42MDEgMTkyLjkzOSA1MTMuODY3IDE5MS4wMTcgNTE4LjgwMSAxODcuMTcxQzUyMy43MzQgMTgzLjE5MyA1MjYuNzM0IDE3Ny40MjUgNTI3LjgwMSAxNjkuODY3SDU1OS4wMDFDNTU4LjMzNCAxNzguNjE5IDU1Ni40MDEgMTg2LjA0NCA1NTMuMjAxIDE5Mi4xNDRDNTUwLjAwMSAxOTguMjQzIDU0NS44MDEgMjAzLjIxNSA1NDAuNjAxIDIwNy4wNjFDNTM1LjQwMSAyMTAuOTA2IDUyOS4zMzQgMjEzLjY5MSA1MjIuNDAxIDIxNS40MTRDNTE1LjQ2NyAyMTcuMTM4IDUwOC4wMDEgMjE4IDUwMC4wMDEgMjE4WiIgZmlsbD0iIzAwQTNGRiIvPgo8cGF0aCBkPSJNNTgzLjA0IDc2LjM4NjdINjEzLjA0MVYxMzEuNjhINjczLjA0MVY3Ni4zODY3SDcwMy4wNDFWMjE1LjYxM0g2NzMuMDQxVjE1Ni4zNDNINjEzLjA0MVYyMTUuNjEzSDU4My4wNFY3Ni4zODY3WiIgZmlsbD0iIzAwQTNGRiIvPgo8cGF0aCBkPSJNODEyLjU5NSAxODcuMTcxSDc2MC43OTVMNzUwLjE5NSAyMTUuNjEzSDcyMC45OTVMNzcyLjk5NSA3Ni4zODY3SDgwMi45OTVMODU0Ljk5NiAyMTUuNjEzSDgyMy4zOTVMODEyLjU5NSAxODcuMTcxWk03NjkuOTk1IDE2Mi41MDhIODAzLjU5NUw3ODYuNzk1IDExNC4xNzdMNzY5Ljk5NSAxNjIuNTA4WiIgZmlsbD0iIzAwQTNGRiIvPgo8cGF0aCBkPSJNODkyLjAxOSAxMDEuMDVIODQ3LjAxOVY3Ni4zODY3SDk2Ny4wMlYxMDEuMDVIOTIyLjAyVjIxNS42MTNIODkyLjAxOVYxMDEuMDVaIiBmaWxsPSIjMDBBM0ZGIi8+CjxwYXRoIGQ9Ik0xMDM1Ljk3IDIxOEMxMDE3LjQ0IDIxOCAxMDAzLjQ0IDIxNC4yMjEgOTkzLjk3MyAyMDYuNjYzQzk4NC41MDcgMTk5LjEwNSA5NzkuNTA3IDE4OS4xNiA5NzguOTczIDE3Ni44MjlIMTAwOC45N0MxMDA5LjUxIDE3OS40ODEgMTAxMC4zMSAxODEuODAxIDEwMTEuMzcgMTgzLjc5QzEwMTIuNDQgMTg1Ljc3OSAxMDE0LjA0IDE4Ny40MzYgMTAxNi4xNyAxODguNzYyQzEwMTguMzEgMTkwLjA4OCAxMDIwLjk3IDE5MS4xNDkgMTAyNC4xNyAxOTEuOTQ1QzEwMjcuMzcgMTkyLjYwOCAxMDMxLjMxIDE5Mi45MzkgMTAzNS45NyAxOTIuOTM5QzEwNDUuNTcgMTkyLjkzOSAxMDUyLjQ0IDE5MS42OCAxMDU2LjU3IDE4OS4xNkMxMDYwLjg0IDE4Ni41MDggMTA2Mi45NyAxODIuNzk2IDEwNjIuOTcgMTc4LjAyMkMxMDYyLjk3IDE3Mi4xODggMTA2MC4xNyAxNjcuODEyIDEwNTQuNTcgMTY0Ljg5NUMxMDQ5LjExIDE2MS44NDUgMTA0MC4zMSAxNTkuMTI3IDEwMjguMTcgMTU2Ljc0QzEwMjAuOTcgMTU1LjQxNCAxMDE0LjU3IDE1My42OTEgMTAwOC45NyAxNTEuNTY5QzEwMDMuMzcgMTQ5LjMxNSA5OTguNjQgMTQ2LjUzIDk5NC43NzMgMTQzLjIxNUM5OTAuOTA3IDEzOS43NjggOTg3Ljk3MyAxMzUuNjU3IDk4NS45NzMgMTMwLjg4NEM5ODMuOTczIDEyNi4xMSA5ODIuOTczIDEyMC40MDkgOTgyLjk3MyAxMTMuNzc5Qzk4Mi45NzMgMTA3LjgxMiA5ODQuMTczIDEwMi4zNzYgOTg2LjU3MyA5Ny40Njk2Qzk4OS4xMDcgOTIuNTYzNSA5OTIuNjQgODguMzg2NyA5OTcuMTczIDg0LjkzOTJDMTAwMS44NCA4MS4zNTkxIDEwMDcuNDQgNzguNjQwOSAxMDEzLjk3IDc2Ljc4NDVDMTAyMC41MSA3NC45MjgyIDEwMjcuODQgNzQgMTAzNS45NyA3NEMxMDQ0Ljc3IDc0IDEwNTIuMzcgNzQuOTI4MiAxMDU4Ljc3IDc2Ljc4NDVDMTA2NS4xNyA3OC42NDA5IDEwNzAuNTEgODEuMjkyOCAxMDc0Ljc3IDg0Ljc0MDNDMTA3OS4xNyA4OC4xODc4IDEwODIuNTEgOTIuMjk4MyAxMDg0Ljc3IDk3LjA3MThDMTA4Ny4wNCAxMDEuODQ1IDEwODguNDQgMTA3LjIxNSAxMDg4Ljk3IDExMy4xODJIMTA1OC45N0MxMDU4LjA0IDEwOC40MDkgMTA1NS45MSAxMDQuODk1IDEwNTIuNTcgMTAyLjY0MUMxMDQ5LjI0IDEwMC4yNTQgMTA0My43MSA5OS4wNjA4IDEwMzUuOTcgOTkuMDYwOEMxMDI4LjExIDk5LjA2MDggMTAyMi4zMSAxMDAuMzIgMTAxOC41NyAxMDIuODRDMTAxNC44NCAxMDUuMzU5IDEwMTIuOTcgMTA4Ljc0IDEwMTIuOTcgMTEyLjk4M0MxMDEyLjk3IDExOC4wMjIgMTAxNS41MSAxMjIuMDY2IDEwMjAuNTcgMTI1LjExNkMxMDI1Ljc3IDEyOC4wMzMgMTAzMy45MSAxMzAuNTUyIDEwNDQuOTcgMTMyLjY3NEMxMDUyLjU3IDEzNC4xMzMgMTA1OS4zMSAxMzUuOTIzIDEwNjUuMTcgMTM4LjA0NEMxMDcxLjE3IDE0MC4xNjYgMTA3Ni4xNyAxNDIuODg0IDEwODAuMTcgMTQ2LjE5OUMxMDg0LjMxIDE0OS41MTQgMTA4Ny40NCAxNTMuNjI0IDEwODkuNTcgMTU4LjUzQzEwOTEuODQgMTYzLjMwNCAxMDkyLjk3IDE2OS4wNzIgMTA5Mi45NyAxNzUuODM0QzEwOTIuOTcgMTg4LjY5NiAxMDg3Ljk3IDE5OC45NzIgMTA3Ny45NyAyMDYuNjYzQzEwNjguMTEgMjE0LjIyMSAxMDU0LjExIDIxOCAxMDM1Ljk3IDIxOFoiIGZpbGw9IiMwMEEzRkYiLz4KPHBhdGggZD0iTTExMTguMDEgNzYuMzg2N0gxMTQ4LjAxVjEzNS4wNjFMMTE5OS4yMSA3Ni4zODY3SDEyMzQuMDFMMTE4MS44MSAxMzMuNDdMMTIzOC4wMSAyMTUuNjEzSDEyMDIuODFMMTE2MS42MSAxNTMuMTZMMTE0OC4wMSAxNjcuODc4VjIxNS42MTNIMTExOC4wMVY3Ni4zODY3WiIgZmlsbD0iIzAwQTNGRiIvPgo8cGF0aCBkPSJNMTI5OSAxNjIuNTA4TDEyNDYgNzYuMzg2N0gxMjgxLjJMMTMxNC42IDEzMi4yNzZMMTM0NiA3Ni4zODY3SDEzNzhMMTMyOSAxNjIuNTA4VjIxNS42MTNIMTI5OVYxNjIuNTA4WiIgZmlsbD0iIzAwQTNGRiIvPgo8L3N2Zz4K')]) + +@pytest.mark.parametrize( + "img,expected", + [ + ( + Image( + source="https://raw.githubusercontent.com/deeppavlov/chatsky/master/docs/source/_static/images/Chatsky-full-dark.svg" + ), + "data:image/svg;base64,PHN2ZyB3aWR0aD0iMTM3OCIgaGVpZ2h0PSIyNzYiIHZpZXdCb3g9IjAgMCAxMzc4IDI3NiIgZmlsbD0ibm9uZSIgeG1sbnM9Imh0dHA6Ly93d3cudzMub3JnLzIwMDAvc3ZnIj4KPHBhdGggZD0iTTE0NCAxMDguMDE5SDIyOC4wNDhDMjY3LjQ3NCAxMDguMDE5IDMwNi4zNTkgOTguODg4IDM0MS42MjMgODEuMzQ5N0MzNTAuMDY2IDc3LjE1MDcgMzYwIDgzLjI1NzUgMzYwIDkyLjY0NjlWMTA4LjAxOUMzNjAgMTg3LjEyNSAyOTUuNTI5IDI1MS4yNTMgMjE2IDI1MS4yNTNIMTE1LjM1NEMxMDYuMDQ5IDI1MS4yNTMgMTAxLjM5NiAyNTEuMjUzIDk2LjgwNzkgMjUxLjU0NUM4MC42NzU4IDI1Mi41NzEgNjQuODMyNSAyNTYuMjkxIDQ5Ljk0MjMgMjYyLjU1QzQ1LjcwNzMgMjY0LjMzIDQxLjU0NTggMjY2LjM5OSAzMy4yMjI4IDI3MC41MzlMMzIuMTk5NCAyNzEuMDQ4QzI2LjgwNDggMjczLjczMSAyNC4xMDc1IDI3NS4wNzIgMjEuOTg5OCAyNzUuNTUxQzEyLjQwOTcgMjc3LjcxNyAyLjg1MzA3IDI3MS44NDIgMC41MTY0ODggMjYyLjM1QzAgMjYwLjI1MiAwIDI1Ny4yNTIgMCAyNTEuMjUzQzAgMTcyLjE0NyA2NC40NzEgMTA4LjAxOSAxNDQgMTA4LjAxOVoiIGZpbGw9IiMwMEEzRkYiLz4KPHBhdGggZD0iTTI1MC4zNjQgMEMyODkuMjkxIDAgMzIzLjA0MSAyMi41MDUzIDMzOS41NjkgNTUuMzk4N0MzNDAuMDk1IDU2Ljc4ODQgMzQwLjI0MyA1OS4wNjQ3IDMzOS40MTUgNjEuNjg1NEMzMzguNjIxIDY0LjE5NjUgMzM3LjI3OCA2NS45MjM0IDMzNi4wODkgNjYuNzg3NEMzMDEuOTM0IDgzLjM2NTYgMjY0LjMxNiA5MiAyMjYuMTY3IDkySDE3MC4zNTdDMTU5Ljc5MyA5MiAxNTguMDExIDkxLjcwMzYgMTU2LjI5OCA5MC42MzgxQzE1NS45MDIgOTAuMzkyMyAxNTQuOTc4IDg5LjYwMzUgMTUzLjk3NSA4OC4yNzVDMTUyLjk3MyA4Ni45NDY1IDE1Mi40NjQgODUuODM2OSAxNTIuMzMzIDg1LjM4NjRDMTUxLjcyNiA4My4yODg3IDE1MS43OTQgODIuMjEwNCAxNTMuODY2IDc0LjUxNjhDMTY1LjQzNSAzMS41NjU0IDIwNC4yNjggMCAyNTAuMzY0IDBaIiBmaWxsPSIjRkZBRDBEIi8+CjxwYXRoIGQ9Ik01MDAuMDAxIDIxOEM0OTAuOCAyMTggNDgyLjMzNCAyMTYuNjc0IDQ3NC42IDIxNC4wMjJDNDY3IDIxMS4zNyA0NjAuNCAyMDcuMjYgNDU0LjggMjAxLjY5MUM0NDkuMiAxOTUuOTg5IDQ0NC44IDE4OC42MyA0NDEuNiAxNzkuNjEzQzQzOC41MzMgMTcwLjU5NyA0MzcgMTU5LjcyNCA0MzcgMTQ2Ljk5NFYxNDMuMDE3QzQzNyAxMzAuODE4IDQzOC42IDEyMC40MDkgNDQxLjggMTExLjc5QzQ0NS4xMzMgMTAzLjAzOSA0NDkuNiA5NS44Nzg1IDQ1NS4yIDkwLjMwOTRDNDYwLjkzNCA4NC43NDAzIDQ2Ny42IDgwLjYyOTggNDc1LjIgNzcuOTc3OUM0ODIuOTM0IDc1LjMyNiA0OTEuMiA3NCA1MDAuMDAxIDc0QzUwNy44NjcgNzQgNTE1LjI2NyA3NC45MjgyIDUyMi4yMDEgNzYuNzg0NUM1MjkuMTM0IDc4LjY0MDkgNTM1LjIwMSA4MS41NTggNTQwLjQwMSA4NS41MzU5QzU0NS43MzQgODkuMzgxMiA1NTAuMDAxIDk0LjM1MzYgNTUzLjIwMSAxMDAuNDUzQzU1Ni41MzQgMTA2LjU1MiA1NTguNDY4IDExMy43NzkgNTU5LjAwMSAxMjIuMTMzSDUyNy44MDFDNTI2LjMzNCAxMTQuMTc3IDUyMy4xMzQgMTA4LjM0MyA1MTguMjAxIDEwNC42M0M1MTMuMjY3IDEwMC45MTcgNTA3LjIwMSA5OS4wNjA4IDUwMC4wMDEgOTkuMDYwOEM0OTUuODY3IDk5LjA2MDggNDkxLjg2NyA5OS43OTAxIDQ4OCAxMDEuMjQ5QzQ4NC4yNjcgMTAyLjcwNyA0ODAuOTM0IDEwNS4xNiA0NzggMTA4LjYwOEM0NzUuMDY3IDExMi4wNTUgNDcyLjY2NyAxMTYuNTY0IDQ3MC44IDEyMi4xMzNDNDY5LjA2NyAxMjcuNzAyIDQ2OC4yIDEzNC42NjMgNDY4LjIgMTQzLjAxN1YxNDYuOTk0QzQ2OC4yIDE1NS43NDYgNDY5LjEzNCAxNjMuMTA1IDQ3MSAxNjkuMDcyQzQ3Mi44NjcgMTc0LjkwNiA0NzUuMjY3IDE3OS42MTMgNDc4LjIgMTgzLjE5M0M0ODEuMjY3IDE4Ni42NDEgNDg0LjY2NyAxODkuMTYgNDg4LjQgMTkwLjc1MUM0OTIuMjY3IDE5Mi4yMSA0OTYuMTM0IDE5Mi45MzkgNTAwLjAwMSAxOTIuOTM5QzUwNy42MDEgMTkyLjkzOSA1MTMuODY3IDE5MS4wMTcgNTE4LjgwMSAxODcuMTcxQzUyMy43MzQgMTgzLjE5MyA1MjYuNzM0IDE3Ny40MjUgNTI3LjgwMSAxNjkuODY3SDU1OS4wMDFDNTU4LjMzNCAxNzguNjE5IDU1Ni40MDEgMTg2LjA0NCA1NTMuMjAxIDE5Mi4xNDRDNTUwLjAwMSAxOTguMjQzIDU0NS44MDEgMjAzLjIxNSA1NDAuNjAxIDIwNy4wNjFDNTM1LjQwMSAyMTAuOTA2IDUyOS4zMzQgMjEzLjY5MSA1MjIuNDAxIDIxNS40MTRDNTE1LjQ2NyAyMTcuMTM4IDUwOC4wMDEgMjE4IDUwMC4wMDEgMjE4WiIgZmlsbD0iIzAwQTNGRiIvPgo8cGF0aCBkPSJNNTgzLjA0IDc2LjM4NjdINjEzLjA0MVYxMzEuNjhINjczLjA0MVY3Ni4zODY3SDcwMy4wNDFWMjE1LjYxM0g2NzMuMDQxVjE1Ni4zNDNINjEzLjA0MVYyMTUuNjEzSDU4My4wNFY3Ni4zODY3WiIgZmlsbD0iIzAwQTNGRiIvPgo8cGF0aCBkPSJNODEyLjU5NSAxODcuMTcxSDc2MC43OTVMNzUwLjE5NSAyMTUuNjEzSDcyMC45OTVMNzcyLjk5NSA3Ni4zODY3SDgwMi45OTVMODU0Ljk5NiAyMTUuNjEzSDgyMy4zOTVMODEyLjU5NSAxODcuMTcxWk03NjkuOTk1IDE2Mi41MDhIODAzLjU5NUw3ODYuNzk1IDExNC4xNzdMNzY5Ljk5NSAxNjIuNTA4WiIgZmlsbD0iIzAwQTNGRiIvPgo8cGF0aCBkPSJNODkyLjAxOSAxMDEuMDVIODQ3LjAxOVY3Ni4zODY3SDk2Ny4wMlYxMDEuMDVIOTIyLjAyVjIxNS42MTNIODkyLjAxOVYxMDEuMDVaIiBmaWxsPSIjMDBBM0ZGIi8+CjxwYXRoIGQ9Ik0xMDM1Ljk3IDIxOEMxMDE3LjQ0IDIxOCAxMDAzLjQ0IDIxNC4yMjEgOTkzLjk3MyAyMDYuNjYzQzk4NC41MDcgMTk5LjEwNSA5NzkuNTA3IDE4OS4xNiA5NzguOTczIDE3Ni44MjlIMTAwOC45N0MxMDA5LjUxIDE3OS40ODEgMTAxMC4zMSAxODEuODAxIDEwMTEuMzcgMTgzLjc5QzEwMTIuNDQgMTg1Ljc3OSAxMDE0LjA0IDE4Ny40MzYgMTAxNi4xNyAxODguNzYyQzEwMTguMzEgMTkwLjA4OCAxMDIwLjk3IDE5MS4xNDkgMTAyNC4xNyAxOTEuOTQ1QzEwMjcuMzcgMTkyLjYwOCAxMDMxLjMxIDE5Mi45MzkgMTAzNS45NyAxOTIuOTM5QzEwNDUuNTcgMTkyLjkzOSAxMDUyLjQ0IDE5MS42OCAxMDU2LjU3IDE4OS4xNkMxMDYwLjg0IDE4Ni41MDggMTA2Mi45NyAxODIuNzk2IDEwNjIuOTcgMTc4LjAyMkMxMDYyLjk3IDE3Mi4xODggMTA2MC4xNyAxNjcuODEyIDEwNTQuNTcgMTY0Ljg5NUMxMDQ5LjExIDE2MS44NDUgMTA0MC4zMSAxNTkuMTI3IDEwMjguMTcgMTU2Ljc0QzEwMjAuOTcgMTU1LjQxNCAxMDE0LjU3IDE1My42OTEgMTAwOC45NyAxNTEuNTY5QzEwMDMuMzcgMTQ5LjMxNSA5OTguNjQgMTQ2LjUzIDk5NC43NzMgMTQzLjIxNUM5OTAuOTA3IDEzOS43NjggOTg3Ljk3MyAxMzUuNjU3IDk4NS45NzMgMTMwLjg4NEM5ODMuOTczIDEyNi4xMSA5ODIuOTczIDEyMC40MDkgOTgyLjk3MyAxMTMuNzc5Qzk4Mi45NzMgMTA3LjgxMiA5ODQuMTczIDEwMi4zNzYgOTg2LjU3MyA5Ny40Njk2Qzk4OS4xMDcgOTIuNTYzNSA5OTIuNjQgODguMzg2NyA5OTcuMTczIDg0LjkzOTJDMTAwMS44NCA4MS4zNTkxIDEwMDcuNDQgNzguNjQwOSAxMDEzLjk3IDc2Ljc4NDVDMTAyMC41MSA3NC45MjgyIDEwMjcuODQgNzQgMTAzNS45NyA3NEMxMDQ0Ljc3IDc0IDEwNTIuMzcgNzQuOTI4MiAxMDU4Ljc3IDc2Ljc4NDVDMTA2NS4xNyA3OC42NDA5IDEwNzAuNTEgODEuMjkyOCAxMDc0Ljc3IDg0Ljc0MDNDMTA3OS4xNyA4OC4xODc4IDEwODIuNTEgOTIuMjk4MyAxMDg0Ljc3IDk3LjA3MThDMTA4Ny4wNCAxMDEuODQ1IDEwODguNDQgMTA3LjIxNSAxMDg4Ljk3IDExMy4xODJIMTA1OC45N0MxMDU4LjA0IDEwOC40MDkgMTA1NS45MSAxMDQuODk1IDEwNTIuNTcgMTAyLjY0MUMxMDQ5LjI0IDEwMC4yNTQgMTA0My43MSA5OS4wNjA4IDEwMzUuOTcgOTkuMDYwOEMxMDI4LjExIDk5LjA2MDggMTAyMi4zMSAxMDAuMzIgMTAxOC41NyAxMDIuODRDMTAxNC44NCAxMDUuMzU5IDEwMTIuOTcgMTA4Ljc0IDEwMTIuOTcgMTEyLjk4M0MxMDEyLjk3IDExOC4wMjIgMTAxNS41MSAxMjIuMDY2IDEwMjAuNTcgMTI1LjExNkMxMDI1Ljc3IDEyOC4wMzMgMTAzMy45MSAxMzAuNTUyIDEwNDQuOTcgMTMyLjY3NEMxMDUyLjU3IDEzNC4xMzMgMTA1OS4zMSAxMzUuOTIzIDEwNjUuMTcgMTM4LjA0NEMxMDcxLjE3IDE0MC4xNjYgMTA3Ni4xNyAxNDIuODg0IDEwODAuMTcgMTQ2LjE5OUMxMDg0LjMxIDE0OS41MTQgMTA4Ny40NCAxNTMuNjI0IDEwODkuNTcgMTU4LjUzQzEwOTEuODQgMTYzLjMwNCAxMDkyLjk3IDE2OS4wNzIgMTA5Mi45NyAxNzUuODM0QzEwOTIuOTcgMTg4LjY5NiAxMDg3Ljk3IDE5OC45NzIgMTA3Ny45NyAyMDYuNjYzQzEwNjguMTEgMjE0LjIyMSAxMDU0LjExIDIxOCAxMDM1Ljk3IDIxOFoiIGZpbGw9IiMwMEEzRkYiLz4KPHBhdGggZD0iTTExMTguMDEgNzYuMzg2N0gxMTQ4LjAxVjEzNS4wNjFMMTE5OS4yMSA3Ni4zODY3SDEyMzQuMDFMMTE4MS44MSAxMzMuNDdMMTIzOC4wMSAyMTUuNjEzSDEyMDIuODFMMTE2MS42MSAxNTMuMTZMMTE0OC4wMSAxNjcuODc4VjIxNS42MTNIMTExOC4wMVY3Ni4zODY3WiIgZmlsbD0iIzAwQTNGRiIvPgo8cGF0aCBkPSJNMTI5OSAxNjIuNTA4TDEyNDYgNzYuMzg2N0gxMjgxLjJMMTMxNC42IDEzMi4yNzZMMTM0NiA3Ni4zODY3SDEzNzhMMTMyOSAxNjIuNTA4VjIxNS42MTNIMTI5OVYxNjIuNTA4WiIgZmlsbD0iIzAwQTNGRiIvPgo8L3N2Zz4K", + ) + ], +) async def test_attachments(img, expected): script = {"flow": {"node": {RESPONSE: Message(), TRANSITIONS: [Tr(dst="node", cnd=True)]}}} pipe = Pipeline(script=script, start_label=("flow", "node"), messenger_interface=MockMessengerInterface()) diff --git a/tutorials/llm/1_basics.py b/tutorials/llm/1_basics.py index 61327f4f0..dc9d87038 100644 --- a/tutorials/llm/1_basics.py +++ b/tutorials/llm/1_basics.py @@ -17,7 +17,7 @@ Pipeline, Transition as Tr, conditions as cnd, - destinations as dst + destinations as dst, # all the aliases used in tutorials are available for direct import # e.g. you can do `from chatsky import Tr` instead ) @@ -65,18 +65,29 @@ }, "greeting_node": { RESPONSE: llm_response(model_name="barista_model", history=0), - TRANSITIONS: [Tr(dst="main_node", cnd=cnd.ExactMatch("Who are you?"))], + TRANSITIONS: [ + Tr(dst="main_node", cnd=cnd.ExactMatch("Who are you?")) + ], }, "main_node": { RESPONSE: llm_response(model_name="barista_model"), TRANSITIONS: [ - Tr(dst="latte_art_node", cnd=cnd.ExactMatch("Tell me about latte art.")), - Tr(dst="image_desc_node", cnd=cnd.ExactMatch("Tell me what coffee is it?")), - Tr(dst="boss_node", cnd=llm_condition( - model_name="barista_model", - prompt="Return only TRUE if your customer says that he is your boss, or FALSE if he don't. Only ONE word must be in the output.", - method=Contains(pattern="TRUE"), - )), + Tr( + dst="latte_art_node", + cnd=cnd.ExactMatch("Tell me about latte art."), + ), + Tr( + dst="image_desc_node", + cnd=cnd.ExactMatch("Tell me what coffee is it?"), + ), + Tr( + dst="boss_node", + cnd=llm_condition( + model_name="barista_model", + prompt="Return only TRUE if your customer says that he is your boss, or FALSE if he don't. Only ONE word must be in the output.", + method=Contains(pattern="TRUE"), + ), + ), Tr(dst=dst.Current(), cnd=cnd.true()), ], }, @@ -92,7 +103,9 @@ model_name="barista_model", prompt="PROMPT: pretend that you have never heard about latte art before and DO NOT answer the following questions. Instead ask a person about it.", ), - TRANSITIONS: [Tr(dst="main_node", cnd=cnd.ExactMatch("Ok, goodbye."))], + TRANSITIONS: [ + Tr(dst="main_node", cnd=cnd.ExactMatch("Ok, goodbye.")) + ], }, "image_desc_node": { # we expect user to send some images of coffee. diff --git a/tutorials/llm/2_prompt_usage.py b/tutorials/llm/2_prompt_usage.py index de6550efc..f4b927dd7 100644 --- a/tutorials/llm/2_prompt_usage.py +++ b/tutorials/llm/2_prompt_usage.py @@ -35,7 +35,7 @@ Transition as Tr, conditions as cnd, destinations as dst, - labels as lbl + labels as lbl, ) from chatsky.utils.testing import ( is_interactive_mode, @@ -79,8 +79,13 @@ "greeting_node": { RESPONSE: llm_response(model_name="bank_model", history=0), TRANSITIONS: [ - Tr(dst=("loan_flow", "start_node"), cnd=cnd.ExactMatch("/loan")), - Tr(dst=("hr_flow", "start_node"), cnd=cnd.ExactMatch("/vacancies")), + Tr( + dst=("loan_flow", "start_node"), cnd=cnd.ExactMatch("/loan") + ), + Tr( + dst=("hr_flow", "start_node"), + cnd=cnd.ExactMatch("/vacancies"), + ), Tr(dst=dst.Current(), cnd=cnd.true()), ], }, @@ -100,7 +105,10 @@ "start_node": { RESPONSE: llm_response(model_name="bank_model"), TRANSITIONS: [ - Tr(dst=("greeting_flow", "greeting_node"), cnd=cnd.ExactMatch("/end")), + Tr( + dst=("greeting_flow", "greeting_node"), + cnd=cnd.ExactMatch("/end"), + ), Tr(dst=dst.Current(), cnd=cnd.true()), ], }, @@ -115,7 +123,10 @@ "start_node": { RESPONSE: llm_response(model_name="bank_model"), TRANSITIONS: [ - Tr(dst=("greeting_flow", "greeting_node"), cnd=cnd.ExactMatch("/end")), + Tr( + dst=("greeting_flow", "greeting_node"), + cnd=cnd.ExactMatch("/end"), + ), Tr(dst="cook_node", cnd=cnd.Regexp(r".*cook.*")), Tr(dst=dst.Current(), cnd=cnd.true()), ], diff --git a/tutorials/llm/3_filtering_history.py b/tutorials/llm/3_filtering_history.py index e791d90c0..0b2d3a013 100644 --- a/tutorials/llm/3_filtering_history.py +++ b/tutorials/llm/3_filtering_history.py @@ -17,7 +17,7 @@ Transition as Tr, conditions as cnd, destinations as dst, - labels as lbl + labels as lbl, ) from chatsky.utils.testing import ( is_interactive_mode, @@ -73,7 +73,9 @@ def __call__( }, "greeting_node": { RESPONSE: llm_response(model_name="assistant_model", history=0), - TRANSITIONS: [Tr(dst="main_node", cnd=cnd.ExactMatch("Who are you?"))], + TRANSITIONS: [ + Tr(dst="main_node", cnd=cnd.ExactMatch("Who are you?")) + ], }, "main_node": { RESPONSE: llm_response(model_name="assistant_model", history=3), diff --git a/tutorials/llm/4_structured_output.py b/tutorials/llm/4_structured_output.py index cd40f5538..8a77b7dbd 100644 --- a/tutorials/llm/4_structured_output.py +++ b/tutorials/llm/4_structured_output.py @@ -17,7 +17,7 @@ Transition as Tr, conditions as cnd, destinations as dst, - labels as lbl + labels as lbl, ) from chatsky.utils.testing import ( is_interactive_mode, @@ -42,7 +42,7 @@ """ # %% assistant_model = LLM_API(ChatOpenAI(model="gpt-3.5-turbo")) -movie_model = LLM_API(ChatAnthropic(model='claude-3-opus-20240229')) +movie_model = LLM_API(ChatAnthropic(model="claude-3-opus-20240229")) # %% [markdown] """ @@ -50,6 +50,8 @@ The `Movie`, inherited from the `BaseModel` will act as a schema for the response _text_, that will contain valid JSON containing desribed information. The `ImportantMessage`, inherited from the `Message` class, will otherwise define the fields of the output `Message`. In this example we will use this to mark the message as important. """ + + # %% class Movie(BaseModel): name: str = Field(description="Name of the movie") @@ -60,17 +62,23 @@ class Movie(BaseModel): class ImportantMessage(Message): text: str = Field(description="Text of the note") - misc: dict = Field(description="A dictonary with 'important' key and true/false value in it") + misc: dict = Field( + description="A dictonary with 'important' key and true/false value in it" + ) + # %% script = { GLOBAL: { TRANSITIONS: [ - Tr(dst=("greeting_flow", "start_node"), cnd=cnd.ExactMatch("/start")), + Tr( + dst=("greeting_flow", "start_node"), + cnd=cnd.ExactMatch("/start"), + ), Tr(dst=("movie_flow", "main_node"), cnd=cnd.ExactMatch("/movie")), Tr(dst=("note_flow", "main_node"), cnd=cnd.ExactMatch("/note")), - ] + ] }, "greeting_flow": { "start_node": { @@ -79,20 +87,28 @@ class ImportantMessage(Message): "fallback_node": { RESPONSE: Message("I did not quite understand you..."), TRANSITIONS: [Tr(dst="start_node", cnd=cnd.true())], - } + }, }, "movie_flow": { "main_node": { - RESPONSE: llm_response("movie_model", prompt="Ask user to request you for movie ideas.", message_schema=Movie), + RESPONSE: llm_response( + "movie_model", + prompt="Ask user to request you for movie ideas.", + message_schema=Movie, + ), TRANSITIONS: [Tr(dst=dst.Current(), cnd=cnd.true())], } }, "note_flow": { "main_node": { - RESPONSE: llm_response("note_model", prompt="Help user take notes and mark the important ones.", message_schema=ImportantMessage), + RESPONSE: llm_response( + "note_model", + prompt="Help user take notes and mark the important ones.", + message_schema=ImportantMessage, + ), TRANSITIONS: [Tr(dst=dst.Current(), cnd=cnd.true())], } - } + }, } # %%