-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
7 changed files
with
416 additions
and
296 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
import os | ||
import pytest | ||
from haystack.preview.components.generators.openai.chatgpt import ChatGPTGenerator | ||
|
||
|
||
@pytest.mark.skipif( | ||
not os.environ.get("OPENAI_API_KEY", None), | ||
reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.", | ||
) | ||
def test_chatgpt_generator_run(): | ||
component = ChatGPTGenerator(api_key=os.environ.get("OPENAI_API_KEY")) | ||
results = component.run( | ||
prompts=["What's the capital of France?", "What's the capital of Germany?"], model_parameters={"n": 1} | ||
) | ||
|
||
assert len(results["replies"]) == 2 | ||
assert len(results["replies"][0]) == 1 | ||
assert "Paris" in results["replies"][0][0] | ||
assert len(results["replies"][1]) == 1 | ||
assert "Berlin" in results["replies"][1][0] | ||
|
||
assert len(results["metadata"]) == 2 | ||
assert len(results["metadata"][0]) == 1 | ||
assert "gpt-3.5-turbo" in results["metadata"][0][0]["model"] | ||
assert "stop" == results["metadata"][0][0]["finish_reason"] | ||
assert len(results["metadata"][1]) == 1 | ||
assert "gpt-3.5-turbo" in results["metadata"][1][0]["model"] | ||
assert "stop" == results["metadata"][1][0]["finish_reason"] | ||
|
||
|
||
@pytest.mark.skipif( | ||
not os.environ.get("OPENAI_API_KEY", None), | ||
reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.", | ||
) | ||
def test_chatgpt_generator_run_streaming(): | ||
class Callback: | ||
def __init__(self): | ||
self.responses = "" | ||
|
||
def __call__(self, token, event_data): | ||
self.responses += token | ||
return token | ||
|
||
callback = Callback() | ||
component = ChatGPTGenerator(os.environ.get("OPENAI_API_KEY"), streaming_callback=callback) | ||
results = component.run( | ||
prompts=["What's the capital of France?", "What's the capital of Germany?"], model_parameters={"n": 1} | ||
) | ||
|
||
assert len(results["replies"]) == 2 | ||
assert len(results["replies"][0]) == 1 | ||
assert "Paris" in results["replies"][0][0] | ||
assert len(results["replies"][1]) == 1 | ||
assert "Berlin" in results["replies"][1][0] | ||
|
||
assert callback.responses == results["replies"][0][0] + results["replies"][1][0] | ||
|
||
assert len(results["metadata"]) == 2 | ||
assert len(results["metadata"][0]) == 1 | ||
assert "gpt-3.5-turbo" in results["metadata"][0][0]["model"] | ||
assert "stop" == results["metadata"][0][0]["finish_reason"] | ||
assert len(results["metadata"][1]) == 1 | ||
assert "gpt-3.5-turbo" in results["metadata"][1][0]["model"] | ||
assert "stop" == results["metadata"][1][0]["finish_reason"] |
This file was deleted.
Oops, something went wrong.
201 changes: 201 additions & 0 deletions
201
haystack/preview/components/generators/openai/chatgpt.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,201 @@ | ||
from typing import Optional, List, Callable, Dict, Any | ||
|
||
import sys | ||
import builtins | ||
import logging | ||
|
||
from haystack.preview import component, default_from_dict, default_to_dict, DeserializationError | ||
from haystack.preview.llm_backends.openai.chatgpt import ChatGPTBackend | ||
from haystack.preview.llm_backends.chat_message import ChatMessage | ||
|
||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
TOKENS_PER_MESSAGE_OVERHEAD = 4 | ||
|
||
|
||
def default_streaming_callback(token: str, **kwargs): | ||
""" | ||
Default callback function for streaming responses from OpenAI API. | ||
Prints the tokens to stdout as soon as they are received and returns them. | ||
""" | ||
print(token, flush=True, end="") | ||
return token | ||
|
||
|
||
@component | ||
class ChatGPTGenerator: | ||
""" | ||
ChatGPT LLM Generator. | ||
Queries ChatGPT using OpenAI's GPT-3 ChatGPT API. Invocations are made using REST API. | ||
See [OpenAI ChatGPT API](https://platform.openai.com/docs/guides/chat) for more details. | ||
""" | ||
|
||
# TODO support function calling! | ||
|
||
def __init__( | ||
self, | ||
api_key: Optional[str] = None, | ||
model_name: str = "gpt-3.5-turbo", | ||
system_prompt: Optional[str] = None, | ||
model_parameters: Optional[Dict[str, Any]] = None, | ||
streaming_callback: Optional[Callable] = None, | ||
api_base_url: str = "https://api.openai.com/v1", | ||
): | ||
""" | ||
Creates an instance of ChatGPTGenerator for OpenAI's GPT-3.5 model. | ||
:param api_key: The OpenAI API key. | ||
:param model_name: The name of the model to use. | ||
:param system_prompt: The prompt to be prepended to the user prompt. | ||
:param streaming_callback: A callback function that is called when a new token is received from the stream. | ||
The callback function should accept two parameters: the token received from the stream and **kwargs. | ||
The callback function should return the token to be sent to the stream. If the callback function is not | ||
provided, the token is printed to stdout. | ||
:param api_base_url: The OpenAI API Base url, defaults to `https://api.openai.com/v1`. | ||
:param model_parameters: A dictionary of parameters to use for the model. See OpenAI | ||
[documentation](https://platform.openai.com/docs/api-reference/chat) for more details. Some of the supported | ||
parameters: | ||
- `max_tokens`: The maximum number of tokens the output text can have. | ||
- `temperature`: What sampling temperature to use. Higher values means the model will take more risks. | ||
Try 0.9 for more creative applications, and 0 (argmax sampling) for ones with a well-defined answer. | ||
- `top_p`: An alternative to sampling with temperature, called nucleus sampling, where the model | ||
considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens | ||
comprising the top 10% probability mass are considered. | ||
- `n`: How many completions to generate for each prompt. For example, if the LLM gets 3 prompts and n is 2, | ||
it will generate two completions for each of the three prompts, ending up with 6 completions in total. | ||
- `stop`: One or more sequences after which the LLM should stop generating tokens. | ||
- `presence_penalty`: What penalty to apply if a token is already present at all. Bigger values mean | ||
the model will be less likely to repeat the same token in the text. | ||
- `frequency_penalty`: What penalty to apply if a token has already been generated in the text. | ||
Bigger values mean the model will be less likely to repeat the same token in the text. | ||
- `logit_bias`: Add a logit bias to specific tokens. The keys of the dictionary are tokens and the | ||
values are the bias to add to that token. | ||
- `openai_organization`: The OpenAI organization ID. | ||
""" | ||
self.llm = ChatGPTBackend( | ||
api_key=api_key, | ||
model_name=model_name, | ||
model_parameters=model_parameters, | ||
streaming_callback=streaming_callback, | ||
api_base_url=api_base_url, | ||
) | ||
self.system_prompt = system_prompt | ||
|
||
def to_dict(self) -> Dict[str, Any]: | ||
""" | ||
Serialize this component to a dictionary. | ||
""" | ||
if self.llm.streaming_callback: | ||
module = sys.modules.get(self.llm.streaming_callback.__module__) | ||
if not module: | ||
raise ValueError("Could not locate the import module.") | ||
if module == builtins: | ||
callback_name = self.llm.streaming_callback.__name__ | ||
else: | ||
callback_name = f"{module.__name__}.{self.llm.streaming_callback.__name__}" | ||
else: | ||
callback_name = None | ||
|
||
return default_to_dict( | ||
self, | ||
api_key=self.llm.api_key, | ||
model_name=self.llm.model_name, | ||
model_parameters=self.llm.model_parameters, | ||
system_prompt=self.system_prompt, | ||
streaming_callback=callback_name, | ||
api_base_url=self.llm.api_base_url, | ||
) | ||
|
||
@classmethod | ||
def from_dict(cls, data: Dict[str, Any]) -> "ChatGPTGenerator": | ||
""" | ||
Deserialize this component from a dictionary. | ||
""" | ||
init_params = data.get("init_parameters", {}) | ||
streaming_callback = None | ||
if "streaming_callback" in init_params: | ||
parts = init_params["streaming_callback"].split(".") | ||
module_name = ".".join(parts[:-1]) | ||
function_name = parts[-1] | ||
module = sys.modules.get(module_name, None) | ||
if not module: | ||
raise DeserializationError(f"Could not locate the module of the streaming callback: {module_name}") | ||
streaming_callback = getattr(module, function_name, None) | ||
if not streaming_callback: | ||
raise DeserializationError(f"Could not locate the streaming callback: {function_name}") | ||
data["init_parameters"]["streaming_callback"] = streaming_callback | ||
return default_from_dict(cls, data) | ||
|
||
@component.output_types(replies=List[List[str]], metadata=List[Dict[str, Any]]) | ||
def run( | ||
self, | ||
prompts: List[str], | ||
api_key: Optional[str] = None, | ||
model_name: str = "gpt-3.5-turbo", | ||
system_prompt: Optional[str] = None, | ||
model_parameters: Optional[Dict[str, Any]] = None, | ||
streaming_callback: Optional[Callable] = None, | ||
api_base_url: str = "https://api.openai.com/v1", | ||
): | ||
""" | ||
Queries the LLM with the prompts to produce replies. | ||
:param prompts: The prompts to be sent to the generative model. | ||
:param api_key: The OpenAI API key. | ||
:param model_name: The name of the model to use. | ||
:param system_prompt: The prompt to be prepended to the user prompt. | ||
:param streaming_callback: A callback function that is called when a new token is received from the stream. | ||
The callback function should accept two parameters: the token received from the stream and **kwargs. | ||
The callback function should return the token to be sent to the stream. If the callback function is not | ||
provided, the token is printed to stdout. | ||
:param api_base_url: The OpenAI API Base url, defaults to `https://api.openai.com/v1`. | ||
:param model_parameters: A dictionary of parameters to use for the model. See OpenAI | ||
[documentation](https://platform.openai.com/docs/api-reference/chat) for more details. Some of the supported | ||
parameters: | ||
- `max_tokens`: The maximum number of tokens the output text can have. | ||
- `temperature`: What sampling temperature to use. Higher values means the model will take more risks. | ||
Try 0.9 for more creative applications, and 0 (argmax sampling) for ones with a well-defined answer. | ||
- `top_p`: An alternative to sampling with temperature, called nucleus sampling, where the model | ||
considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens | ||
comprising the top 10% probability mass are considered. | ||
- `n`: How many completions to generate for each prompt. For example, if the LLM gets 3 prompts and n is 2, | ||
it will generate two completions for each of the three prompts, ending up with 6 completions in total. | ||
- `stop`: One or more sequences after which the LLM should stop generating tokens. | ||
- `presence_penalty`: What penalty to apply if a token is already present at all. Bigger values mean | ||
the model will be less likely to repeat the same token in the text. | ||
- `frequency_penalty`: What penalty to apply if a token has already been generated in the text. | ||
Bigger values mean the model will be less likely to repeat the same token in the text. | ||
- `logit_bias`: Add a logit bias to specific tokens. The keys of the dictionary are tokens and the | ||
values are the bias to add to that token. | ||
- `openai_organization`: The OpenAI organization ID. | ||
See OpenAI documentation](https://platform.openai.com/docs/api-reference/chat) for more details. | ||
""" | ||
system_prompt = system_prompt if system_prompt is not None else self.system_prompt | ||
if system_prompt: | ||
system_message = ChatMessage(content=system_prompt, role="system") | ||
chats = [] | ||
for prompt in prompts: | ||
message = ChatMessage(content=prompt, role="user") | ||
if system_prompt: | ||
chats.append([system_message, message]) | ||
else: | ||
chats.append([message]) | ||
|
||
replies, metadata = [], [] | ||
for chat in chats: | ||
reply, meta = self.llm.complete( | ||
chat=chat, | ||
api_key=api_key, | ||
model_name=model_name, | ||
model_parameters=model_parameters, | ||
streaming_callback=streaming_callback, | ||
api_base_url=api_base_url, | ||
) | ||
replies.append(reply) | ||
metadata.append(meta) | ||
|
||
return {"replies": replies, "metadata": metadata} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.