diff --git a/e2e/preview/components/test_gpt35_generator.py b/e2e/preview/components/test_gpt35_generator.py new file mode 100644 index 0000000000..c70b8033f5 --- /dev/null +++ b/e2e/preview/components/test_gpt35_generator.py @@ -0,0 +1,86 @@ +import os +import pytest +import openai +from haystack.preview.components.generators.openai.gpt35 import GPT35Generator + + +@pytest.mark.skipif( + not os.environ.get("OPENAI_API_KEY", None), + reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.", +) +def test_gpt35_generator_run(): + component = GPT35Generator(api_key=os.environ.get("OPENAI_API_KEY"), n=1) + results = component.run(prompts=["What's the capital of France?", "What's the capital of Germany?"]) + + assert len(results["replies"]) == 2 + assert len(results["replies"][0]) == 1 + assert "Paris" in results["replies"][0][0] + assert len(results["replies"][1]) == 1 + assert "Berlin" in results["replies"][1][0] + + assert len(results["metadata"]) == 2 + assert len(results["metadata"][0]) == 1 + assert "gpt-3.5-turbo" in results["metadata"][0][0]["model"] + assert "stop" == results["metadata"][0][0]["finish_reason"] + assert len(results["metadata"][1]) == 1 + assert "gpt-3.5-turbo" in results["metadata"][1][0]["model"] + assert "stop" == results["metadata"][1][0]["finish_reason"] + + +@pytest.mark.skipif( + not os.environ.get("OPENAI_API_KEY", None), + reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.", +) +def test_gpt35_generator_run_wrong_model_name(): + component = GPT35Generator(model_name="something-obviously-wrong", api_key=os.environ.get("OPENAI_API_KEY"), n=1) + with pytest.raises(openai.InvalidRequestError, match="The model `something-obviously-wrong` does not exist"): + component.run(prompts=["What's the capital of France?"]) + + +@pytest.mark.skipif( + not os.environ.get("OPENAI_API_KEY", None), + reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.", +) +def test_gpt35_generator_run_above_context_length(): + component = GPT35Generator(api_key=os.environ.get("OPENAI_API_KEY"), n=1) + with pytest.raises( + openai.InvalidRequestError, + match="This model's maximum context length is 4097 tokens. However, your messages resulted in 70008 tokens. " + "Please reduce the length of the messages.", + ): + component.run(prompts=["What's the capital of France? " * 10_000]) + + +@pytest.mark.skipif( + not os.environ.get("OPENAI_API_KEY", None), + reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.", +) +def test_gpt35_generator_run_streaming(): + class Callback: + def __init__(self): + self.responses = "" + + def __call__(self, chunk): + self.responses += chunk.choices[0].delta.content if chunk.choices[0].delta else "" + return chunk + + callback = Callback() + component = GPT35Generator(os.environ.get("OPENAI_API_KEY"), streaming_callback=callback, n=1) + results = component.run(prompts=["What's the capital of France?", "What's the capital of Germany?"]) + + assert len(results["replies"]) == 2 + assert len(results["replies"][0]) == 1 + assert "Paris" in results["replies"][0][0] + assert len(results["replies"][1]) == 1 + assert "Berlin" in results["replies"][1][0] + + assert callback.responses == results["replies"][0][0] + results["replies"][1][0] + + assert len(results["metadata"]) == 2 + assert len(results["metadata"][0]) == 1 + + assert "gpt-3.5-turbo" in results["metadata"][0][0]["model"] + assert "stop" == results["metadata"][0][0]["finish_reason"] + assert len(results["metadata"][1]) == 1 + assert "gpt-3.5-turbo" in results["metadata"][1][0]["model"] + assert "stop" == results["metadata"][1][0]["finish_reason"] diff --git a/haystack/preview/components/generators/openai/_helpers.py b/haystack/preview/components/generators/openai/_helpers.py deleted file mode 100644 index 946901b644..0000000000 --- a/haystack/preview/components/generators/openai/_helpers.py +++ /dev/null @@ -1,33 +0,0 @@ -import logging - -from haystack.preview.lazy_imports import LazyImport - -with LazyImport("Run 'pip install tiktoken'") as tiktoken_import: - import tiktoken - - -logger = logging.getLogger(__name__) - - -def enforce_token_limit(prompt: str, tokenizer: "tiktoken.Encoding", max_tokens_limit: int) -> str: - """ - Ensure that the length of the prompt is within the max tokens limit of the model. - If needed, truncate the prompt text so that it fits within the limit. - - :param prompt: Prompt text to be sent to the generative model. - :param tokenizer: The tokenizer used to encode the prompt. - :param max_tokens_limit: The max tokens limit of the model. - :return: The prompt text that fits within the max tokens limit of the model. - """ - tiktoken_import.check() - tokens = tokenizer.encode(prompt) - tokens_count = len(tokens) - if tokens_count > max_tokens_limit: - logger.warning( - "The prompt has been truncated from %s tokens to %s tokens to fit within the max token limit. " - "Reduce the length of the prompt to prevent it from being cut off.", - tokens_count, - max_tokens_limit, - ) - prompt = tokenizer.decode(tokens[:max_tokens_limit]) - return prompt diff --git a/haystack/preview/components/generators/openai/gpt35.py b/haystack/preview/components/generators/openai/gpt35.py new file mode 100644 index 0000000000..e58b1ffa05 --- /dev/null +++ b/haystack/preview/components/generators/openai/gpt35.py @@ -0,0 +1,213 @@ +from typing import Optional, List, Callable, Dict, Any + +import sys +import logging +from dataclasses import dataclass, asdict + +import openai + +from haystack.preview import component, default_from_dict, default_to_dict, DeserializationError + + +logger = logging.getLogger(__name__) + + +@dataclass +class _ChatMessage: + content: str + role: str + + +def default_streaming_callback(chunk): + """ + Default callback function for streaming responses from OpenAI API. + Prints the tokens of the first completion to stdout as soon as they are received and returns the chunk unchanged. + """ + if hasattr(chunk.choices[0].delta, "content"): + print(chunk.choices[0].delta.content, flush=True, end="") + return chunk + + +@component +class GPT35Generator: + """ + LLM Generator compatible with GPT3.5 (ChatGPT) large language models. + + Queries the LLM using OpenAI's API. Invocations are made using OpenAI SDK ('openai' package) + See [OpenAI GPT3.5 API](https://platform.openai.com/docs/guides/chat) for more details. + """ + + def __init__( + self, + api_key: str, + model_name: str = "gpt-3.5-turbo", + system_prompt: Optional[str] = None, + streaming_callback: Optional[Callable] = None, + api_base_url: str = "https://api.openai.com/v1", + **kwargs, + ): + """ + Creates an instance of GPT35Generator for OpenAI's GPT-3.5 model. + + :param api_key: The OpenAI API key. + :param model_name: The name of the model to use. + :param system_prompt: An additional message to be sent to the LLM at the beginning of each conversation. + Typically, a conversation is formatted with a system message first, followed by alternating messages from + the 'user' (the "queries") and the 'assistant' (the "responses"). The system message helps set the behavior + of the assistant. For example, you can modify the personality of the assistant or provide specific + instructions about how it should behave throughout the conversation. + :param streaming_callback: A callback function that is called when a new token is received from the stream. + The callback function should accept two parameters: the token received from the stream and **kwargs. + The callback function should return the token to be sent to the stream. If the callback function is not + provided, the token is printed to stdout. + :param api_base_url: The OpenAI API Base url, defaults to `https://api.openai.com/v1`. + :param kwargs: Other parameters to use for the model. These parameters are all sent directly to the OpenAI + endpoint. See OpenAI [documentation](https://platform.openai.com/docs/api-reference/chat) for more details. + Some of the supported parameters: + - `max_tokens`: The maximum number of tokens the output text can have. + - `temperature`: What sampling temperature to use. Higher values mean the model will take more risks. + Try 0.9 for more creative applications, and 0 (argmax sampling) for ones with a well-defined answer. + - `top_p`: An alternative to sampling with temperature, called nucleus sampling, where the model + considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens + comprising the top 10% probability mass are considered. + - `n`: How many completions to generate for each prompt. For example, if the LLM gets 3 prompts and n is 2, + it will generate two completions for each of the three prompts, ending up with 6 completions in total. + - `stop`: One or more sequences after which the LLM should stop generating tokens. + - `presence_penalty`: What penalty to apply if a token is already present at all. Bigger values mean + the model will be less likely to repeat the same token in the text. + - `frequency_penalty`: What penalty to apply if a token has already been generated in the text. + Bigger values mean the model will be less likely to repeat the same token in the text. + - `logit_bias`: Add a logit bias to specific tokens. The keys of the dictionary are tokens and the + values are the bias to add to that token. + """ + self.api_key = api_key + self.model_name = model_name + self.system_prompt = system_prompt + self.model_parameters = kwargs + self.streaming_callback = streaming_callback + self.api_base_url = api_base_url + + def to_dict(self) -> Dict[str, Any]: + """ + Serialize this component to a dictionary. + """ + if self.streaming_callback: + module = self.streaming_callback.__module__ + if module == "builtins": + callback_name = self.streaming_callback.__name__ + else: + callback_name = f"{module}.{self.streaming_callback.__name__}" + else: + callback_name = None + + return default_to_dict( + self, + api_key=self.api_key, + model_name=self.model_name, + system_prompt=self.system_prompt, + streaming_callback=callback_name, + api_base_url=self.api_base_url, + **self.model_parameters, + ) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "GPT35Generator": + """ + Deserialize this component from a dictionary. + """ + init_params = data.get("init_parameters", {}) + streaming_callback = None + if "streaming_callback" in init_params: + parts = init_params["streaming_callback"].split(".") + module_name = ".".join(parts[:-1]) + function_name = parts[-1] + module = sys.modules.get(module_name, None) + if not module: + raise DeserializationError(f"Could not locate the module of the streaming callback: {module_name}") + streaming_callback = getattr(module, function_name, None) + if not streaming_callback: + raise DeserializationError(f"Could not locate the streaming callback: {function_name}") + data["init_parameters"]["streaming_callback"] = streaming_callback + return default_from_dict(cls, data) + + @component.output_types(replies=List[List[str]], metadata=List[Dict[str, Any]]) + def run(self, prompts: List[str]): + """ + Queries the LLM with the prompts to produce replies. + + :param prompts: The prompts to be sent to the generative model. + """ + chats = [] + for prompt in prompts: + message = _ChatMessage(content=prompt, role="user") + if self.system_prompt: + chats.append([_ChatMessage(content=self.system_prompt, role="system"), message]) + else: + chats.append([message]) + + all_replies, all_metadata = [], [] + for chat in chats: + completion = openai.ChatCompletion.create( + model=self.model_name, + api_key=self.api_key, + messages=[asdict(message) for message in chat], + stream=self.streaming_callback is not None, + **self.model_parameters, + ) + + replies: List[str] + metadata: List[Dict[str, Any]] + if self.streaming_callback: + replies_dict = {} + metadata_dict: Dict[str, Dict[str, Any]] = {} + for chunk in completion: + chunk = self.streaming_callback(chunk) + for choice in chunk.choices: + if choice.index not in replies_dict: + replies_dict[choice.index] = "" + metadata_dict[choice.index] = {} + + if hasattr(choice.delta, "content"): + replies_dict[choice.index] += choice.delta.content + metadata_dict[choice.index] = { + "model": chunk.model, + "index": choice.index, + "finish_reason": choice.finish_reason, + } + all_replies.append(list(replies_dict.values())) + all_metadata.append(list(metadata_dict.values())) + self._check_truncated_answers(list(metadata_dict.values())) + + else: + metadata = [ + { + "model": completion.model, + "index": choice.index, + "finish_reason": choice.finish_reason, + "usage": dict(completion.usage.items()), + } + for choice in completion.choices + ] + replies = [choice.message.content.strip() for choice in completion.choices] + all_replies.append(replies) + all_metadata.append(metadata) + self._check_truncated_answers(metadata) + + return {"replies": all_replies, "metadata": all_metadata} + + def _check_truncated_answers(self, metadata: List[Dict[str, Any]]): + """ + Check the `finish_reason` returned with the OpenAI completions. + If the `finish_reason` is `length`, log a warning to the user. + + :param result: The result returned from the OpenAI API. + :param payload: The payload sent to the OpenAI API. + """ + truncated_completions = sum(1 for meta in metadata if meta.get("finish_reason") != "stop") + if truncated_completions > 0: + logger.warning( + "%s out of the %s completions have been truncated before reaching a natural stopping point. " + "Increase the max_tokens parameter to allow for longer completions.", + truncated_completions, + len(metadata), + ) diff --git a/pyproject.toml b/pyproject.toml index cd3f2068a3..1b2a577eba 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -80,6 +80,7 @@ dependencies = [ # Preview "canals==0.8.0", + "openai", "Jinja2", "openai-whisper", # FIXME https://github.com/deepset-ai/haystack/issues/5731 diff --git a/releasenotes/notes/chatgpt-llm-generator-d043532654efe684.yaml b/releasenotes/notes/chatgpt-llm-generator-d043532654efe684.yaml new file mode 100644 index 0000000000..13d9491a1e --- /dev/null +++ b/releasenotes/notes/chatgpt-llm-generator-d043532654efe684.yaml @@ -0,0 +1,2 @@ +preview: + - Introduce `GPT35Generator`, a class that can generate completions using OpenAI Chat models like GPT3.5 and GPT4. diff --git a/test/preview/components/generators/openai/test_gpt35_generator.py b/test/preview/components/generators/openai/test_gpt35_generator.py new file mode 100644 index 0000000000..c4bc9c512f --- /dev/null +++ b/test/preview/components/generators/openai/test_gpt35_generator.py @@ -0,0 +1,332 @@ +from unittest.mock import patch, Mock +from copy import deepcopy + +import pytest +import openai +from openai.util import convert_to_openai_object + +from haystack.preview.components.generators.openai.gpt35 import GPT35Generator +from haystack.preview.components.generators.openai.gpt35 import default_streaming_callback + + +def mock_openai_response(messages: str, model: str = "gpt-3.5-turbo-0301", **kwargs) -> openai.ChatCompletion: + response = f"response for these messages --> {' - '.join(msg['role']+': '+msg['content'] for msg in messages)}" + base_dict = { + "id": "chatcmpl-7NaPEA6sgX7LnNPyKPbRlsyqLbr5V", + "object": "chat.completion", + "created": 1685855844, + "model": model, + "usage": {"prompt_tokens": 57, "completion_tokens": 40, "total_tokens": 97}, + } + base_dict["choices"] = [ + {"message": {"role": "assistant", "content": response}, "finish_reason": "stop", "index": "0"} + ] + return convert_to_openai_object(deepcopy(base_dict)) + + +def mock_openai_stream_response(messages: str, model: str = "gpt-3.5-turbo-0301", **kwargs) -> openai.ChatCompletion: + response = f"response for these messages --> {' - '.join(msg['role']+': '+msg['content'] for msg in messages)}" + base_dict = { + "id": "chatcmpl-7NaPEA6sgX7LnNPyKPbRlsyqLbr5V", + "object": "chat.completion", + "created": 1685855844, + "model": model, + } + base_dict["choices"] = [{"delta": {"role": "assistant"}, "finish_reason": None, "index": "0"}] + yield convert_to_openai_object(base_dict) + for token in response.split(): + base_dict["choices"][0]["delta"] = {"content": token + " "} + yield convert_to_openai_object(base_dict) + base_dict["choices"] = [{"delta": {"content": ""}, "finish_reason": "stop", "index": "0"}] + yield convert_to_openai_object(base_dict) + + +class TestGPT35Generator: + @pytest.mark.unit + def test_init_default(self): + component = GPT35Generator(api_key="test-api-key") + assert component.system_prompt is None + assert component.api_key == "test-api-key" + assert component.model_name == "gpt-3.5-turbo" + assert component.streaming_callback is None + assert component.api_base_url == "https://api.openai.com/v1" + assert component.model_parameters == {} + + @pytest.mark.unit + def test_init_with_parameters(self): + callback = lambda x: x + component = GPT35Generator( + api_key="test-api-key", + model_name="gpt-4", + system_prompt="test-system-prompt", + max_tokens=10, + some_test_param="test-params", + streaming_callback=callback, + api_base_url="test-base-url", + ) + assert component.system_prompt == "test-system-prompt" + assert component.api_key == "test-api-key" + assert component.model_name == "gpt-4" + assert component.streaming_callback == callback + assert component.api_base_url == "test-base-url" + assert component.model_parameters == {"max_tokens": 10, "some_test_param": "test-params"} + + @pytest.mark.unit + def test_to_dict_default(self): + component = GPT35Generator(api_key="test-api-key") + data = component.to_dict() + assert data == { + "type": "GPT35Generator", + "init_parameters": { + "api_key": "test-api-key", + "model_name": "gpt-3.5-turbo", + "system_prompt": None, + "streaming_callback": None, + "api_base_url": "https://api.openai.com/v1", + }, + } + + @pytest.mark.unit + def test_to_dict_with_parameters(self): + component = GPT35Generator( + api_key="test-api-key", + model_name="gpt-4", + system_prompt="test-system-prompt", + max_tokens=10, + some_test_param="test-params", + streaming_callback=default_streaming_callback, + api_base_url="test-base-url", + ) + data = component.to_dict() + assert data == { + "type": "GPT35Generator", + "init_parameters": { + "api_key": "test-api-key", + "model_name": "gpt-4", + "system_prompt": "test-system-prompt", + "max_tokens": 10, + "some_test_param": "test-params", + "api_base_url": "test-base-url", + "streaming_callback": "haystack.preview.components.generators.openai.gpt35.default_streaming_callback", + }, + } + + @pytest.mark.unit + def test_to_dict_with_lambda_streaming_callback(self): + component = GPT35Generator( + api_key="test-api-key", + model_name="gpt-4", + system_prompt="test-system-prompt", + max_tokens=10, + some_test_param="test-params", + streaming_callback=lambda x: x, + api_base_url="test-base-url", + ) + data = component.to_dict() + assert data == { + "type": "GPT35Generator", + "init_parameters": { + "api_key": "test-api-key", + "model_name": "gpt-4", + "system_prompt": "test-system-prompt", + "max_tokens": 10, + "some_test_param": "test-params", + "api_base_url": "test-base-url", + "streaming_callback": "test_gpt35_generator.", + }, + } + + @pytest.mark.unit + def test_from_dict(self): + data = { + "type": "GPT35Generator", + "init_parameters": { + "api_key": "test-api-key", + "model_name": "gpt-4", + "system_prompt": "test-system-prompt", + "max_tokens": 10, + "some_test_param": "test-params", + "api_base_url": "test-base-url", + "streaming_callback": "haystack.preview.components.generators.openai.gpt35.default_streaming_callback", + }, + } + component = GPT35Generator.from_dict(data) + assert component.system_prompt == "test-system-prompt" + assert component.api_key == "test-api-key" + assert component.model_name == "gpt-4" + assert component.streaming_callback == default_streaming_callback + assert component.api_base_url == "test-base-url" + assert component.model_parameters == {"max_tokens": 10, "some_test_param": "test-params"} + + @pytest.mark.unit + def test_run_no_system_prompt(self): + with patch("haystack.preview.components.generators.openai.gpt35.openai.ChatCompletion") as gpt35_patch: + gpt35_patch.create.side_effect = mock_openai_response + component = GPT35Generator(api_key="test-api-key") + results = component.run(prompts=["test-prompt-1", "test-prompt-2"]) + assert results == { + "replies": [ + ["response for these messages --> user: test-prompt-1"], + ["response for these messages --> user: test-prompt-2"], + ], + "metadata": [ + [ + { + "model": "gpt-3.5-turbo", + "index": "0", + "finish_reason": "stop", + "usage": {"prompt_tokens": 57, "completion_tokens": 40, "total_tokens": 97}, + } + ], + [ + { + "model": "gpt-3.5-turbo", + "index": "0", + "finish_reason": "stop", + "usage": {"prompt_tokens": 57, "completion_tokens": 40, "total_tokens": 97}, + } + ], + ], + } + assert gpt35_patch.create.call_count == 2 + gpt35_patch.create.assert_any_call( + model="gpt-3.5-turbo", + api_key="test-api-key", + messages=[{"role": "user", "content": "test-prompt-1"}], + stream=False, + ) + gpt35_patch.create.assert_any_call( + model="gpt-3.5-turbo", + api_key="test-api-key", + messages=[{"role": "user", "content": "test-prompt-2"}], + stream=False, + ) + + @pytest.mark.unit + def test_run_with_system_prompt(self): + with patch("haystack.preview.components.generators.openai.gpt35.openai.ChatCompletion") as gpt35_patch: + gpt35_patch.create.side_effect = mock_openai_response + component = GPT35Generator(api_key="test-api-key", system_prompt="test-system-prompt") + results = component.run(prompts=["test-prompt-1", "test-prompt-2"]) + assert results == { + "replies": [ + ["response for these messages --> system: test-system-prompt - user: test-prompt-1"], + ["response for these messages --> system: test-system-prompt - user: test-prompt-2"], + ], + "metadata": [ + [ + { + "model": "gpt-3.5-turbo", + "index": "0", + "finish_reason": "stop", + "usage": {"prompt_tokens": 57, "completion_tokens": 40, "total_tokens": 97}, + } + ], + [ + { + "model": "gpt-3.5-turbo", + "index": "0", + "finish_reason": "stop", + "usage": {"prompt_tokens": 57, "completion_tokens": 40, "total_tokens": 97}, + } + ], + ], + } + assert gpt35_patch.create.call_count == 2 + gpt35_patch.create.assert_any_call( + model="gpt-3.5-turbo", + api_key="test-api-key", + messages=[ + {"role": "system", "content": "test-system-prompt"}, + {"role": "user", "content": "test-prompt-1"}, + ], + stream=False, + ) + gpt35_patch.create.assert_any_call( + model="gpt-3.5-turbo", + api_key="test-api-key", + messages=[ + {"role": "system", "content": "test-system-prompt"}, + {"role": "user", "content": "test-prompt-2"}, + ], + stream=False, + ) + + @pytest.mark.unit + def test_run_with_parameters(self): + with patch("haystack.preview.components.generators.openai.gpt35.openai.ChatCompletion") as gpt35_patch: + gpt35_patch.create.side_effect = mock_openai_response + component = GPT35Generator(api_key="test-api-key", max_tokens=10) + component.run(prompts=["test-prompt-1", "test-prompt-2"]) + assert gpt35_patch.create.call_count == 2 + gpt35_patch.create.assert_any_call( + model="gpt-3.5-turbo", + api_key="test-api-key", + messages=[{"role": "user", "content": "test-prompt-1"}], + stream=False, + max_tokens=10, + ) + gpt35_patch.create.assert_any_call( + model="gpt-3.5-turbo", + api_key="test-api-key", + messages=[{"role": "user", "content": "test-prompt-2"}], + stream=False, + max_tokens=10, + ) + + @pytest.mark.unit + def test_run_stream(self): + with patch("haystack.preview.components.generators.openai.gpt35.openai.ChatCompletion") as gpt35_patch: + mock_callback = Mock() + mock_callback.side_effect = default_streaming_callback + gpt35_patch.create.side_effect = mock_openai_stream_response + component = GPT35Generator( + api_key="test-api-key", system_prompt="test-system-prompt", streaming_callback=mock_callback + ) + results = component.run(prompts=["test-prompt-1", "test-prompt-2"]) + assert results == { + "replies": [ + ["response for these messages --> system: test-system-prompt - user: test-prompt-1 "], + ["response for these messages --> system: test-system-prompt - user: test-prompt-2 "], + ], + "metadata": [ + [{"model": "gpt-3.5-turbo", "index": "0", "finish_reason": "stop"}], + [{"model": "gpt-3.5-turbo", "index": "0", "finish_reason": "stop"}], + ], + } + # Calls count: (10 tokens per prompt + 1 token for the role + 1 empty termination token) * 2 prompts + assert mock_callback.call_count == 24 + assert gpt35_patch.create.call_count == 2 + gpt35_patch.create.assert_any_call( + model="gpt-3.5-turbo", + api_key="test-api-key", + messages=[ + {"role": "system", "content": "test-system-prompt"}, + {"role": "user", "content": "test-prompt-1"}, + ], + stream=True, + ) + gpt35_patch.create.assert_any_call( + model="gpt-3.5-turbo", + api_key="test-api-key", + messages=[ + {"role": "system", "content": "test-system-prompt"}, + {"role": "user", "content": "test-prompt-2"}, + ], + stream=True, + ) + + @pytest.mark.unit + def test_check_truncated_answers(self, caplog): + component = GPT35Generator(api_key="test-api-key") + metadata = [ + {"finish_reason": "stop"}, + {"finish_reason": "content_filter"}, + {"finish_reason": "length"}, + {"finish_reason": "stop"}, + ] + component._check_truncated_answers(metadata) + assert caplog.records[0].message == ( + "2 out of the 4 completions have been truncated before reaching a natural " + "stopping point. Increase the max_tokens parameter to allow for longer completions." + ) diff --git a/test/preview/components/generators/openai/test_openai_helpers.py b/test/preview/components/generators/openai/test_openai_helpers.py deleted file mode 100644 index 23a66117d1..0000000000 --- a/test/preview/components/generators/openai/test_openai_helpers.py +++ /dev/null @@ -1,20 +0,0 @@ -import pytest - -from haystack.preview.components.generators.openai._helpers import enforce_token_limit - - -@pytest.mark.unit -def test_enforce_token_limit_above_limit(caplog, mock_tokenizer): - prompt = enforce_token_limit("This is a test prompt.", tokenizer=mock_tokenizer, max_tokens_limit=3) - assert prompt == "This is a" - assert caplog.records[0].message == ( - "The prompt has been truncated from 5 tokens to 3 tokens to fit within the max token " - "limit. Reduce the length of the prompt to prevent it from being cut off." - ) - - -@pytest.mark.unit -def test_enforce_token_limit_below_limit(caplog, mock_tokenizer): - prompt = enforce_token_limit("This is a test prompt.", tokenizer=mock_tokenizer, max_tokens_limit=100) - assert prompt == "This is a test prompt." - assert not caplog.records diff --git a/test/preview/conftest.py b/test/preview/conftest.py index b8abfa41a6..d8882ea230 100644 --- a/test/preview/conftest.py +++ b/test/preview/conftest.py @@ -1,4 +1,4 @@ -from unittest.mock import Mock +from unittest.mock import Mock, patch import pytest