Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: ChatGPTGenerator #5710

Closed
wants to merge 34 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
0fc2bac
add generators module
ZanSara Aug 30, 2023
7f6325c
add tests for module helper
ZanSara Aug 30, 2023
47b6799
add chatgpt generator
ZanSara Aug 30, 2023
4e8fcb3
add init and serialization tests
ZanSara Aug 30, 2023
cbf7701
test component
ZanSara Aug 30, 2023
419f615
reno
ZanSara Aug 30, 2023
49ff654
Merge branch 'main' into generators-module
ZanSara Aug 30, 2023
4edeb8e
Merge branch 'generators-module' into chatgpt-generator
ZanSara Aug 30, 2023
08e9c62
reno
ZanSara Aug 30, 2023
a984e67
more tests
ZanSara Aug 30, 2023
612876a
add another test
ZanSara Aug 31, 2023
ec8e14a
Merge branch 'generators-module' of github.com:deepset-ai/haystack in…
ZanSara Aug 31, 2023
366b0ff
Merge branch 'generators-module' into chatgpt-generator
ZanSara Aug 31, 2023
e9c3de7
chat token limit
ZanSara Aug 31, 2023
725fabe
move into openai
ZanSara Aug 31, 2023
4d4f9d4
Merge branch 'generators-module' into chatgpt-generator
ZanSara Aug 31, 2023
c3bef8f
fix test
ZanSara Aug 31, 2023
c1a7696
improve tests
ZanSara Aug 31, 2023
246ca63
Merge branch 'generators-module' into chatgpt-generator
ZanSara Aug 31, 2023
ec809e4
add e2e test and small fixes
ZanSara Aug 31, 2023
5d946f8
linting
ZanSara Aug 31, 2023
aa9ce33
Add ChatGPTGenerator example
vblagoje Aug 31, 2023
9310057
review feedback
ZanSara Aug 31, 2023
7c36db1
Merge branch 'chatgpt-generator' of github.com:deepset-ai/haystack in…
ZanSara Aug 31, 2023
b2e421d
support for metadata
ZanSara Aug 31, 2023
6d81d79
Merge branch 'main' into chatgpt-generator
ZanSara Aug 31, 2023
2895697
mypy
ZanSara Aug 31, 2023
1538d61
mypy
ZanSara Sep 1, 2023
02cd61f
extract backend from generator and make it accept chats
ZanSara Sep 1, 2023
84332c6
fix tests
ZanSara Sep 1, 2023
329b54d
mypy
ZanSara Sep 4, 2023
5ee2aac
query->complete
ZanSara Sep 4, 2023
429a3ae
mypy
ZanSara Sep 4, 2023
c0b237d
Merge branch 'main' into chatgpt-generator
ZanSara Sep 4, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions e2e/preview/components/test_chatgpt_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import os
import pytest
from haystack.preview.components.generators.openai.chatgpt import ChatGPTGenerator


@pytest.mark.skipif(
not os.environ.get("OPENAI_API_KEY", None),
reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.",
)
def test_chatgpt_generator_run():
component = ChatGPTGenerator(api_key=os.environ.get("OPENAI_API_KEY"))
results = component.run(prompts=["What's the capital of France?", "What's the capital of Germany?"], n=1)

assert len(results["replies"]) == 2
assert len(results["replies"][0]) == 1
assert "Paris" in results["replies"][0][0]
assert len(results["replies"][1]) == 1
assert "Berlin" in results["replies"][1][0]

assert len(results["metadata"]) == 2
assert len(results["metadata"][0]) == 1
assert "gpt-3.5-turbo" in results["metadata"][0][0]["model"]
assert "stop" == results["metadata"][0][0]["finish_reason"]
assert len(results["metadata"][1]) == 1
assert "gpt-3.5-turbo" in results["metadata"][1][0]["model"]
assert "stop" == results["metadata"][1][0]["finish_reason"]


@pytest.mark.skipif(
not os.environ.get("OPENAI_API_KEY", None),
reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.",
)
def test_chatgpt_generator_run_streaming():
class Callback:
def __init__(self):
self.responses = ""

def __call__(self, token, event_data):
self.responses += token
return token

callback = Callback()
component = ChatGPTGenerator(os.environ.get("OPENAI_API_KEY"), stream=True, streaming_callback=callback)
results = component.run(prompts=["What's the capital of France?", "What's the capital of Germany?"], n=1)

assert len(results["replies"]) == 2
assert len(results["replies"][0]) == 1
assert "Paris" in results["replies"][0][0]
assert len(results["replies"][1]) == 1
assert "Berlin" in results["replies"][1][0]

assert callback.responses == results["replies"][0][0] + results["replies"][1][0]

assert len(results["metadata"]) == 2
assert len(results["metadata"][0]) == 1
assert "gpt-3.5-turbo" in results["metadata"][0][0]["model"]
assert "stop" == results["metadata"][0][0]["finish_reason"]
assert len(results["metadata"][1]) == 1
assert "gpt-3.5-turbo" in results["metadata"][1][0]["model"]
assert "stop" == results["metadata"][1][0]["finish_reason"]
2 changes: 1 addition & 1 deletion haystack/preview/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from canals import component, Pipeline
from canals.serialization import default_from_dict, default_to_dict
from canals.errors import DeserializationError
from canals.errors import DeserializationError, ComponentError
from haystack.preview.dataclasses import *
33 changes: 0 additions & 33 deletions haystack/preview/components/generators/openai/_helpers.py

This file was deleted.

194 changes: 194 additions & 0 deletions haystack/preview/components/generators/openai/chatgpt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
from typing import Optional, List, Callable, Dict, Any

import logging

from haystack.preview import component, default_from_dict, default_to_dict
from haystack.preview.llm_backends.openai.chatgpt import ChatGPTBackend
from haystack.preview.llm_backends.chat_message import ChatMessage
from haystack.preview.llm_backends.openai._helpers import default_streaming_callback


logger = logging.getLogger(__name__)


TOKENS_PER_MESSAGE_OVERHEAD = 4


@component
class ChatGPTGenerator:
"""
ChatGPT LLM Generator.

Queries ChatGPT using OpenAI's GPT-3 ChatGPT API. Invocations are made using REST API.
See [OpenAI ChatGPT API](https://platform.openai.com/docs/guides/chat) for more details.
"""

# TODO support function calling!

def __init__(
self,
api_key: Optional[str] = None,
model_name: str = "gpt-3.5-turbo",
system_prompt: Optional[str] = None,
max_tokens: Optional[int] = 500,
temperature: Optional[float] = 0.7,
top_p: Optional[float] = 1,
n: Optional[int] = 1,
stop: Optional[List[str]] = None,
presence_penalty: Optional[float] = 0,
frequency_penalty: Optional[float] = 0,
logit_bias: Optional[Dict[str, float]] = None,
stream: bool = False,
streaming_callback: Optional[Callable] = default_streaming_callback,
api_base_url: str = "https://api.openai.com/v1",
openai_organization: Optional[str] = None,
):
"""
Creates an instance of ChatGPTGenerator for OpenAI's GPT-3.5 model.

:param api_key: The OpenAI API key.
:param model_name: The name or path of the underlying model.
:param system_prompt: The prompt to be prepended to the user prompt.
:param max_tokens: The maximum number of tokens the output text can have.
:param temperature: What sampling temperature to use. Higher values means the model will take more risks.
Try 0.9 for more creative applications, and 0 (argmax sampling) for ones with a well-defined answer.
:param top_p: An alternative to sampling with temperature, called nucleus sampling, where the model
considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens
comprising the top 10% probability mass are considered.
:param n: How many completions to generate for each prompt.
:param stop: One or more sequences where the API will stop generating further tokens.
:param presence_penalty: What penalty to apply if a token is already present at all. Bigger values mean
the model will be less likely to repeat the same token in the text.
:param frequency_penalty: What penalty to apply if a token has already been generated in the text.
Bigger values mean the model will be less likely to repeat the same token in the text.
:param logit_bias: Add a logit bias to specific tokens. The keys of the dictionary are tokens and the
values are the bias to add to that token.
:param stream: If set to True, the API will stream the response. The streaming_callback parameter
is used to process the stream. If set to False, the response will be returned as a string.
:param streaming_callback: A callback function that is called when a new token is received from the stream.
The callback function should accept two parameters: the token received from the stream and **kwargs.
The callback function should return the token to be sent to the stream. If the callback function is not
provided, the token is printed to stdout.
:param api_base_url: The OpenAI API Base url, defaults to `https://api.openai.com/v1`.
:param openai_organization: The OpenAI organization ID.

See OpenAI documentation](https://platform.openai.com/docs/api-reference/chat) for more details.
"""
self.llm = ChatGPTBackend(
api_key=api_key,
model_name=model_name,
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
n=n,
stop=stop,
presence_penalty=presence_penalty,
frequency_penalty=frequency_penalty,
logit_bias=logit_bias,
stream=stream,
streaming_callback=streaming_callback,
api_base_url=api_base_url,
openai_organization=openai_organization,
)
self.system_prompt = system_prompt

def to_dict(self) -> Dict[str, Any]:
"""
Serialize this component to a dictionary.
"""
return default_to_dict(self, system_prompt=self.system_prompt, **self.llm.to_dict())

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "ChatGPTGenerator":
"""
Deserialize this component from a dictionary.
"""
# FIXME how to deserialize the streaming callback?
return default_from_dict(cls, data)

@component.output_types(replies=List[List[str]], metadata=List[Dict[str, Any]])
def run(
self,
prompts: List[str],
api_key: Optional[str] = None,
model_name: Optional[str] = None,
system_prompt: Optional[str] = None,
max_tokens: Optional[int] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
n: Optional[int] = None,
stop: Optional[List[str]] = None,
presence_penalty: Optional[float] = None,
frequency_penalty: Optional[float] = None,
logit_bias: Optional[Dict[str, float]] = None,
stream: Optional[bool] = None,
streaming_callback: Optional[Callable] = None,
api_base_url: Optional[str] = None,
openai_organization: Optional[str] = None,
):
"""
Queries the LLM with the prompts to produce replies.

:param prompts: The prompts to be sent to the generative model.
:param api_key: The OpenAI API key.
:param model_name: The name or path of the underlying model.
:param system_prompt: The prompt to be prepended to the user prompt.
:param max_tokens: The maximum number of tokens the output text can have.
:param temperature: What sampling temperature to use. Higher values means the model will take more risks.
Try 0.9 for more creative applications, and 0 (argmax sampling) for ones with a well-defined answer.
:param top_p: An alternative to sampling with temperature, called nucleus sampling, where the model
considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens
comprising the top 10% probability mass are considered.
:param n: How many completions to generate for each prompt.
:param stop: One or more sequences where the API will stop generating further tokens.
:param presence_penalty: What penalty to apply if a token is already present at all. Bigger values mean
the model will be less likely to repeat the same token in the text.
:param frequency_penalty: What penalty to apply if a token has already been generated in the text.
Bigger values mean the model will be less likely to repeat the same token in the text.
:param logit_bias: Add a logit bias to specific tokens. The keys of the dictionary are tokens and the
values are the bias to add to that token.
:param stream: If set to True, the API will stream the response. The streaming_callback parameter
is used to process the stream. If set to False, the response will be returned as a string.
:param streaming_callback: A callback function that is called when a new token is received from the stream.
The callback function should accept two parameters: the token received from the stream and **kwargs.
The callback function should return the token to be sent to the stream. If the callback function is not
provided, the token is printed to stdout.
:param api_base_url: The OpenAI API Base url, defaults to `https://api.openai.com/v1`.
:param openai_organization: The OpenAI organization ID.

See OpenAI documentation](https://platform.openai.com/docs/api-reference/chat) for more details.
"""
system_prompt = system_prompt if system_prompt is not None else self.system_prompt
if system_prompt:
system_message = ChatMessage(content=system_prompt, role="system")
chats = []
for prompt in prompts:
message = ChatMessage(content=prompt, role="user")
if system_prompt:
chats.append([system_message, message])
else:
chats.append([message])

replies, metadata = [], []
for chat in chats:
reply, meta = self.llm.complete(
chat=chat,
api_key=api_key,
model_name=model_name,
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
n=n,
stop=stop,
presence_penalty=presence_penalty,
frequency_penalty=frequency_penalty,
logit_bias=logit_bias,
api_base_url=api_base_url,
openai_organization=openai_organization,
stream=stream,
streaming_callback=streaming_callback,
)
replies.append(reply)
metadata.append(meta)

return {"replies": replies, "metadata": metadata}
Empty file.
13 changes: 13 additions & 0 deletions haystack/preview/examples/chat_gpt_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import os

from haystack.preview.components.generators.openai.chatgpt import ChatGPTGenerator

stream_response = False

llm = ChatGPTGenerator(
api_key=os.environ.get("OPENAI_API_KEY"), model_name="gpt-3.5-turbo", max_tokens=256, stream=stream_response
)

responses = llm.run(prompts=["What is the meaning of life?"])
if not stream_response:
print(responses)
Empty file.
7 changes: 7 additions & 0 deletions haystack/preview/llm_backends/chat_message.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from dataclasses import dataclass


@dataclass
class ChatMessage:
content: str
role: str
Empty file.
Loading
Loading