Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: enable streaming for VertexAIGeminiChatGenerator #1014

Merged
merged 8 commits into from
Aug 27, 2024
Merged
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
import logging
from typing import Any, Dict, List, Optional, Union
from typing import Any, Callable, Dict, List, Optional, Union

import vertexai
from haystack.core.component import component
from haystack.core.serialization import default_from_dict, default_to_dict
from haystack.dataclasses import StreamingChunk
from haystack.dataclasses.byte_stream import ByteStream
from haystack.dataclasses.chat_message import ChatMessage, ChatRole
from haystack.utils import deserialize_callable, serialize_callable
from vertexai import init as vertexai_init
from vertexai.preview.generative_models import (
Amnah199 marked this conversation as resolved.
Show resolved Hide resolved
Content,
FunctionDeclaration,
GenerationConfig,
GenerativeModel,
HarmBlockThreshold,
Expand Down Expand Up @@ -55,6 +56,7 @@ def __init__(
generation_config: Optional[Union[GenerationConfig, Dict[str, Any]]] = None,
safety_settings: Optional[Dict[HarmCategory, HarmBlockThreshold]] = None,
tools: Optional[List[Tool]] = None,
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
):
"""
`VertexAIGeminiChatGenerator` enables chat completion using Google Gemini models.
Expand All @@ -76,10 +78,13 @@ def __init__(
:param tools: List of tools to use when generating content. See the documentation for
[Tool](https://cloud.google.com/python/docs/reference/aiplatform/latest/vertexai.preview.generative_models.Tool)
the list of supported arguments.
:param streaming_callback: A callback function that is called when a new token is received from
the stream. The callback function accepts StreamingChunk as an argument.

"""

# Login to GCP. This will fail if user has not set up their gcloud SDK
vertexai.init(project=project_id, location=location)
vertexai_init(project=project_id, location=location)

self._model_name = model
self._project_id = project_id
Expand All @@ -89,18 +94,7 @@ def __init__(
self._generation_config = generation_config
self._safety_settings = safety_settings
self._tools = tools

def _function_to_dict(self, function: FunctionDeclaration) -> Dict[str, Any]:
return {
"name": function._raw_function_declaration.name,
"parameters": function._raw_function_declaration.parameters,
"description": function._raw_function_declaration.description,
}

def _tool_to_dict(self, tool: Tool) -> Dict[str, Any]:
return {
"function_declarations": [self._function_to_dict(f) for f in tool._raw_tool.function_declarations],
}
self._streaming_callback = streaming_callback

def _generation_config_to_dict(self, config: Union[GenerationConfig, Dict[str, Any]]) -> Dict[str, Any]:
if isinstance(config, dict):
Expand All @@ -121,6 +115,8 @@ def to_dict(self) -> Dict[str, Any]:
:returns:
Dictionary with serialized data.
"""
callback_name = serialize_callable(self._streaming_callback) if self._streaming_callback else None

data = default_to_dict(
self,
model=self._model_name,
Expand All @@ -129,9 +125,10 @@ def to_dict(self) -> Dict[str, Any]:
generation_config=self._generation_config,
safety_settings=self._safety_settings,
tools=self._tools,
streaming_callback=callback_name,
)
if (tools := data["init_parameters"].get("tools")) is not None:
data["init_parameters"]["tools"] = [self._tool_to_dict(t) for t in tools]
data["init_parameters"]["tools"] = [Tool.to_dict(t) for t in tools]
if (generation_config := data["init_parameters"].get("generation_config")) is not None:
data["init_parameters"]["generation_config"] = self._generation_config_to_dict(generation_config)
return data
Expand All @@ -150,7 +147,8 @@ def from_dict(cls, data: Dict[str, Any]) -> "VertexAIGeminiChatGenerator":
data["init_parameters"]["tools"] = [Tool.from_dict(t) for t in tools]
if (generation_config := data["init_parameters"].get("generation_config")) is not None:
data["init_parameters"]["generation_config"] = GenerationConfig.from_dict(generation_config)

if (serialized_callback_handler := data["init_parameters"].get("streaming_callback")) is not None:
data["init_parameters"]["streaming_callback"] = deserialize_callable(serialized_callback_handler)
return default_from_dict(cls, data)

def _convert_part(self, part: Union[str, ByteStream, Part]) -> Part:
Expand Down Expand Up @@ -195,13 +193,21 @@ def _message_to_content(self, message: ChatMessage) -> Content:
return Content(parts=[part], role=role)

@component.output_types(replies=List[ChatMessage])
def run(self, messages: List[ChatMessage]):
def run(
self,
messages: List[ChatMessage],
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
):
"""Prompts Google Vertex AI Gemini model to generate a response to a list of messages.

:param messages: The last message is the prompt, the rest are the history.
:param streaming_callback: A callback function that is called when a new token is received from the stream.
:returns: A dictionary with the following keys:
- `replies`: A list of ChatMessage objects representing the model's replies.
"""
# check if streaming_callback is passed
streaming_callback = streaming_callback or self._streaming_callback

history = [self._message_to_content(m) for m in messages[:-1]]
session = self._model.start_chat(history=history)

Expand All @@ -211,10 +217,22 @@ def run(self, messages: List[ChatMessage]):
generation_config=self._generation_config,
safety_settings=self._safety_settings,
tools=self._tools,
stream=streaming_callback is not None,
)

replies = self.get_stream_response(res, streaming_callback) if streaming_callback else self.get_response(res)

return {"replies": replies}

def get_response(self, response_body) -> List[ChatMessage]:
Amnah199 marked this conversation as resolved.
Show resolved Hide resolved
"""
Extracts the responses from the Vertex AI response.

:param response_body: The response from Vertex AI request.
:returns: The extracted responses.
"""
replies = []
for candidate in res.candidates:
for candidate in response_body.candidates:
for part in candidate.content.parts:
if part._raw_part.text != "":
replies.append(ChatMessage.from_system(part.text))
Expand All @@ -226,5 +244,21 @@ def run(self, messages: List[ChatMessage]):
name=part.function_call.name,
)
)
return replies

return {"replies": replies}
def get_stream_response(self, stream, streaming_callback: Callable[[StreamingChunk], None]) -> List[ChatMessage]:
Amnah199 marked this conversation as resolved.
Show resolved Hide resolved
"""
Extracts the responses from the Vertex AI streaming response.

:param stream: The streaming response from the Vertex AI request.
:param streaming_callback: The handler for the streaming response.
:returns: The extracted response with the content of all streaming chunks.
"""
responses = []
for chunk in stream:
streaming_chunk = StreamingChunk(content=chunk.text, meta=chunk.to_dict())
streaming_callback(streaming_chunk)
responses.append(streaming_chunk.content)

combined_response = "".join(responses).lstrip()
return [ChatMessage.from_system(content=combined_response)]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will be changed to assistant role in another PR. I am tracking it here.