diff --git a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/gemini.py b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/gemini.py index 8a288a315..ca03a83b1 100644 --- a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/gemini.py +++ b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/gemini.py @@ -1,15 +1,16 @@ import logging -from typing import Any, Dict, List, Optional, Union +from typing import Any, Callable, Dict, Iterable, List, Optional, Union -import vertexai from haystack.core.component import component from haystack.core.component.types import Variadic from haystack.core.serialization import default_from_dict, default_to_dict -from haystack.dataclasses.byte_stream import ByteStream -from vertexai.preview.generative_models import ( +from haystack.dataclasses import ByteStream, StreamingChunk +from haystack.utils import deserialize_callable, serialize_callable +from vertexai import init as vertexai_init +from vertexai.generative_models import ( Content, - FunctionDeclaration, GenerationConfig, + GenerationResponse, GenerativeModel, HarmBlockThreshold, HarmCategory, @@ -60,6 +61,7 @@ def __init__( generation_config: Optional[Union[GenerationConfig, Dict[str, Any]]] = None, safety_settings: Optional[Dict[HarmCategory, HarmBlockThreshold]] = None, tools: Optional[List[Tool]] = None, + streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, ): """ Multi-modal generator using Gemini model via Google Vertex AI. @@ -87,10 +89,12 @@ def __init__( :param tools: List of tools to use when generating content. See the documentation for [Tool](https://cloud.google.com/python/docs/reference/aiplatform/latest/vertexai.preview.generative_models.Tool) the list of supported arguments. + :param streaming_callback: A callback function that is called when a new token is received from the stream. + The callback function accepts StreamingChunk as an argument. """ # Login to GCP. This will fail if user has not set up their gcloud SDK - vertexai.init(project=project_id, location=location) + vertexai_init(project=project_id, location=location) self._model_name = model self._project_id = project_id @@ -100,18 +104,7 @@ def __init__( self._generation_config = generation_config self._safety_settings = safety_settings self._tools = tools - - def _function_to_dict(self, function: FunctionDeclaration) -> Dict[str, Any]: - return { - "name": function._raw_function_declaration.name, - "parameters": function._raw_function_declaration.parameters, - "description": function._raw_function_declaration.description, - } - - def _tool_to_dict(self, tool: Tool) -> Dict[str, Any]: - return { - "function_declarations": [self._function_to_dict(f) for f in tool._raw_tool.function_declarations], - } + self._streaming_callback = streaming_callback def _generation_config_to_dict(self, config: Union[GenerationConfig, Dict[str, Any]]) -> Dict[str, Any]: if isinstance(config, dict): @@ -132,6 +125,8 @@ def to_dict(self) -> Dict[str, Any]: :returns: Dictionary with serialized data. """ + + callback_name = serialize_callable(self._streaming_callback) if self._streaming_callback else None data = default_to_dict( self, model=self._model_name, @@ -140,9 +135,10 @@ def to_dict(self) -> Dict[str, Any]: generation_config=self._generation_config, safety_settings=self._safety_settings, tools=self._tools, + streaming_callback=callback_name, ) if (tools := data["init_parameters"].get("tools")) is not None: - data["init_parameters"]["tools"] = [self._tool_to_dict(t) for t in tools] + data["init_parameters"]["tools"] = [Tool.to_dict(t) for t in tools] if (generation_config := data["init_parameters"].get("generation_config")) is not None: data["init_parameters"]["generation_config"] = self._generation_config_to_dict(generation_config) return data @@ -161,7 +157,8 @@ def from_dict(cls, data: Dict[str, Any]) -> "VertexAIGeminiGenerator": data["init_parameters"]["tools"] = [Tool.from_dict(t) for t in tools] if (generation_config := data["init_parameters"].get("generation_config")) is not None: data["init_parameters"]["generation_config"] = GenerationConfig.from_dict(generation_config) - + if (serialized_callback_handler := data["init_parameters"].get("streaming_callback")) is not None: + data["init_parameters"]["streaming_callback"] = deserialize_callable(serialized_callback_handler) return default_from_dict(cls, data) def _convert_part(self, part: Union[str, ByteStream, Part]) -> Part: @@ -176,14 +173,21 @@ def _convert_part(self, part: Union[str, ByteStream, Part]) -> Part: raise ValueError(msg) @component.output_types(replies=List[Union[str, Dict[str, str]]]) - def run(self, parts: Variadic[Union[str, ByteStream, Part]]): + def run( + self, + parts: Variadic[Union[str, ByteStream, Part]], + streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, + ): """ Generates content using the Gemini model. :param parts: Prompt for the model. + :param streaming_callback: A callback function that is called when a new token is received from the stream. :returns: A dictionary with the following keys: - `replies`: A list of generated content. """ + # check if streaming_callback is passed + streaming_callback = streaming_callback or self._streaming_callback converted_parts = [self._convert_part(p) for p in parts] contents = [Content(parts=converted_parts, role="user")] @@ -192,10 +196,23 @@ def run(self, parts: Variadic[Union[str, ByteStream, Part]]): generation_config=self._generation_config, safety_settings=self._safety_settings, tools=self._tools, + stream=streaming_callback is not None, ) self._model.start_chat() + replies = self._get_stream_response(res, streaming_callback) if streaming_callback else self._get_response(res) + + return {"replies": replies} + + def _get_response(self, response_body: GenerationResponse) -> List[str]: + """ + Extracts the responses from the Vertex AI response. + + :param response_body: The response body from the Vertex AI request. + + :returns: A list of string responses. + """ replies = [] - for candidate in res.candidates: + for candidate in response_body.candidates: for part in candidate.content.parts: if part._raw_part.text != "": replies.append(part.text) @@ -205,5 +222,24 @@ def run(self, parts: Variadic[Union[str, ByteStream, Part]]): "args": dict(part.function_call.args.items()), } replies.append(function_call) + return replies - return {"replies": replies} + def _get_stream_response( + self, stream: Iterable[GenerationResponse], streaming_callback: Callable[[StreamingChunk], None] + ) -> List[str]: + """ + Extracts the responses from the Vertex AI streaming response. + + :param stream: The streaming response from the Vertex AI request. + :param streaming_callback: The handler for the streaming response. + :returns: A list of string responses. + """ + streaming_chunks: List[StreamingChunk] = [] + + for chunk in stream: + streaming_chunk = StreamingChunk(content=chunk.text, meta=chunk.to_dict()) + streaming_chunks.append(streaming_chunk) + streaming_callback(streaming_chunk) + + responses = ["".join(streaming_chunk.content for streaming_chunk in streaming_chunks).lstrip()] + return responses diff --git a/integrations/google_vertex/tests/test_gemini.py b/integrations/google_vertex/tests/test_gemini.py new file mode 100644 index 000000000..0704664c5 --- /dev/null +++ b/integrations/google_vertex/tests/test_gemini.py @@ -0,0 +1,256 @@ +from unittest.mock import MagicMock, Mock, patch + +from haystack.dataclasses import StreamingChunk +from vertexai.preview.generative_models import ( + FunctionDeclaration, + GenerationConfig, + HarmBlockThreshold, + HarmCategory, + Tool, +) + +from haystack_integrations.components.generators.google_vertex import VertexAIGeminiGenerator + +GET_CURRENT_WEATHER_FUNC = FunctionDeclaration( + name="get_current_weather", + description="Get the current weather in a given location", + parameters={ + "type_": "OBJECT", + "properties": { + "location": {"type_": "STRING", "description": "The city and state, e.g. San Francisco, CA"}, + "unit": { + "type_": "STRING", + "enum": [ + "celsius", + "fahrenheit", + ], + }, + }, + "required": ["location"], + }, +) + + +@patch("haystack_integrations.components.generators.google_vertex.gemini.vertexai_init") +@patch("haystack_integrations.components.generators.google_vertex.gemini.GenerativeModel") +def test_init(mock_vertexai_init, _mock_generative_model): + + generation_config = GenerationConfig( + candidate_count=1, + stop_sequences=["stop"], + max_output_tokens=10, + temperature=0.5, + top_p=0.5, + top_k=0.5, + ) + safety_settings = {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH} + + tool = Tool(function_declarations=[GET_CURRENT_WEATHER_FUNC]) + + gemini = VertexAIGeminiGenerator( + project_id="TestID123", + location="TestLocation", + generation_config=generation_config, + safety_settings=safety_settings, + tools=[tool], + ) + mock_vertexai_init.assert_called() + assert gemini._model_name == "gemini-pro-vision" + assert gemini._generation_config == generation_config + assert gemini._safety_settings == safety_settings + assert gemini._tools == [tool] + + +@patch("haystack_integrations.components.generators.google_vertex.gemini.vertexai_init") +@patch("haystack_integrations.components.generators.google_vertex.gemini.GenerativeModel") +def test_to_dict(_mock_vertexai_init, _mock_generative_model): + + gemini = VertexAIGeminiGenerator( + project_id="TestID123", + ) + assert gemini.to_dict() == { + "type": "haystack_integrations.components.generators.google_vertex.gemini.VertexAIGeminiGenerator", + "init_parameters": { + "model": "gemini-pro-vision", + "project_id": "TestID123", + "location": None, + "generation_config": None, + "safety_settings": None, + "streaming_callback": None, + "tools": None, + }, + } + + +@patch("haystack_integrations.components.generators.google_vertex.gemini.vertexai_init") +@patch("haystack_integrations.components.generators.google_vertex.gemini.GenerativeModel") +def test_to_dict_with_params(_mock_vertexai_init, _mock_generative_model): + generation_config = GenerationConfig( + candidate_count=1, + stop_sequences=["stop"], + max_output_tokens=10, + temperature=0.5, + top_p=0.5, + top_k=2, + ) + safety_settings = {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH} + + tool = Tool(function_declarations=[GET_CURRENT_WEATHER_FUNC]) + + gemini = VertexAIGeminiGenerator( + project_id="TestID123", + generation_config=generation_config, + safety_settings=safety_settings, + tools=[tool], + ) + assert gemini.to_dict() == { + "type": "haystack_integrations.components.generators.google_vertex.gemini.VertexAIGeminiGenerator", + "init_parameters": { + "model": "gemini-pro-vision", + "project_id": "TestID123", + "location": None, + "generation_config": { + "temperature": 0.5, + "top_p": 0.5, + "top_k": 2.0, + "candidate_count": 1, + "max_output_tokens": 10, + "stop_sequences": ["stop"], + }, + "safety_settings": {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH}, + "streaming_callback": None, + "tools": [ + { + "function_declarations": [ + { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type_": "OBJECT", + "properties": { + "location": { + "type_": "STRING", + "description": "The city and state, e.g. San Francisco, CA", + }, + "unit": {"type_": "STRING", "enum": ["celsius", "fahrenheit"]}, + }, + "required": ["location"], + }, + } + ] + } + ], + }, + } + + +@patch("haystack_integrations.components.generators.google_vertex.gemini.vertexai_init") +@patch("haystack_integrations.components.generators.google_vertex.gemini.GenerativeModel") +def test_from_dict(_mock_vertexai_init, _mock_generative_model): + gemini = VertexAIGeminiGenerator.from_dict( + { + "type": "haystack_integrations.components.generators.google_vertex.gemini.VertexAIGeminiGenerator", + "init_parameters": { + "project_id": "TestID123", + "model": "gemini-pro-vision", + "generation_config": None, + "safety_settings": None, + "tools": None, + "streaming_callback": None, + }, + } + ) + + assert gemini._model_name == "gemini-pro-vision" + assert gemini._project_id == "TestID123" + assert gemini._safety_settings is None + assert gemini._tools is None + assert gemini._generation_config is None + + +@patch("haystack_integrations.components.generators.google_vertex.gemini.vertexai_init") +@patch("haystack_integrations.components.generators.google_vertex.gemini.GenerativeModel") +def test_from_dict_with_param(_mock_vertexai_init, _mock_generative_model): + gemini = VertexAIGeminiGenerator.from_dict( + { + "type": "haystack_integrations.components.generators.google_vertex.gemini.VertexAIGeminiGenerator", + "init_parameters": { + "project_id": "TestID123", + "model": "gemini-pro-vision", + "generation_config": { + "temperature": 0.5, + "top_p": 0.5, + "top_k": 0.5, + "candidate_count": 1, + "max_output_tokens": 10, + "stop_sequences": ["stop"], + }, + "safety_settings": {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH}, + "tools": [ + { + "function_declarations": [ + { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type_": "OBJECT", + "properties": { + "location": { + "type_": "STRING", + "description": "The city and state, e.g. San Francisco, CA", + }, + "unit": {"type_": "STRING", "enum": ["celsius", "fahrenheit"]}, + }, + "required": ["location"], + }, + } + ] + } + ], + "streaming_callback": None, + }, + } + ) + + assert gemini._model_name == "gemini-pro-vision" + assert gemini._project_id == "TestID123" + assert gemini._safety_settings == {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH} + assert repr(gemini._tools) == repr([Tool(function_declarations=[GET_CURRENT_WEATHER_FUNC])]) + assert isinstance(gemini._generation_config, GenerationConfig) + + +@patch("haystack_integrations.components.generators.google_vertex.gemini.GenerativeModel") +def test_run(mock_generative_model): + mock_model = Mock() + mock_model.generate_content.return_value = MagicMock() + mock_generative_model.return_value = mock_model + + gemini = VertexAIGeminiGenerator(project_id="TestID123", location=None) + + response = gemini.run(["What's the weather like today?"]) + + mock_model.generate_content.assert_called_once() + assert "replies" in response + assert isinstance(response["replies"], list) + + +@patch("haystack_integrations.components.generators.google_vertex.gemini.GenerativeModel") +def test_run_with_streaming_callback(mock_generative_model): + mock_model = Mock() + mock_stream = [ + MagicMock(text="First part", usage_metadata={}), + MagicMock(text="Second part", usage_metadata={}), + ] + + mock_model.generate_content.return_value = mock_stream + mock_generative_model.return_value = mock_model + + streaming_callback_called = False + + def streaming_callback(_chunk: StreamingChunk) -> None: + nonlocal streaming_callback_called + streaming_callback_called = True + + gemini = VertexAIGeminiGenerator(model="gemini-pro", project_id="TestID123", streaming_callback=streaming_callback) + gemini.run(["Come on, stream!"]) + assert streaming_callback_called