Skip to content

Commit

Permalink
feat: allow Generators to run with a system prompt defined at run t…
Browse files Browse the repository at this point in the history
…ime (#8423)

* initial import

* Update haystack/components/generators/openai.py

Co-authored-by: Sebastian Husch Lee <[email protected]>

* docs: fixing

* supporting the three use cases: no system prompt, using system prompt defined at init, using system prompt defined at run time

* renaming 'run_time_system_prompt' to 'system_prompt'

* adding tests, converting methods to static

---------

Co-authored-by: Sebastian Husch Lee <[email protected]>
  • Loading branch information
davidsbatista and sjrl authored Oct 22, 2024
1 parent f6935d1 commit 3a50d35
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 22 deletions.
20 changes: 15 additions & 5 deletions haystack/components/generators/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ def from_dict(cls, data: Dict[str, Any]) -> "OpenAIGenerator":
def run(
self,
prompt: str,
system_prompt: Optional[str] = None,
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
generation_kwargs: Optional[Dict[str, Any]] = None,
):
Expand All @@ -178,6 +179,9 @@ def run(
:param prompt:
The string prompt to use for text generation.
:param system_prompt:
The system prompt to use for text generation. If this run time system prompt is omitted, the system
prompt, if defined at initialisation time, is used.
:param streaming_callback:
A callback function that is called when a new token is received from the stream.
:param generation_kwargs:
Expand All @@ -189,7 +193,9 @@ def run(
for each response.
"""
message = ChatMessage.from_user(prompt)
if self.system_prompt:
if system_prompt is not None:
messages = [ChatMessage.from_system(system_prompt), message]
elif self.system_prompt:
messages = [ChatMessage.from_system(self.system_prompt), message]
else:
messages = [message]
Expand Down Expand Up @@ -237,7 +243,8 @@ def run(
"meta": [message.meta for message in completions],
}

def _connect_chunks(self, chunk: Any, chunks: List[StreamingChunk]) -> ChatMessage:
@staticmethod
def _connect_chunks(chunk: Any, chunks: List[StreamingChunk]) -> ChatMessage:
"""
Connects the streaming chunks into a single ChatMessage.
"""
Expand All @@ -252,7 +259,8 @@ def _connect_chunks(self, chunk: Any, chunks: List[StreamingChunk]) -> ChatMessa
)
return complete_response

def _build_message(self, completion: Any, choice: Any) -> ChatMessage:
@staticmethod
def _build_message(completion: Any, choice: Any) -> ChatMessage:
"""
Converts the response from the OpenAI API to a ChatMessage.
Expand All @@ -276,7 +284,8 @@ def _build_message(self, completion: Any, choice: Any) -> ChatMessage:
)
return chat_message

def _build_chunk(self, chunk: Any) -> StreamingChunk:
@staticmethod
def _build_chunk(chunk: Any) -> StreamingChunk:
"""
Converts the response from the OpenAI API to a StreamingChunk.
Expand All @@ -293,7 +302,8 @@ def _build_chunk(self, chunk: Any) -> StreamingChunk:
chunk_message.meta.update({"model": chunk.model, "index": choice.index, "finish_reason": choice.finish_reason})
return chunk_message

def _check_finish_reason(self, message: ChatMessage) -> None:
@staticmethod
def _check_finish_reason(message: ChatMessage) -> None:
"""
Check the `finish_reason` returned with the OpenAI completions.
Expand Down
36 changes: 19 additions & 17 deletions test/components/generators/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,23 +49,6 @@ def test_init_with_parameters(self, monkeypatch):
assert component.client.timeout == 40.0
assert component.client.max_retries == 1

def test_init_with_parameters(self, monkeypatch):
monkeypatch.setenv("OPENAI_TIMEOUT", "100")
monkeypatch.setenv("OPENAI_MAX_RETRIES", "10")
component = OpenAIGenerator(
api_key=Secret.from_token("test-api-key"),
model="gpt-4o-mini",
streaming_callback=print_streaming_chunk,
api_base_url="test-base-url",
generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"},
)
assert component.client.api_key == "test-api-key"
assert component.model == "gpt-4o-mini"
assert component.streaming_callback is print_streaming_chunk
assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"}
assert component.client.timeout == 100.0
assert component.client.max_retries == 10

def test_to_dict_default(self, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
component = OpenAIGenerator()
Expand Down Expand Up @@ -331,3 +314,22 @@ def __call__(self, chunk: StreamingChunk) -> None:

assert callback.counter > 1
assert "Paris" in callback.responses

@pytest.mark.skipif(
not os.environ.get("OPENAI_API_KEY", None),
reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.",
)
@pytest.mark.integration
def test_run_with_system_prompt(self):
generator = OpenAIGenerator(
model="gpt-4o-mini",
system_prompt="You answer in Portuguese, regardless of the language on which a question is asked",
)
result = generator.run("Can you explain the Pitagoras therom?")
assert "teorema" in result["replies"][0]

result = generator.run(
"Can you explain the Pitagoras therom?",
system_prompt="You answer in German, regardless of the language on which a question is asked.",
)
assert "Pythagoras" in result["replies"][0]

0 comments on commit 3a50d35

Please sign in to comment.