Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: GPT35Generator #5714

Merged
merged 44 commits into from
Sep 7, 2023
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
a243cae
chatgpt backend
ZanSara Sep 4, 2023
f59abe8
fix tests
ZanSara Sep 4, 2023
5f70a65
reno
ZanSara Sep 4, 2023
ffb1a8f
remove print
ZanSara Sep 4, 2023
853f29d
helpers tests
ZanSara Sep 4, 2023
f0c5a8d
add chatgpt generator
ZanSara Sep 4, 2023
0a25414
use openai sdk
ZanSara Sep 4, 2023
5105ae8
remove backend
ZanSara Sep 4, 2023
7d0c8e6
tests are broken
ZanSara Sep 4, 2023
de46d10
fix tests
ZanSara Sep 5, 2023
28d83f4
stray param
ZanSara Sep 5, 2023
30b4bc3
move _check_troncated_answers into the class
ZanSara Sep 5, 2023
1b744e4
wrong import
ZanSara Sep 5, 2023
ab0e45c
rename function
ZanSara Sep 5, 2023
fc7dc05
typo in test
ZanSara Sep 5, 2023
3e43dcd
add openai deps
ZanSara Sep 5, 2023
c3381e3
mypy
ZanSara Sep 5, 2023
a204d14
Merge branch 'main' into chatgpt-llm-generator
ZanSara Sep 5, 2023
8d6f134
improve system prompt docstring
ZanSara Sep 5, 2023
8e0c1c6
Merge branch 'chatgpt-llm-generator' of github.com:deepset-ai/haystac…
ZanSara Sep 5, 2023
e1652f8
typos update
dfokina Sep 5, 2023
2a256b2
Update haystack/preview/components/generators/openai/chatgpt.py
ZanSara Sep 5, 2023
7178f23
pylint
ZanSara Sep 5, 2023
9eb7900
Merge branch 'chatgpt-llm-generator' of github.com:deepset-ai/haystac…
ZanSara Sep 5, 2023
13104de
Merge branch 'main' into chatgpt-llm-generator
ZanSara Sep 5, 2023
155485f
Update haystack/preview/components/generators/openai/chatgpt.py
ZanSara Sep 5, 2023
b2187c3
Update haystack/preview/components/generators/openai/chatgpt.py
ZanSara Sep 5, 2023
ed08e34
Update haystack/preview/components/generators/openai/chatgpt.py
ZanSara Sep 5, 2023
cc0bb7d
review feedback
ZanSara Sep 5, 2023
c58ab26
fix tests
ZanSara Sep 5, 2023
835fd0c
freview feedback
ZanSara Sep 5, 2023
0eb43f9
reno
ZanSara Sep 5, 2023
e8d92dd
remove tenacity mock
ZanSara Sep 6, 2023
0aeb875
gpt35generator
ZanSara Sep 6, 2023
9167e05
fix naming
ZanSara Sep 6, 2023
941cc66
remove stray references to chatgpt
ZanSara Sep 6, 2023
04ec229
fix e2e
ZanSara Sep 6, 2023
4eece1e
Merge branch 'main' into chatgpt-llm-generator
ZanSara Sep 6, 2023
8fb06ae
Update releasenotes/notes/chatgpt-llm-generator-d043532654efe684.yaml
ZanSara Sep 6, 2023
46385ac
add another test
ZanSara Sep 6, 2023
812e8b9
Merge branch 'main' into chatgpt-llm-generator
ZanSara Sep 6, 2023
3ca3f73
test wrong model name
ZanSara Sep 6, 2023
1015424
review feedback
ZanSara Sep 6, 2023
b79c7c1
Merge branch 'main' into chatgpt-llm-generator
ZanSara Sep 6, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 67 additions & 0 deletions e2e/preview/components/test_chatgpt_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import os
import pytest
from haystack.preview.components.generators.openai.chatgpt import ChatGPTGenerator


@pytest.mark.skipif(
not os.environ.get("OPENAI_API_KEY", None),
reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.",
)
def test_chatgpt_generator_run():
component = ChatGPTGenerator(api_key=os.environ.get("OPENAI_API_KEY"))
results = component.run(
prompts=["What's the capital of France?", "What's the capital of Germany?"], model_parameters={"n": 1}
)

assert len(results["replies"]) == 2
assert len(results["replies"][0]) == 1
assert "Paris" in results["replies"][0][0]
assert len(results["replies"][1]) == 1
assert "Berlin" in results["replies"][1][0]

assert len(results["metadata"]) == 2
assert len(results["metadata"][0]) == 1
assert "gpt-3.5-turbo" in results["metadata"][0][0]["model"]
assert "stop" == results["metadata"][0][0]["finish_reason"]
assert len(results["metadata"][1]) == 1
assert "gpt-3.5-turbo" in results["metadata"][1][0]["model"]
assert "stop" == results["metadata"][1][0]["finish_reason"]


@pytest.mark.skipif(
not os.environ.get("OPENAI_API_KEY", None),
reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.",
)
def test_chatgpt_generator_run_streaming():
class Callback:
def __init__(self):
self.responses = ""

def __call__(self, chunk):
self.responses += chunk.choices[0].delta.content if chunk.choices[0].delta else ""
return chunk

callback = Callback()
component = ChatGPTGenerator(os.environ.get("OPENAI_API_KEY"), streaming_callback=callback)
results = component.run(
prompts=["What's the capital of France?", "What's the capital of Germany?"], model_parameters={"n": 1}
)

assert len(results["replies"]) == 2
assert len(results["replies"][0]) == 1
assert "Paris" in results["replies"][0][0]
assert len(results["replies"][1]) == 1
assert "Berlin" in results["replies"][1][0]

assert callback.responses == results["replies"][0][0] + results["replies"][1][0]

assert len(results["metadata"]) == 2
assert len(results["metadata"][0]) == 1

print(results["metadata"][0][0])

assert "gpt-3.5-turbo" in results["metadata"][0][0]["model"]
assert "stop" == results["metadata"][0][0]["finish_reason"]
assert len(results["metadata"][1]) == 1
assert "gpt-3.5-turbo" in results["metadata"][1][0]["model"]
assert "stop" == results["metadata"][1][0]["finish_reason"]
2 changes: 1 addition & 1 deletion haystack/preview/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from canals import component, Pipeline
from canals.serialization import default_from_dict, default_to_dict
from canals.errors import DeserializationError
from canals.errors import DeserializationError, ComponentError
from haystack.preview.dataclasses import *
33 changes: 0 additions & 33 deletions haystack/preview/components/generators/openai/_helpers.py

This file was deleted.

268 changes: 268 additions & 0 deletions haystack/preview/components/generators/openai/chatgpt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,268 @@
from typing import Optional, List, Callable, Dict, Any

import sys
import builtins
import logging
from dataclasses import asdict

import openai

from haystack.preview import component, default_from_dict, default_to_dict, DeserializationError

# from haystack.preview.llm_backends.openai.chatgpt import ChatGPTBackend
ZanSara marked this conversation as resolved.
Show resolved Hide resolved
from haystack.preview.dataclasses.chat_message import ChatMessage


logger = logging.getLogger(__name__)


TOKENS_PER_MESSAGE_OVERHEAD = 4
ZanSara marked this conversation as resolved.
Show resolved Hide resolved


def default_streaming_callback(chunk):
"""
Default callback function for streaming responses from OpenAI API.
Prints the tokens of the first completion to stdout as soon as they are received and returns the chunk unchanged.
"""
if hasattr(chunk.choices[0].delta, "content"):
print(chunk.choices[0].delta.content, flush=True, end="")
return chunk
ZanSara marked this conversation as resolved.
Show resolved Hide resolved


@component
class ChatGPTGenerator:
"""
ChatGPT LLM Generator.

Queries ChatGPT using OpenAI's GPT-3 ChatGPT API. Invocations are made using REST API.
See [OpenAI ChatGPT API](https://platform.openai.com/docs/guides/chat) for more details.
"""

# TODO support function calling!
ZanSara marked this conversation as resolved.
Show resolved Hide resolved

def __init__(
self,
api_key: Optional[str] = None,
ZanSara marked this conversation as resolved.
Show resolved Hide resolved
model_name: str = "gpt-3.5-turbo",
system_prompt: Optional[str] = None,
model_parameters: Optional[Dict[str, Any]] = None,
ZanSara marked this conversation as resolved.
Show resolved Hide resolved
streaming_callback: Optional[Callable] = None,
ZanSara marked this conversation as resolved.
Show resolved Hide resolved
api_base_url: str = "https://api.openai.com/v1",
):
"""
Creates an instance of ChatGPTGenerator for OpenAI's GPT-3.5 model.

:param api_key: The OpenAI API key.
:param model_name: The name of the model to use.
:param system_prompt: The prompt to be prepended to the user prompt.
:param streaming_callback: A callback function that is called when a new token is received from the stream.
The callback function should accept two parameters: the token received from the stream and **kwargs.
The callback function should return the token to be sent to the stream. If the callback function is not
provided, the token is printed to stdout.
:param api_base_url: The OpenAI API Base url, defaults to `https://api.openai.com/v1`.
:param model_parameters: A dictionary of parameters to use for the model. See OpenAI
[documentation](https://platform.openai.com/docs/api-reference/chat) for more details. Some of the supported
parameters:
- `max_tokens`: The maximum number of tokens the output text can have.
- `temperature`: What sampling temperature to use. Higher values means the model will take more risks.
Try 0.9 for more creative applications, and 0 (argmax sampling) for ones with a well-defined answer.
- `top_p`: An alternative to sampling with temperature, called nucleus sampling, where the model
considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens
comprising the top 10% probability mass are considered.
- `n`: How many completions to generate for each prompt. For example, if the LLM gets 3 prompts and n is 2,
it will generate two completions for each of the three prompts, ending up with 6 completions in total.
- `stop`: One or more sequences after which the LLM should stop generating tokens.
- `presence_penalty`: What penalty to apply if a token is already present at all. Bigger values mean
the model will be less likely to repeat the same token in the text.
- `frequency_penalty`: What penalty to apply if a token has already been generated in the text.
Bigger values mean the model will be less likely to repeat the same token in the text.
- `logit_bias`: Add a logit bias to specific tokens. The keys of the dictionary are tokens and the
values are the bias to add to that token.
- `openai_organization`: The OpenAI organization ID.
"""
if not api_key:
logger.warning("OpenAI API key is missing. You need to provide an API key to Pipeline.run().")

self.api_key = api_key
self.model_name = model_name
self.system_prompt = system_prompt
self.model_parameters = model_parameters
self.streaming_callback = streaming_callback
self.api_base_url = api_base_url

def to_dict(self) -> Dict[str, Any]:
"""
Serialize this component to a dictionary.
"""
if self.streaming_callback:
module = sys.modules.get(self.streaming_callback.__module__)
if not module:
raise ValueError("Could not locate the import module.")
if module == builtins:
callback_name = self.streaming_callback.__name__
else:
callback_name = f"{module.__name__}.{self.streaming_callback.__name__}"
ZanSara marked this conversation as resolved.
Show resolved Hide resolved
else:
callback_name = None

return default_to_dict(
self,
api_key=self.api_key,
model_name=self.model_name,
model_parameters=self.model_parameters,
system_prompt=self.system_prompt,
streaming_callback=callback_name,
api_base_url=self.api_base_url,
)

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "ChatGPTGenerator":
"""
Deserialize this component from a dictionary.
"""
init_params = data.get("init_parameters", {})
streaming_callback = None
if "streaming_callback" in init_params:
parts = init_params["streaming_callback"].split(".")
module_name = ".".join(parts[:-1])
function_name = parts[-1]
module = sys.modules.get(module_name, None)
if not module:
raise DeserializationError(f"Could not locate the module of the streaming callback: {module_name}")
streaming_callback = getattr(module, function_name, None)
if not streaming_callback:
raise DeserializationError(f"Could not locate the streaming callback: {function_name}")
data["init_parameters"]["streaming_callback"] = streaming_callback
return default_from_dict(cls, data)

@component.output_types(replies=List[List[str]], metadata=List[Dict[str, Any]])
def run(
self,
prompts: List[str],
api_key: Optional[str] = None,
model_name: str = "gpt-3.5-turbo",
system_prompt: Optional[str] = None,
model_parameters: Optional[Dict[str, Any]] = None,
streaming_callback: Optional[Callable] = None,
api_base_url: str = "https://api.openai.com/v1",
ZanSara marked this conversation as resolved.
Show resolved Hide resolved
):
"""
Queries the LLM with the prompts to produce replies.

:param prompts: The prompts to be sent to the generative model.
:param api_key: The OpenAI API key.
:param model_name: The name of the model to use.
:param system_prompt: The prompt to be prepended to the user prompt.
ZanSara marked this conversation as resolved.
Show resolved Hide resolved
:param streaming_callback: A callback function that is called when a new token is received from the stream.
The callback function should accept two parameters: the token received from the stream and **kwargs.
The callback function should return the token to be sent to the stream. If the callback function is not
provided, the token is printed to stdout.
:param api_base_url: The OpenAI API Base url, defaults to `https://api.openai.com/v1`.
:param model_parameters: A dictionary of parameters to use for the model. See OpenAI
[documentation](https://platform.openai.com/docs/api-reference/chat) for more details. Some of the supported
parameters:
- `max_tokens`: The maximum number of tokens the output text can have.
- `temperature`: What sampling temperature to use. Higher values means the model will take more risks.
Try 0.9 for more creative applications, and 0 (argmax sampling) for ones with a well-defined answer.
- `top_p`: An alternative to sampling with temperature, called nucleus sampling, where the model
considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens
comprising the top 10% probability mass are considered.
- `n`: How many completions to generate for each prompt. For example, if the LLM gets 3 prompts and n is 2,
it will generate two completions for each of the three prompts, ending up with 6 completions in total.
- `stop`: One or more sequences after which the LLM should stop generating tokens.
- `presence_penalty`: What penalty to apply if a token is already present at all. Bigger values mean
the model will be less likely to repeat the same token in the text.
- `frequency_penalty`: What penalty to apply if a token has already been generated in the text.
Bigger values mean the model will be less likely to repeat the same token in the text.
- `logit_bias`: Add a logit bias to specific tokens. The keys of the dictionary are tokens and the
values are the bias to add to that token.
- `openai_organization`: The OpenAI organization ID.

See OpenAI documentation](https://platform.openai.com/docs/api-reference/chat) for more details.
"""
api_key = api_key if api_key is not None else self.api_key
if not api_key:
raise ValueError("OpenAI API key is missing. Please provide an API key.")

model_name = model_name or self.model_name
system_prompt = system_prompt if system_prompt is not None else self.system_prompt
model_parameters = model_parameters if model_parameters is not None else self.model_parameters
streaming_callback = streaming_callback or self.streaming_callback
api_base_url = api_base_url or self.api_base_url
ZanSara marked this conversation as resolved.
Show resolved Hide resolved

if system_prompt:
system_message = ChatMessage(content=system_prompt, role="system")
chats = []
for prompt in prompts:
message = ChatMessage(content=prompt, role="user")
if system_prompt:
chats.append([system_message, message])
else:
chats.append([message])

all_replies, all_metadata = [], []
for chat in chats:
completion = openai.ChatCompletion.create(
model=self.model_name,
api_key=self.api_key,
messages=[asdict(message) for message in chat],
stream=streaming_callback is not None,
**(self.model_parameters or model_parameters or {}),
)

replies: List[str]
metadata: List[Dict[str, Any]]
if streaming_callback:
replies_dict = {}
metadata_dict: Dict[str, Dict[str, Any]] = {}
for chunk in completion:
chunk = streaming_callback(chunk)
for choice in chunk.choices:
if choice.index not in replies_dict:
replies_dict[choice.index] = ""
metadata_dict[choice.index] = {}

if hasattr(choice.delta, "content"):
replies_dict[choice.index] += choice.delta.content
metadata_dict[choice.index] = {
"model": chunk.model,
"index": choice.index,
"finish_reason": choice.finish_reason,
}
all_replies.append(list(replies_dict.values()))
all_metadata.append(list(metadata_dict.values()))
self._check_truncated_answers(list(metadata_dict.values()))

else:
metadata = [
{
"model": completion.model,
"index": choice.index,
"finish_reason": choice.finish_reason,
"usage": dict(completion.usage.items()),
}
for choice in completion.choices
]
replies = [choice.message.content.strip() for choice in completion.choices]
all_replies.append(replies)
all_metadata.append(metadata)
self._check_truncated_answers(metadata)

return {"replies": all_replies, "metadata": all_metadata}

def _check_truncated_answers(self, metadata: List[Dict[str, Any]]):
"""
Check the `finish_reason` the answers returned by OpenAI completions endpoint.
ZanSara marked this conversation as resolved.
Show resolved Hide resolved
If the `finish_reason` is `length`, log a warning to the user.

:param result: The result returned from the OpenAI API.
:param payload: The payload sent to the OpenAI API.
"""
truncated_completions = sum([1 for meta in metadata if meta.get("finish_reason") != "stop"])
if truncated_completions > 0:
logger.warning(
"%s out of the %s completions have been truncated before reaching a natural stopping point. "
"Increase the max_tokens parameter to allow for longer completions.",
truncated_completions,
len(metadata),
)
7 changes: 7 additions & 0 deletions haystack/preview/dataclasses/chat_message.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from dataclasses import dataclass


@dataclass
class ChatMessage:
ZanSara marked this conversation as resolved.
Show resolved Hide resolved
content: str
role: str
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ dependencies = [

# Preview
"canals==0.8.0",
"openai",
ZanSara marked this conversation as resolved.
Show resolved Hide resolved
"Jinja2",

# Agent events
Expand Down
2 changes: 2 additions & 0 deletions releasenotes/notes/chatgpt-llm-backend-d043532654efe684.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
preview:
- Introduce `ChatGPTBackend`, a class that will be used by LLM components to talk to OpenAI Chat models like ChatGPT and GPT4. Note that ChatGPTBackend itself is NOT a component.
Loading
Loading