Skip to content

Commit

Permalink
add chatgpt generator
Browse files Browse the repository at this point in the history
  • Loading branch information
ZanSara committed Sep 4, 2023
1 parent 853f29d commit f0c5a8d
Show file tree
Hide file tree
Showing 7 changed files with 416 additions and 296 deletions.
64 changes: 64 additions & 0 deletions e2e/preview/components/test_chatgpt_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import os
import pytest
from haystack.preview.components.generators.openai.chatgpt import ChatGPTGenerator


@pytest.mark.skipif(
not os.environ.get("OPENAI_API_KEY", None),
reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.",
)
def test_chatgpt_generator_run():
component = ChatGPTGenerator(api_key=os.environ.get("OPENAI_API_KEY"))
results = component.run(
prompts=["What's the capital of France?", "What's the capital of Germany?"], model_parameters={"n": 1}
)

assert len(results["replies"]) == 2
assert len(results["replies"][0]) == 1
assert "Paris" in results["replies"][0][0]
assert len(results["replies"][1]) == 1
assert "Berlin" in results["replies"][1][0]

assert len(results["metadata"]) == 2
assert len(results["metadata"][0]) == 1
assert "gpt-3.5-turbo" in results["metadata"][0][0]["model"]
assert "stop" == results["metadata"][0][0]["finish_reason"]
assert len(results["metadata"][1]) == 1
assert "gpt-3.5-turbo" in results["metadata"][1][0]["model"]
assert "stop" == results["metadata"][1][0]["finish_reason"]


@pytest.mark.skipif(
not os.environ.get("OPENAI_API_KEY", None),
reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.",
)
def test_chatgpt_generator_run_streaming():
class Callback:
def __init__(self):
self.responses = ""

def __call__(self, token, event_data):
self.responses += token
return token

callback = Callback()
component = ChatGPTGenerator(os.environ.get("OPENAI_API_KEY"), streaming_callback=callback)
results = component.run(
prompts=["What's the capital of France?", "What's the capital of Germany?"], model_parameters={"n": 1}
)

assert len(results["replies"]) == 2
assert len(results["replies"][0]) == 1
assert "Paris" in results["replies"][0][0]
assert len(results["replies"][1]) == 1
assert "Berlin" in results["replies"][1][0]

assert callback.responses == results["replies"][0][0] + results["replies"][1][0]

assert len(results["metadata"]) == 2
assert len(results["metadata"][0]) == 1
assert "gpt-3.5-turbo" in results["metadata"][0][0]["model"]
assert "stop" == results["metadata"][0][0]["finish_reason"]
assert len(results["metadata"][1]) == 1
assert "gpt-3.5-turbo" in results["metadata"][1][0]["model"]
assert "stop" == results["metadata"][1][0]["finish_reason"]
33 changes: 0 additions & 33 deletions haystack/preview/components/generators/openai/_helpers.py

This file was deleted.

201 changes: 201 additions & 0 deletions haystack/preview/components/generators/openai/chatgpt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
from typing import Optional, List, Callable, Dict, Any

import sys
import builtins
import logging

from haystack.preview import component, default_from_dict, default_to_dict, DeserializationError
from haystack.preview.llm_backends.openai.chatgpt import ChatGPTBackend
from haystack.preview.llm_backends.chat_message import ChatMessage


logger = logging.getLogger(__name__)


TOKENS_PER_MESSAGE_OVERHEAD = 4


def default_streaming_callback(token: str, **kwargs):
"""
Default callback function for streaming responses from OpenAI API.
Prints the tokens to stdout as soon as they are received and returns them.
"""
print(token, flush=True, end="")
return token


@component
class ChatGPTGenerator:
"""
ChatGPT LLM Generator.
Queries ChatGPT using OpenAI's GPT-3 ChatGPT API. Invocations are made using REST API.
See [OpenAI ChatGPT API](https://platform.openai.com/docs/guides/chat) for more details.
"""

# TODO support function calling!

def __init__(
self,
api_key: Optional[str] = None,
model_name: str = "gpt-3.5-turbo",
system_prompt: Optional[str] = None,
model_parameters: Optional[Dict[str, Any]] = None,
streaming_callback: Optional[Callable] = None,
api_base_url: str = "https://api.openai.com/v1",
):
"""
Creates an instance of ChatGPTGenerator for OpenAI's GPT-3.5 model.
:param api_key: The OpenAI API key.
:param model_name: The name of the model to use.
:param system_prompt: The prompt to be prepended to the user prompt.
:param streaming_callback: A callback function that is called when a new token is received from the stream.
The callback function should accept two parameters: the token received from the stream and **kwargs.
The callback function should return the token to be sent to the stream. If the callback function is not
provided, the token is printed to stdout.
:param api_base_url: The OpenAI API Base url, defaults to `https://api.openai.com/v1`.
:param model_parameters: A dictionary of parameters to use for the model. See OpenAI
[documentation](https://platform.openai.com/docs/api-reference/chat) for more details. Some of the supported
parameters:
- `max_tokens`: The maximum number of tokens the output text can have.
- `temperature`: What sampling temperature to use. Higher values means the model will take more risks.
Try 0.9 for more creative applications, and 0 (argmax sampling) for ones with a well-defined answer.
- `top_p`: An alternative to sampling with temperature, called nucleus sampling, where the model
considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens
comprising the top 10% probability mass are considered.
- `n`: How many completions to generate for each prompt. For example, if the LLM gets 3 prompts and n is 2,
it will generate two completions for each of the three prompts, ending up with 6 completions in total.
- `stop`: One or more sequences after which the LLM should stop generating tokens.
- `presence_penalty`: What penalty to apply if a token is already present at all. Bigger values mean
the model will be less likely to repeat the same token in the text.
- `frequency_penalty`: What penalty to apply if a token has already been generated in the text.
Bigger values mean the model will be less likely to repeat the same token in the text.
- `logit_bias`: Add a logit bias to specific tokens. The keys of the dictionary are tokens and the
values are the bias to add to that token.
- `openai_organization`: The OpenAI organization ID.
"""
self.llm = ChatGPTBackend(
api_key=api_key,
model_name=model_name,
model_parameters=model_parameters,
streaming_callback=streaming_callback,
api_base_url=api_base_url,
)
self.system_prompt = system_prompt

def to_dict(self) -> Dict[str, Any]:
"""
Serialize this component to a dictionary.
"""
if self.llm.streaming_callback:
module = sys.modules.get(self.llm.streaming_callback.__module__)
if not module:
raise ValueError("Could not locate the import module.")
if module == builtins:
callback_name = self.llm.streaming_callback.__name__
else:
callback_name = f"{module.__name__}.{self.llm.streaming_callback.__name__}"
else:
callback_name = None

return default_to_dict(
self,
api_key=self.llm.api_key,
model_name=self.llm.model_name,
model_parameters=self.llm.model_parameters,
system_prompt=self.system_prompt,
streaming_callback=callback_name,
api_base_url=self.llm.api_base_url,
)

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "ChatGPTGenerator":
"""
Deserialize this component from a dictionary.
"""
init_params = data.get("init_parameters", {})
streaming_callback = None
if "streaming_callback" in init_params:
parts = init_params["streaming_callback"].split(".")
module_name = ".".join(parts[:-1])
function_name = parts[-1]
module = sys.modules.get(module_name, None)
if not module:
raise DeserializationError(f"Could not locate the module of the streaming callback: {module_name}")
streaming_callback = getattr(module, function_name, None)
if not streaming_callback:
raise DeserializationError(f"Could not locate the streaming callback: {function_name}")
data["init_parameters"]["streaming_callback"] = streaming_callback
return default_from_dict(cls, data)

@component.output_types(replies=List[List[str]], metadata=List[Dict[str, Any]])
def run(
self,
prompts: List[str],
api_key: Optional[str] = None,
model_name: str = "gpt-3.5-turbo",
system_prompt: Optional[str] = None,
model_parameters: Optional[Dict[str, Any]] = None,
streaming_callback: Optional[Callable] = None,
api_base_url: str = "https://api.openai.com/v1",
):
"""
Queries the LLM with the prompts to produce replies.
:param prompts: The prompts to be sent to the generative model.
:param api_key: The OpenAI API key.
:param model_name: The name of the model to use.
:param system_prompt: The prompt to be prepended to the user prompt.
:param streaming_callback: A callback function that is called when a new token is received from the stream.
The callback function should accept two parameters: the token received from the stream and **kwargs.
The callback function should return the token to be sent to the stream. If the callback function is not
provided, the token is printed to stdout.
:param api_base_url: The OpenAI API Base url, defaults to `https://api.openai.com/v1`.
:param model_parameters: A dictionary of parameters to use for the model. See OpenAI
[documentation](https://platform.openai.com/docs/api-reference/chat) for more details. Some of the supported
parameters:
- `max_tokens`: The maximum number of tokens the output text can have.
- `temperature`: What sampling temperature to use. Higher values means the model will take more risks.
Try 0.9 for more creative applications, and 0 (argmax sampling) for ones with a well-defined answer.
- `top_p`: An alternative to sampling with temperature, called nucleus sampling, where the model
considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens
comprising the top 10% probability mass are considered.
- `n`: How many completions to generate for each prompt. For example, if the LLM gets 3 prompts and n is 2,
it will generate two completions for each of the three prompts, ending up with 6 completions in total.
- `stop`: One or more sequences after which the LLM should stop generating tokens.
- `presence_penalty`: What penalty to apply if a token is already present at all. Bigger values mean
the model will be less likely to repeat the same token in the text.
- `frequency_penalty`: What penalty to apply if a token has already been generated in the text.
Bigger values mean the model will be less likely to repeat the same token in the text.
- `logit_bias`: Add a logit bias to specific tokens. The keys of the dictionary are tokens and the
values are the bias to add to that token.
- `openai_organization`: The OpenAI organization ID.
See OpenAI documentation](https://platform.openai.com/docs/api-reference/chat) for more details.
"""
system_prompt = system_prompt if system_prompt is not None else self.system_prompt
if system_prompt:
system_message = ChatMessage(content=system_prompt, role="system")
chats = []
for prompt in prompts:
message = ChatMessage(content=prompt, role="user")
if system_prompt:
chats.append([system_message, message])
else:
chats.append([message])

replies, metadata = [], []
for chat in chats:
reply, meta = self.llm.complete(
chat=chat,
api_key=api_key,
model_name=model_name,
model_parameters=model_parameters,
streaming_callback=streaming_callback,
api_base_url=api_base_url,
)
replies.append(reply)
metadata.append(meta)

return {"replies": replies, "metadata": metadata}
9 changes: 0 additions & 9 deletions haystack/preview/llm_backends/openai/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,15 +47,6 @@
)


def default_streaming_callback(token: str, **kwargs):
"""
Default callback function for streaming responses from OpenAI API.
Prints the tokens to stdout as soon as they are received and returns them.
"""
print(token, flush=True, end="")
return token


@openai_retry
def complete(url: str, headers: Dict[str, str], payload: Dict[str, Any]) -> Tuple[List[str], List[Dict[str, Any]]]:
"""
Expand Down
4 changes: 2 additions & 2 deletions haystack/preview/llm_backends/openai/chatgpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def __init__(
Creates an instance of ChatGPTGenerator for OpenAI's GPT-3.5 model.
:param api_key: The OpenAI API key.
:param model_name: The name or path of the underlying model.
:param model_name: The name of the model to use.
:param streaming_callback: A callback function that is called when a new token is received from the stream.
The callback function should accept two parameters: the token received from the stream and **kwargs.
The callback function should return the token to be sent to the stream. If the callback function is not
Expand Down Expand Up @@ -126,7 +126,7 @@ def complete(
:param chat: The chat to be sent to the generative model.
:param api_key: The OpenAI API key.
:param model_name: The name or path of the underlying model.
:param model_name: The name of the model to use.
:param streaming_callback: A callback function that is called when a new token is received from the stream.
The callback function should accept two parameters: the token received from the stream and **kwargs.
The callback function should return the token to be sent to the stream. If the callback function is not
Expand Down
Loading

0 comments on commit f0c5a8d

Please sign in to comment.