diff --git a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py index f08a69b5f..21fa1f52f 100644 --- a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py +++ b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py @@ -1,15 +1,17 @@ import logging -from typing import Any, Dict, List, Optional, Union +from typing import Any, Callable, Dict, Iterable, List, Optional, Union -import vertexai from haystack.core.component import component from haystack.core.serialization import default_from_dict, default_to_dict +from haystack.dataclasses import StreamingChunk from haystack.dataclasses.byte_stream import ByteStream from haystack.dataclasses.chat_message import ChatMessage, ChatRole +from haystack.utils import deserialize_callable, serialize_callable +from vertexai import init as vertexai_init from vertexai.preview.generative_models import ( Content, - FunctionDeclaration, GenerationConfig, + GenerationResponse, GenerativeModel, HarmBlockThreshold, HarmCategory, @@ -55,6 +57,7 @@ def __init__( generation_config: Optional[Union[GenerationConfig, Dict[str, Any]]] = None, safety_settings: Optional[Dict[HarmCategory, HarmBlockThreshold]] = None, tools: Optional[List[Tool]] = None, + streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, ): """ `VertexAIGeminiChatGenerator` enables chat completion using Google Gemini models. @@ -76,10 +79,13 @@ def __init__( :param tools: List of tools to use when generating content. See the documentation for [Tool](https://cloud.google.com/python/docs/reference/aiplatform/latest/vertexai.preview.generative_models.Tool) the list of supported arguments. + :param streaming_callback: A callback function that is called when a new token is received from + the stream. The callback function accepts StreamingChunk as an argument. + """ # Login to GCP. This will fail if user has not set up their gcloud SDK - vertexai.init(project=project_id, location=location) + vertexai_init(project=project_id, location=location) self._model_name = model self._project_id = project_id @@ -89,18 +95,7 @@ def __init__( self._generation_config = generation_config self._safety_settings = safety_settings self._tools = tools - - def _function_to_dict(self, function: FunctionDeclaration) -> Dict[str, Any]: - return { - "name": function._raw_function_declaration.name, - "parameters": function._raw_function_declaration.parameters, - "description": function._raw_function_declaration.description, - } - - def _tool_to_dict(self, tool: Tool) -> Dict[str, Any]: - return { - "function_declarations": [self._function_to_dict(f) for f in tool._raw_tool.function_declarations], - } + self._streaming_callback = streaming_callback def _generation_config_to_dict(self, config: Union[GenerationConfig, Dict[str, Any]]) -> Dict[str, Any]: if isinstance(config, dict): @@ -121,6 +116,8 @@ def to_dict(self) -> Dict[str, Any]: :returns: Dictionary with serialized data. """ + callback_name = serialize_callable(self._streaming_callback) if self._streaming_callback else None + data = default_to_dict( self, model=self._model_name, @@ -129,9 +126,10 @@ def to_dict(self) -> Dict[str, Any]: generation_config=self._generation_config, safety_settings=self._safety_settings, tools=self._tools, + streaming_callback=callback_name, ) if (tools := data["init_parameters"].get("tools")) is not None: - data["init_parameters"]["tools"] = [self._tool_to_dict(t) for t in tools] + data["init_parameters"]["tools"] = [Tool.to_dict(t) for t in tools] if (generation_config := data["init_parameters"].get("generation_config")) is not None: data["init_parameters"]["generation_config"] = self._generation_config_to_dict(generation_config) return data @@ -150,7 +148,8 @@ def from_dict(cls, data: Dict[str, Any]) -> "VertexAIGeminiChatGenerator": data["init_parameters"]["tools"] = [Tool.from_dict(t) for t in tools] if (generation_config := data["init_parameters"].get("generation_config")) is not None: data["init_parameters"]["generation_config"] = GenerationConfig.from_dict(generation_config) - + if (serialized_callback_handler := data["init_parameters"].get("streaming_callback")) is not None: + data["init_parameters"]["streaming_callback"] = deserialize_callable(serialized_callback_handler) return default_from_dict(cls, data) def _convert_part(self, part: Union[str, ByteStream, Part]) -> Part: @@ -195,13 +194,21 @@ def _message_to_content(self, message: ChatMessage) -> Content: return Content(parts=[part], role=role) @component.output_types(replies=List[ChatMessage]) - def run(self, messages: List[ChatMessage]): + def run( + self, + messages: List[ChatMessage], + streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, + ): """Prompts Google Vertex AI Gemini model to generate a response to a list of messages. :param messages: The last message is the prompt, the rest are the history. + :param streaming_callback: A callback function that is called when a new token is received from the stream. :returns: A dictionary with the following keys: - `replies`: A list of ChatMessage objects representing the model's replies. """ + # check if streaming_callback is passed + streaming_callback = streaming_callback or self._streaming_callback + history = [self._message_to_content(m) for m in messages[:-1]] session = self._model.start_chat(history=history) @@ -211,10 +218,22 @@ def run(self, messages: List[ChatMessage]): generation_config=self._generation_config, safety_settings=self._safety_settings, tools=self._tools, + stream=streaming_callback is not None, ) + replies = self._get_stream_response(res, streaming_callback) if streaming_callback else self._get_response(res) + + return {"replies": replies} + + def _get_response(self, response_body: GenerationResponse) -> List[ChatMessage]: + """ + Extracts the responses from the Vertex AI response. + + :param response_body: The response from Vertex AI request. + :returns: The extracted responses. + """ replies = [] - for candidate in res.candidates: + for candidate in response_body.candidates: for part in candidate.content.parts: if part._raw_part.text != "": replies.append(ChatMessage.from_system(part.text)) @@ -226,5 +245,23 @@ def run(self, messages: List[ChatMessage]): name=part.function_call.name, ) ) + return replies - return {"replies": replies} + def _get_stream_response( + self, stream: Iterable[GenerationResponse], streaming_callback: Callable[[StreamingChunk], None] + ) -> List[ChatMessage]: + """ + Extracts the responses from the Vertex AI streaming response. + + :param stream: The streaming response from the Vertex AI request. + :param streaming_callback: The handler for the streaming response. + :returns: The extracted response with the content of all streaming chunks. + """ + responses = [] + for chunk in stream: + streaming_chunk = StreamingChunk(content=chunk.text, meta=chunk.to_dict()) + streaming_callback(streaming_chunk) + responses.append(streaming_chunk.content) + + combined_response = "".join(responses).lstrip() + return [ChatMessage.from_system(content=combined_response)]