diff --git a/e2e/preview/components/test_chatgpt_generator.py b/e2e/preview/components/test_chatgpt_generator.py new file mode 100644 index 0000000000..c3fad4038d --- /dev/null +++ b/e2e/preview/components/test_chatgpt_generator.py @@ -0,0 +1,64 @@ +import os +import pytest +from haystack.preview.components.generators.openai.chatgpt import ChatGPTGenerator + + +@pytest.mark.skipif( + not os.environ.get("OPENAI_API_KEY", None), + reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.", +) +def test_chatgpt_generator_run(): + component = ChatGPTGenerator(api_key=os.environ.get("OPENAI_API_KEY")) + results = component.run( + prompts=["What's the capital of France?", "What's the capital of Germany?"], model_parameters={"n": 1} + ) + + assert len(results["replies"]) == 2 + assert len(results["replies"][0]) == 1 + assert "Paris" in results["replies"][0][0] + assert len(results["replies"][1]) == 1 + assert "Berlin" in results["replies"][1][0] + + assert len(results["metadata"]) == 2 + assert len(results["metadata"][0]) == 1 + assert "gpt-3.5-turbo" in results["metadata"][0][0]["model"] + assert "stop" == results["metadata"][0][0]["finish_reason"] + assert len(results["metadata"][1]) == 1 + assert "gpt-3.5-turbo" in results["metadata"][1][0]["model"] + assert "stop" == results["metadata"][1][0]["finish_reason"] + + +@pytest.mark.skipif( + not os.environ.get("OPENAI_API_KEY", None), + reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.", +) +def test_chatgpt_generator_run_streaming(): + class Callback: + def __init__(self): + self.responses = "" + + def __call__(self, token, event_data): + self.responses += token + return token + + callback = Callback() + component = ChatGPTGenerator(os.environ.get("OPENAI_API_KEY"), streaming_callback=callback) + results = component.run( + prompts=["What's the capital of France?", "What's the capital of Germany?"], model_parameters={"n": 1} + ) + + assert len(results["replies"]) == 2 + assert len(results["replies"][0]) == 1 + assert "Paris" in results["replies"][0][0] + assert len(results["replies"][1]) == 1 + assert "Berlin" in results["replies"][1][0] + + assert callback.responses == results["replies"][0][0] + results["replies"][1][0] + + assert len(results["metadata"]) == 2 + assert len(results["metadata"][0]) == 1 + assert "gpt-3.5-turbo" in results["metadata"][0][0]["model"] + assert "stop" == results["metadata"][0][0]["finish_reason"] + assert len(results["metadata"][1]) == 1 + assert "gpt-3.5-turbo" in results["metadata"][1][0]["model"] + assert "stop" == results["metadata"][1][0]["finish_reason"] diff --git a/haystack/preview/components/generators/openai/_helpers.py b/haystack/preview/components/generators/openai/_helpers.py deleted file mode 100644 index 946901b644..0000000000 --- a/haystack/preview/components/generators/openai/_helpers.py +++ /dev/null @@ -1,33 +0,0 @@ -import logging - -from haystack.preview.lazy_imports import LazyImport - -with LazyImport("Run 'pip install tiktoken'") as tiktoken_import: - import tiktoken - - -logger = logging.getLogger(__name__) - - -def enforce_token_limit(prompt: str, tokenizer: "tiktoken.Encoding", max_tokens_limit: int) -> str: - """ - Ensure that the length of the prompt is within the max tokens limit of the model. - If needed, truncate the prompt text so that it fits within the limit. - - :param prompt: Prompt text to be sent to the generative model. - :param tokenizer: The tokenizer used to encode the prompt. - :param max_tokens_limit: The max tokens limit of the model. - :return: The prompt text that fits within the max tokens limit of the model. - """ - tiktoken_import.check() - tokens = tokenizer.encode(prompt) - tokens_count = len(tokens) - if tokens_count > max_tokens_limit: - logger.warning( - "The prompt has been truncated from %s tokens to %s tokens to fit within the max token limit. " - "Reduce the length of the prompt to prevent it from being cut off.", - tokens_count, - max_tokens_limit, - ) - prompt = tokenizer.decode(tokens[:max_tokens_limit]) - return prompt diff --git a/haystack/preview/components/generators/openai/chatgpt.py b/haystack/preview/components/generators/openai/chatgpt.py new file mode 100644 index 0000000000..fcb9047d75 --- /dev/null +++ b/haystack/preview/components/generators/openai/chatgpt.py @@ -0,0 +1,201 @@ +from typing import Optional, List, Callable, Dict, Any + +import sys +import builtins +import logging + +from haystack.preview import component, default_from_dict, default_to_dict, DeserializationError +from haystack.preview.llm_backends.openai.chatgpt import ChatGPTBackend +from haystack.preview.llm_backends.chat_message import ChatMessage + + +logger = logging.getLogger(__name__) + + +TOKENS_PER_MESSAGE_OVERHEAD = 4 + + +def default_streaming_callback(token: str, **kwargs): + """ + Default callback function for streaming responses from OpenAI API. + Prints the tokens to stdout as soon as they are received and returns them. + """ + print(token, flush=True, end="") + return token + + +@component +class ChatGPTGenerator: + """ + ChatGPT LLM Generator. + + Queries ChatGPT using OpenAI's GPT-3 ChatGPT API. Invocations are made using REST API. + See [OpenAI ChatGPT API](https://platform.openai.com/docs/guides/chat) for more details. + """ + + # TODO support function calling! + + def __init__( + self, + api_key: Optional[str] = None, + model_name: str = "gpt-3.5-turbo", + system_prompt: Optional[str] = None, + model_parameters: Optional[Dict[str, Any]] = None, + streaming_callback: Optional[Callable] = None, + api_base_url: str = "https://api.openai.com/v1", + ): + """ + Creates an instance of ChatGPTGenerator for OpenAI's GPT-3.5 model. + + :param api_key: The OpenAI API key. + :param model_name: The name of the model to use. + :param system_prompt: The prompt to be prepended to the user prompt. + :param streaming_callback: A callback function that is called when a new token is received from the stream. + The callback function should accept two parameters: the token received from the stream and **kwargs. + The callback function should return the token to be sent to the stream. If the callback function is not + provided, the token is printed to stdout. + :param api_base_url: The OpenAI API Base url, defaults to `https://api.openai.com/v1`. + :param model_parameters: A dictionary of parameters to use for the model. See OpenAI + [documentation](https://platform.openai.com/docs/api-reference/chat) for more details. Some of the supported + parameters: + - `max_tokens`: The maximum number of tokens the output text can have. + - `temperature`: What sampling temperature to use. Higher values means the model will take more risks. + Try 0.9 for more creative applications, and 0 (argmax sampling) for ones with a well-defined answer. + - `top_p`: An alternative to sampling with temperature, called nucleus sampling, where the model + considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens + comprising the top 10% probability mass are considered. + - `n`: How many completions to generate for each prompt. For example, if the LLM gets 3 prompts and n is 2, + it will generate two completions for each of the three prompts, ending up with 6 completions in total. + - `stop`: One or more sequences after which the LLM should stop generating tokens. + - `presence_penalty`: What penalty to apply if a token is already present at all. Bigger values mean + the model will be less likely to repeat the same token in the text. + - `frequency_penalty`: What penalty to apply if a token has already been generated in the text. + Bigger values mean the model will be less likely to repeat the same token in the text. + - `logit_bias`: Add a logit bias to specific tokens. The keys of the dictionary are tokens and the + values are the bias to add to that token. + - `openai_organization`: The OpenAI organization ID. + """ + self.llm = ChatGPTBackend( + api_key=api_key, + model_name=model_name, + model_parameters=model_parameters, + streaming_callback=streaming_callback, + api_base_url=api_base_url, + ) + self.system_prompt = system_prompt + + def to_dict(self) -> Dict[str, Any]: + """ + Serialize this component to a dictionary. + """ + if self.llm.streaming_callback: + module = sys.modules.get(self.llm.streaming_callback.__module__) + if not module: + raise ValueError("Could not locate the import module.") + if module == builtins: + callback_name = self.llm.streaming_callback.__name__ + else: + callback_name = f"{module.__name__}.{self.llm.streaming_callback.__name__}" + else: + callback_name = None + + return default_to_dict( + self, + api_key=self.llm.api_key, + model_name=self.llm.model_name, + model_parameters=self.llm.model_parameters, + system_prompt=self.system_prompt, + streaming_callback=callback_name, + api_base_url=self.llm.api_base_url, + ) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "ChatGPTGenerator": + """ + Deserialize this component from a dictionary. + """ + init_params = data.get("init_parameters", {}) + streaming_callback = None + if "streaming_callback" in init_params: + parts = init_params["streaming_callback"].split(".") + module_name = ".".join(parts[:-1]) + function_name = parts[-1] + module = sys.modules.get(module_name, None) + if not module: + raise DeserializationError(f"Could not locate the module of the streaming callback: {module_name}") + streaming_callback = getattr(module, function_name, None) + if not streaming_callback: + raise DeserializationError(f"Could not locate the streaming callback: {function_name}") + data["init_parameters"]["streaming_callback"] = streaming_callback + return default_from_dict(cls, data) + + @component.output_types(replies=List[List[str]], metadata=List[Dict[str, Any]]) + def run( + self, + prompts: List[str], + api_key: Optional[str] = None, + model_name: str = "gpt-3.5-turbo", + system_prompt: Optional[str] = None, + model_parameters: Optional[Dict[str, Any]] = None, + streaming_callback: Optional[Callable] = None, + api_base_url: str = "https://api.openai.com/v1", + ): + """ + Queries the LLM with the prompts to produce replies. + + :param prompts: The prompts to be sent to the generative model. + :param api_key: The OpenAI API key. + :param model_name: The name of the model to use. + :param system_prompt: The prompt to be prepended to the user prompt. + :param streaming_callback: A callback function that is called when a new token is received from the stream. + The callback function should accept two parameters: the token received from the stream and **kwargs. + The callback function should return the token to be sent to the stream. If the callback function is not + provided, the token is printed to stdout. + :param api_base_url: The OpenAI API Base url, defaults to `https://api.openai.com/v1`. + :param model_parameters: A dictionary of parameters to use for the model. See OpenAI + [documentation](https://platform.openai.com/docs/api-reference/chat) for more details. Some of the supported + parameters: + - `max_tokens`: The maximum number of tokens the output text can have. + - `temperature`: What sampling temperature to use. Higher values means the model will take more risks. + Try 0.9 for more creative applications, and 0 (argmax sampling) for ones with a well-defined answer. + - `top_p`: An alternative to sampling with temperature, called nucleus sampling, where the model + considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens + comprising the top 10% probability mass are considered. + - `n`: How many completions to generate for each prompt. For example, if the LLM gets 3 prompts and n is 2, + it will generate two completions for each of the three prompts, ending up with 6 completions in total. + - `stop`: One or more sequences after which the LLM should stop generating tokens. + - `presence_penalty`: What penalty to apply if a token is already present at all. Bigger values mean + the model will be less likely to repeat the same token in the text. + - `frequency_penalty`: What penalty to apply if a token has already been generated in the text. + Bigger values mean the model will be less likely to repeat the same token in the text. + - `logit_bias`: Add a logit bias to specific tokens. The keys of the dictionary are tokens and the + values are the bias to add to that token. + - `openai_organization`: The OpenAI organization ID. + + See OpenAI documentation](https://platform.openai.com/docs/api-reference/chat) for more details. + """ + system_prompt = system_prompt if system_prompt is not None else self.system_prompt + if system_prompt: + system_message = ChatMessage(content=system_prompt, role="system") + chats = [] + for prompt in prompts: + message = ChatMessage(content=prompt, role="user") + if system_prompt: + chats.append([system_message, message]) + else: + chats.append([message]) + + replies, metadata = [], [] + for chat in chats: + reply, meta = self.llm.complete( + chat=chat, + api_key=api_key, + model_name=model_name, + model_parameters=model_parameters, + streaming_callback=streaming_callback, + api_base_url=api_base_url, + ) + replies.append(reply) + metadata.append(meta) + + return {"replies": replies, "metadata": metadata} diff --git a/haystack/preview/llm_backends/openai/_helpers.py b/haystack/preview/llm_backends/openai/_helpers.py index 431fd72ef6..f87611d4fc 100644 --- a/haystack/preview/llm_backends/openai/_helpers.py +++ b/haystack/preview/llm_backends/openai/_helpers.py @@ -47,15 +47,6 @@ ) -def default_streaming_callback(token: str, **kwargs): - """ - Default callback function for streaming responses from OpenAI API. - Prints the tokens to stdout as soon as they are received and returns them. - """ - print(token, flush=True, end="") - return token - - @openai_retry def complete(url: str, headers: Dict[str, str], payload: Dict[str, Any]) -> Tuple[List[str], List[Dict[str, Any]]]: """ diff --git a/haystack/preview/llm_backends/openai/chatgpt.py b/haystack/preview/llm_backends/openai/chatgpt.py index 34e2e4211c..6b00f090f0 100644 --- a/haystack/preview/llm_backends/openai/chatgpt.py +++ b/haystack/preview/llm_backends/openai/chatgpt.py @@ -58,7 +58,7 @@ def __init__( Creates an instance of ChatGPTGenerator for OpenAI's GPT-3.5 model. :param api_key: The OpenAI API key. - :param model_name: The name or path of the underlying model. + :param model_name: The name of the model to use. :param streaming_callback: A callback function that is called when a new token is received from the stream. The callback function should accept two parameters: the token received from the stream and **kwargs. The callback function should return the token to be sent to the stream. If the callback function is not @@ -126,7 +126,7 @@ def complete( :param chat: The chat to be sent to the generative model. :param api_key: The OpenAI API key. - :param model_name: The name or path of the underlying model. + :param model_name: The name of the model to use. :param streaming_callback: A callback function that is called when a new token is received from the stream. The callback function should accept two parameters: the token received from the stream and **kwargs. The callback function should return the token to be sent to the stream. If the callback function is not diff --git a/test/preview/components/generators/openai/test_chatgpt_generator.py b/test/preview/components/generators/openai/test_chatgpt_generator.py new file mode 100644 index 0000000000..944dd84d5e --- /dev/null +++ b/test/preview/components/generators/openai/test_chatgpt_generator.py @@ -0,0 +1,149 @@ +from unittest.mock import patch + +import pytest + +from haystack.preview.components.generators.openai.chatgpt import ChatGPTGenerator +from haystack.preview.components.generators.openai.chatgpt import default_streaming_callback +from haystack.preview.llm_backends.openai.chatgpt import ChatGPTBackend, DEFAULT_OPENAI_PARAMS + + +class TestChatGPTGenerator: + @pytest.mark.unit + def test_init_default(self, caplog): + with patch("haystack.preview.llm_backends.openai.chatgpt.tiktoken") as tiktoken_patch: + component = ChatGPTGenerator() + assert component.system_prompt is None + assert component.llm.api_key is None + assert component.llm.model_name == "gpt-3.5-turbo" + assert component.llm.streaming_callback is None + assert component.llm.api_base_url == "https://api.openai.com/v1" + assert component.llm.model_parameters == DEFAULT_OPENAI_PARAMS + assert isinstance(component.llm, ChatGPTBackend) + + @pytest.mark.unit + def test_init_with_parameters(self, caplog): + with patch("haystack.preview.llm_backends.openai.chatgpt.tiktoken") as tiktoken_patch: + callback = lambda x: x + component = ChatGPTGenerator( + api_key="test-api-key", + model_name="gpt-4", + system_prompt="test-system-prompt", + model_parameters={"max_tokens": 10, "some-test-param": "test-params"}, + streaming_callback=callback, + api_base_url="test-base-url", + ) + assert component.system_prompt == "test-system-prompt" + assert component.llm.api_key == "test-api-key" + assert component.llm.model_name == "gpt-4" + assert component.llm.streaming_callback == callback + assert component.llm.api_base_url == "test-base-url" + assert component.llm.model_parameters == { + **DEFAULT_OPENAI_PARAMS, + "max_tokens": 10, + "some-test-param": "test-params", + } + + @pytest.mark.unit + def test_to_dict_default(self): + with patch("haystack.preview.llm_backends.openai.chatgpt.tiktoken") as tiktoken_patch: + component = ChatGPTGenerator() + data = component.to_dict() + assert data == { + "type": "ChatGPTGenerator", + "init_parameters": { + "api_key": None, + "model_name": "gpt-3.5-turbo", + "system_prompt": None, + "model_parameters": DEFAULT_OPENAI_PARAMS, + "streaming_callback": None, + "api_base_url": "https://api.openai.com/v1", + }, + } + + @pytest.mark.unit + def test_to_dict_with_parameters(self): + with patch("haystack.preview.llm_backends.openai.chatgpt.tiktoken") as tiktoken_patch: + component = ChatGPTGenerator( + api_key="test-api-key", + model_name="gpt-4", + system_prompt="test-system-prompt", + model_parameters={"max_tokens": 10, "some-test-params": "test-params"}, + streaming_callback=default_streaming_callback, + api_base_url="test-base-url", + ) + data = component.to_dict() + assert data == { + "type": "ChatGPTGenerator", + "init_parameters": { + "api_key": "test-api-key", + "model_name": "gpt-4", + "system_prompt": "test-system-prompt", + "model_parameters": {**DEFAULT_OPENAI_PARAMS, "max_tokens": 10, "some-test-params": "test-params"}, + "api_base_url": "test-base-url", + "streaming_callback": "haystack.preview.components.generators.openai.chatgpt.default_streaming_callback", + }, + } + + @pytest.mark.unit + def test_from_dict(self): + with patch("haystack.preview.llm_backends.openai.chatgpt.tiktoken") as tiktoken_patch: + data = { + "type": "ChatGPTGenerator", + "init_parameters": { + "api_key": "test-api-key", + "model_name": "gpt-4", + "system_prompt": "test-system-prompt", + "model_parameters": {"max_tokens": 10, "some-test-params": "test-params"}, + "api_base_url": "test-base-url", + "streaming_callback": "haystack.preview.components.generators.openai.chatgpt.default_streaming_callback", + }, + } + component = ChatGPTGenerator.from_dict(data) + assert component.system_prompt == "test-system-prompt" + assert component.llm.api_key == "test-api-key" + assert component.llm.model_name == "gpt-4" + assert component.llm.streaming_callback == default_streaming_callback + assert component.llm.api_base_url == "test-base-url" + assert component.llm.model_parameters == { + **DEFAULT_OPENAI_PARAMS, + "max_tokens": 10, + "some-test-params": "test-params", + } + + @pytest.mark.unit + def test_run_no_api_key(self): + with patch("haystack.preview.llm_backends.openai.chatgpt.tiktoken") as tiktoken_patch: + component = ChatGPTGenerator() + with pytest.raises(ValueError, match="OpenAI API key is missing. Please provide an API key."): + component.run(prompts=["test"]) + + @pytest.mark.unit + def test_run_no_system_prompt(self): + with patch("haystack.preview.components.generators.openai.chatgpt.ChatGPTBackend") as chatgpt_patch: + chatgpt_patch.return_value.complete.side_effect = lambda chat, **kwargs: ( + [f"{msg.role}: {msg.content}" for msg in chat], + {"some_info": None}, + ) + component = ChatGPTGenerator(api_key="test-api-key") + results = component.run(prompts=["test-prompt-1", "test-prompt-2"]) + assert results == { + "replies": [["user: test-prompt-1"], ["user: test-prompt-2"]], + "metadata": [{"some_info": None}, {"some_info": None}], + } + + @pytest.mark.unit + def test_run_with_system_prompt(self): + with patch("haystack.preview.components.generators.openai.chatgpt.ChatGPTBackend") as chatgpt_patch: + chatgpt_patch.return_value.complete.side_effect = lambda chat, **kwargs: ( + [f"{msg.role}: {msg.content}" for msg in chat], + {"some_info": None}, + ) + component = ChatGPTGenerator(api_key="test-api-key", system_prompt="test-system-prompt") + results = component.run(prompts=["test-prompt-1", "test-prompt-2"]) + assert results == { + "replies": [ + ["system: test-system-prompt", "user: test-prompt-1"], + ["system: test-system-prompt", "user: test-prompt-2"], + ], + "metadata": [{"some_info": None}, {"some_info": None}], + } diff --git a/test/preview/components/generators/openai/test_openai_helpers.py b/test/preview/components/generators/openai/test_openai_helpers.py deleted file mode 100644 index 736d7f3dd5..0000000000 --- a/test/preview/components/generators/openai/test_openai_helpers.py +++ /dev/null @@ -1,252 +0,0 @@ -from unittest.mock import Mock, patch -import json - -import pytest - -from haystack.preview.llm_backends.openai.errors import OpenAIUnauthorizedError, OpenAIError, OpenAIRateLimitError -from haystack.preview.llm_backends.openai._helpers import ( - ChatMessage, - raise_for_status, - check_truncated_answers, - complete, - complete_stream, - enforce_token_limit, - enforce_token_limit_chat, - OPENAI_TIMEOUT, - OPENAI_MAX_RETRIES, -) - - -@pytest.mark.unit -def test_raise_for_status_200(): - response = Mock() - response.status_code = 200 - raise_for_status(response) - - -@pytest.mark.unit -def test_raise_for_status_401(): - response = Mock() - response.status_code = 401 - with pytest.raises(OpenAIUnauthorizedError): - raise_for_status(response) - - -@pytest.mark.unit -def test_raise_for_status_429(): - response = Mock() - response.status_code = 429 - with pytest.raises(OpenAIRateLimitError): - raise_for_status(response) - - -@pytest.mark.unit -def test_raise_for_status_500(): - response = Mock() - response.status_code = 500 - response.text = "Internal Server Error" - with pytest.raises(OpenAIError): - raise_for_status(response) - - -@pytest.mark.unit -def test_check_truncated_answers(caplog): - result = { - "choices": [ - {"finish_reason": "length"}, - {"finish_reason": "content_filter"}, - {"finish_reason": "length"}, - {"finish_reason": "stop"}, - ] - } - payload = {"n": 4} - check_truncated_answers(result, payload) - assert caplog.records[0].message == ( - "2 out of the 4 completions have been truncated before reaching a natural " - "stopping point. Increase the max_tokens parameter to allow for longer completions." - ) - - -@pytest.mark.unit -def test_query_chat_model(): - with patch("haystack.preview.llm_backends.openai._helpers.requests.post") as mock_post: - response = Mock() - response.status_code = 200 - response.text = """ - { - "model": "test-model", - "choices": [ - { - "index": 0, - "finish_reason": "stop", - "message": {"content": " Hello, how are you? "} - } - ], - "usage": { - "prompt_tokens": 4, - "completion_tokens": 5, - "total_tokens": 9 - } - - }""" - mock_post.return_value = response - replies, metadata = complete(url="test-url", headers={"header": "test-header"}, payload={"param": "test-param"}) - mock_post.assert_called_once_with( - "test-url", - headers={"header": "test-header"}, - data=json.dumps({"param": "test-param"}), - timeout=OPENAI_TIMEOUT, - ) - assert replies == ["Hello, how are you?"] - assert metadata == [ - { - "model": "test-model", - "index": 0, - "finish_reason": "stop", - "prompt_tokens": 4, - "completion_tokens": 5, - "total_tokens": 9, - } - ] - - -@pytest.mark.unit -def test_query_chat_model_fail(): - with patch("haystack.preview.llm_backends.openai._helpers.requests.post") as mock_post: - response = Mock() - response.status_code = 500 - mock_post.return_value = response - with pytest.raises(OpenAIError): - complete(url="test-url", headers={"header": "test-header"}, payload={"param": "test-param"}) - mock_post.assert_called_with( - "test-url", - headers={"header": "test-header"}, - data=json.dumps({"param": "test-param"}), - timeout=OPENAI_TIMEOUT, - ) - mock_post.call_count == OPENAI_MAX_RETRIES - - -def mock_chat_completion_stream(model="test-model", index=0, token="test", finish_reason="stop"): - return Mock( - data=f"""{{ - "model": "{model}", - "choices": [ - {{ - "index": {index}, - "delta": {{"content": "{token}"}}, - "finish_reason": "{finish_reason}" - }} - ] - }}""" - ) - - -@pytest.mark.unit -def test_query_chat_model_stream(): - with patch("haystack.preview.llm_backends.openai._helpers.requests.post") as mock_post: - with patch("haystack.preview.llm_backends.openai._helpers.sseclient.SSEClient") as mock_sseclient: - callback = lambda token, event_data: f"|{token}|" - response = Mock() - response.status_code = 200 - - mock_sseclient.return_value.events.return_value = [ - mock_chat_completion_stream(token="Hello"), - mock_chat_completion_stream(token=","), - mock_chat_completion_stream(token=" how"), - mock_chat_completion_stream(token=" are"), - mock_chat_completion_stream(token=" you"), - mock_chat_completion_stream(token="?"), - Mock(data="[DONE]"), - mock_chat_completion_stream(token="discarded tokens"), - ] - - mock_post.return_value = response - replies, metadata = complete_stream( - url="test-url", headers={"header": "test-header"}, payload={"param": "test-param"}, callback=callback - ) - mock_post.assert_called_once_with( - "test-url", - headers={"header": "test-header"}, - data=json.dumps({"param": "test-param"}), - timeout=OPENAI_TIMEOUT, - stream=True, - ) - assert replies == ["|Hello||,|| how|| are|| you||?|"] - assert metadata == [{"model": "test-model", "index": 0, "finish_reason": "stop"}] - - -@pytest.mark.unit -def test_query_chat_model_stream_fail(): - with patch("haystack.preview.llm_backends.openai._helpers.requests.post") as mock_post: - callback = Mock() - response = Mock() - response.status_code = 500 - mock_post.return_value = response - with pytest.raises(OpenAIError): - complete_stream( - url="test-url", headers={"header": "test-header"}, payload={"param": "test-param"}, callback=callback - ) - mock_post.assert_called_with( - "test-url", - headers={"header": "test-header"}, - data=json.dumps({"param": "test-param"}), - timeout=OPENAI_TIMEOUT, - ) - mock_post.call_count == OPENAI_MAX_RETRIES - - -@pytest.mark.unit -def test_enforce_token_limit_above_limit(caplog, mock_tokenizer): - prompt = enforce_token_limit("This is a test prompt.", tokenizer=mock_tokenizer, max_tokens_limit=3) - assert prompt == "This is a" - assert caplog.records[0].message == ( - "The prompt has been truncated from 5 tokens to 3 tokens to fit within the max token " - "limit. Reduce the length of the prompt to prevent it from being cut off." - ) - - -@pytest.mark.unit -def test_enforce_token_limit_below_limit(caplog, mock_tokenizer): - prompt = enforce_token_limit("This is a test prompt.", tokenizer=mock_tokenizer, max_tokens_limit=100) - assert prompt == "This is a test prompt." - assert not caplog.records - - -@pytest.mark.unit -def test_enforce_token_limit_chat_above_limit(caplog, mock_tokenizer): - prompts = enforce_token_limit_chat( - [ - ChatMessage(content="System Prompt", role="system"), - ChatMessage(content="This is a test prompt.", role="user"), - ], - tokenizer=mock_tokenizer, - max_tokens_limit=7, - tokens_per_message_overhead=2, - ) - assert prompts == [ - ChatMessage(content="System Prompt", role="system"), - ChatMessage(content="This is a", role="user"), - ] - assert caplog.records[0].message == ( - "The chat have been truncated from 11 tokens to 7 tokens to fit within the max token limit. " - "Reduce the length of the chat to prevent it from being cut off." - ) - - -@pytest.mark.unit -def test_enforce_token_limit_chat_below_limit(caplog, mock_tokenizer): - prompts = enforce_token_limit_chat( - [ - ChatMessage(content="System Prompt", role="system"), - ChatMessage(content="This is a test prompt.", role="user"), - ], - tokenizer=mock_tokenizer, - max_tokens_limit=100, - tokens_per_message_overhead=2, - ) - assert prompts == [ - ChatMessage(content="System Prompt", role="system"), - ChatMessage(content="This is a test prompt.", role="user"), - ] - assert not caplog.records