From 352d2b877a998c2ea7fb01ee3d863f8d5e678914 Mon Sep 17 00:00:00 2001 From: Amna Mubashar Date: Tue, 20 Aug 2024 16:23:19 +0200 Subject: [PATCH 01/17] Add test_gemini.py and enable streaming --- .../generators/google_vertex/gemini.py | 50 ++++- .../google_vertex/tests/test_gemini.py | 189 ++++++++++++++++++ 2 files changed, 233 insertions(+), 6 deletions(-) create mode 100644 integrations/google_vertex/tests/test_gemini.py diff --git a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/gemini.py b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/gemini.py index 8a288a315..a7441b991 100644 --- a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/gemini.py +++ b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/gemini.py @@ -1,11 +1,11 @@ import logging -from typing import Any, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Union import vertexai from haystack.core.component import component from haystack.core.component.types import Variadic from haystack.core.serialization import default_from_dict, default_to_dict -from haystack.dataclasses.byte_stream import ByteStream +from haystack.dataclasses import ByteStream, StreamingChunk from vertexai.preview.generative_models import ( Content, FunctionDeclaration, @@ -60,6 +60,7 @@ def __init__( generation_config: Optional[Union[GenerationConfig, Dict[str, Any]]] = None, safety_settings: Optional[Dict[HarmCategory, HarmBlockThreshold]] = None, tools: Optional[List[Tool]] = None, + streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, ): """ Multi-modal generator using Gemini model via Google Vertex AI. @@ -87,6 +88,8 @@ def __init__( :param tools: List of tools to use when generating content. See the documentation for [Tool](https://cloud.google.com/python/docs/reference/aiplatform/latest/vertexai.preview.generative_models.Tool) the list of supported arguments. + :param streaming_callback: A callback function that is called when a new token is received from the stream. + The callback function accepts StreamingChunk as an argument. """ # Login to GCP. This will fail if user has not set up their gcloud SDK @@ -100,6 +103,7 @@ def __init__( self._generation_config = generation_config self._safety_settings = safety_settings self._tools = tools + self._streaming_callback = streaming_callback def _function_to_dict(self, function: FunctionDeclaration) -> Dict[str, Any]: return { @@ -140,9 +144,10 @@ def to_dict(self) -> Dict[str, Any]: generation_config=self._generation_config, safety_settings=self._safety_settings, tools=self._tools, + streaming_callback=self._streaming_callback, ) if (tools := data["init_parameters"].get("tools")) is not None: - data["init_parameters"]["tools"] = [self._tool_to_dict(t) for t in tools] + data["init_parameters"]["tools"] = [Tool.to_dict(t) for t in tools] if (generation_config := data["init_parameters"].get("generation_config")) is not None: data["init_parameters"]["generation_config"] = self._generation_config_to_dict(generation_config) return data @@ -161,7 +166,6 @@ def from_dict(cls, data: Dict[str, Any]) -> "VertexAIGeminiGenerator": data["init_parameters"]["tools"] = [Tool.from_dict(t) for t in tools] if (generation_config := data["init_parameters"].get("generation_config")) is not None: data["init_parameters"]["generation_config"] = GenerationConfig.from_dict(generation_config) - return default_from_dict(cls, data) def _convert_part(self, part: Union[str, ByteStream, Part]) -> Part: @@ -192,10 +196,27 @@ def run(self, parts: Variadic[Union[str, ByteStream, Part]]): generation_config=self._generation_config, safety_settings=self._safety_settings, tools=self._tools, + stream=self._streaming_callback is not None, ) self._model.start_chat() + replies = ( + self.get_stream_responses(res, self._streaming_callback) + if self._streaming_callback + else self.get_response(res) + ) + + return {"replies": replies} + + def get_response(self, response_body) -> List[str]: + """ + Extracts the responses from the Vertex AI response. + + :param response_body: The response body from the Amazon Bedrock request. + + :returns: A list of string responses. + """ replies = [] - for candidate in res.candidates: + for candidate in response_body.candidates: for part in candidate.content.parts: if part._raw_part.text != "": replies.append(part.text) @@ -205,5 +226,22 @@ def run(self, parts: Variadic[Union[str, ByteStream, Part]]): "args": dict(part.function_call.args.items()), } replies.append(function_call) + return replies - return {"replies": replies} + def get_stream_responses(self, stream, streaming_callback: Callable[[StreamingChunk], None]) -> List[str]: + """ + Extracts the responses from the Vertex AI streaming response. + + :param stream: The streaming response from the Vertex AI request. + :param streaming_callback: The handler for the streaming response. + :returns: A list of string responses. + """ + streaming_chunks: List[StreamingChunk] = [] + + for chunk in stream: + streaming_chunk = StreamingChunk(content=chunk.text, meta=chunk.usage_metadata) + streaming_chunks.append(streaming_chunk) + streaming_callback(streaming_chunk) + + responses = ["".join(streaming_chunk.content for streaming_chunk in streaming_chunks).lstrip()] + return responses diff --git a/integrations/google_vertex/tests/test_gemini.py b/integrations/google_vertex/tests/test_gemini.py new file mode 100644 index 000000000..71e589a11 --- /dev/null +++ b/integrations/google_vertex/tests/test_gemini.py @@ -0,0 +1,189 @@ +from unittest.mock import MagicMock, Mock, patch + +from vertexai.preview.generative_models import ( + FunctionDeclaration, + GenerationConfig, + GenerativeModel, + HarmBlockThreshold, + HarmCategory, + Tool, +) + +from haystack_integrations.components.generators.google_vertex import VertexAIGeminiGenerator + +GET_CURRENT_WEATHER_FUNC = FunctionDeclaration( + name="get_current_weather", + description="Get the current weather in a given location", + parameters={ + "type_": "OBJECT", + "properties": { + "location": {"type_": "STRING", "description": "The city and state, e.g. San Francisco, CA"}, + "unit": { + "type_": "STRING", + "enum": [ + "celsius", + "fahrenheit", + ], + }, + }, + "required": ["location"], + }, +) + + +def test_init(): + + generation_config = GenerationConfig( + candidate_count=1, + stop_sequences=["stop"], + max_output_tokens=10, + temperature=0.5, + top_p=0.5, + top_k=0.5, + ) + safety_settings = {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH} + + tool = Tool(function_declarations=[GET_CURRENT_WEATHER_FUNC]) + with patch( + "haystack_integrations.components.generators.google_vertex.gemini.vertexai.init" + ) as mock_genai_configure: + gemini = VertexAIGeminiGenerator( + project_id="TestID123", + location="TestLocation", + generation_config=generation_config, + safety_settings=safety_settings, + tools=[tool], + ) + mock_genai_configure.assert_called() + assert gemini._model_name == "gemini-pro-vision" + assert gemini._generation_config == generation_config + assert gemini._safety_settings == safety_settings + assert gemini._tools == [tool] + assert isinstance(gemini._model, GenerativeModel) + + +def test_to_dict(): + generation_config = GenerationConfig( + candidate_count=1, + stop_sequences=["stop"], + max_output_tokens=10, + temperature=0.5, + top_p=0.5, + top_k=2, + ) + safety_settings = {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH} + + tool = Tool(function_declarations=[GET_CURRENT_WEATHER_FUNC]) + + with patch("haystack_integrations.components.generators.google_vertex.gemini.vertexai.init"): + gemini = VertexAIGeminiGenerator( + project_id="TestID123", + generation_config=generation_config, + safety_settings=safety_settings, + tools=[tool], + ) + assert gemini.to_dict() == { + "type": "haystack_integrations.components.generators.google_vertex.gemini.VertexAIGeminiGenerator", + "init_parameters": { + "model": "gemini-pro-vision", + "project_id": "TestID123", + "location": None, + "generation_config": { + "temperature": 0.5, + "top_p": 0.5, + "top_k": 2.0, + "candidate_count": 1, + "max_output_tokens": 10, + "stop_sequences": ["stop"], + }, + "safety_settings": {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH}, + "streaming_callback": None, + "tools": [ + { + "function_declarations": [ + { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type_": "OBJECT", + "properties": { + "location": { + "type_": "STRING", + "description": "The city and state, e.g. San Francisco, CA", + }, + "unit": {"type_": "STRING", "enum": ["celsius", "fahrenheit"]}, + }, + "required": ["location"], + }, + } + ] + } + ], + }, + } + + +def test_from_dict(): + + with patch("haystack_integrations.components.generators.google_vertex.gemini.vertexai.init"): + gemini = VertexAIGeminiGenerator.from_dict( + { + "type": "haystack_integrations.components.generators.google_vertex.gemini.VertexAIGeminiGenerator", + "init_parameters": { + "project_id": "TestID123", + "model": "gemini-pro-vision", + "generation_config": { + "temperature": 0.5, + "top_p": 0.5, + "top_k": 0.5, + "candidate_count": 1, + "max_output_tokens": 10, + "stop_sequences": ["stop"], + }, + "safety_settings": { + HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH + }, + "tools": [ + { + "function_declarations": [ + { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type_": "OBJECT", + "properties": { + "location": { + "type_": "STRING", + "description": "The city and state, e.g. San Francisco, CA", + }, + "unit": {"type_": "STRING", "enum": ["celsius", "fahrenheit"]}, + }, + "required": ["location"], + }, + } + ] + } + ], + "streaming_callback": None, + }, + } + ) + + assert gemini._model_name == "gemini-pro-vision" + assert gemini._project_id == "TestID123" + assert gemini._safety_settings == {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH} + + # assert gemini._tools == [Tool(function_declarations=[GET_CURRENT_WEATHER_FUNC])] + assert isinstance(gemini._generation_config, GenerationConfig) + assert isinstance(gemini._model, GenerativeModel) + + +@patch("haystack_integrations.components.generators.google_vertex.gemini.vertexai.init") +@patch("haystack_integrations.components.generators.google_vertex.gemini.VertexAIGeminiGenerator") +def test_run(mock_model_class, _mock_vertexai_init): + mock_model = Mock() + mock_model.predict.return_value = MagicMock() + mock_model_class.from_pretrained.return_value = mock_model + VertexAIGeminiGenerator(model="gemini-pro", project_id="TestID123") + + _mock_vertexai_init.assert_called_once_with(project="TestID123", location=None) From f4831e7a66453145162140d65ba8bbdfb98090e3 Mon Sep 17 00:00:00 2001 From: Amna Mubashar Date: Wed, 21 Aug 2024 10:51:50 +0200 Subject: [PATCH 02/17] Fix mock in test file --- .../google_vertex/tests/test_gemini.py | 152 +++++++++--------- 1 file changed, 78 insertions(+), 74 deletions(-) diff --git a/integrations/google_vertex/tests/test_gemini.py b/integrations/google_vertex/tests/test_gemini.py index 71e589a11..2f7fa5db9 100644 --- a/integrations/google_vertex/tests/test_gemini.py +++ b/integrations/google_vertex/tests/test_gemini.py @@ -3,7 +3,6 @@ from vertexai.preview.generative_models import ( FunctionDeclaration, GenerationConfig, - GenerativeModel, HarmBlockThreshold, HarmCategory, Tool, @@ -30,8 +29,9 @@ }, ) - -def test_init(): +@patch("haystack_integrations.components.generators.google_vertex.gemini.vertexai.init") +@patch("haystack_integrations.components.generators.google_vertex.gemini.GenerativeModel") +def test_init(_mock_vertexai_init, mock_generative_model): generation_config = GenerationConfig( candidate_count=1, @@ -44,25 +44,24 @@ def test_init(): safety_settings = {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH} tool = Tool(function_declarations=[GET_CURRENT_WEATHER_FUNC]) - with patch( - "haystack_integrations.components.generators.google_vertex.gemini.vertexai.init" - ) as mock_genai_configure: - gemini = VertexAIGeminiGenerator( - project_id="TestID123", - location="TestLocation", - generation_config=generation_config, - safety_settings=safety_settings, - tools=[tool], - ) - mock_genai_configure.assert_called() + + gemini = VertexAIGeminiGenerator( + project_id="TestID123", + location="TestLocation", + generation_config=generation_config, + safety_settings=safety_settings, + tools=[tool], + ) + _mock_vertexai_init.assert_called() assert gemini._model_name == "gemini-pro-vision" assert gemini._generation_config == generation_config assert gemini._safety_settings == safety_settings assert gemini._tools == [tool] - assert isinstance(gemini._model, GenerativeModel) +@patch("haystack_integrations.components.generators.google_vertex.gemini.vertexai") +@patch("haystack_integrations.components.generators.google_vertex.gemini.GenerativeModel") -def test_to_dict(): +def test_to_dict(_mock_vertexai, _mock_generative_model): generation_config = GenerationConfig( candidate_count=1, stop_sequences=["stop"], @@ -75,13 +74,12 @@ def test_to_dict(): tool = Tool(function_declarations=[GET_CURRENT_WEATHER_FUNC]) - with patch("haystack_integrations.components.generators.google_vertex.gemini.vertexai.init"): - gemini = VertexAIGeminiGenerator( - project_id="TestID123", - generation_config=generation_config, - safety_settings=safety_settings, - tools=[tool], - ) + gemini = VertexAIGeminiGenerator( + project_id="TestID123", + generation_config=generation_config, + safety_settings=safety_settings, + tools=[tool], + ) assert gemini.to_dict() == { "type": "haystack_integrations.components.generators.google_vertex.gemini.VertexAIGeminiGenerator", "init_parameters": { @@ -123,51 +121,52 @@ def test_to_dict(): } -def test_from_dict(): - - with patch("haystack_integrations.components.generators.google_vertex.gemini.vertexai.init"): - gemini = VertexAIGeminiGenerator.from_dict( - { - "type": "haystack_integrations.components.generators.google_vertex.gemini.VertexAIGeminiGenerator", - "init_parameters": { - "project_id": "TestID123", - "model": "gemini-pro-vision", - "generation_config": { - "temperature": 0.5, - "top_p": 0.5, - "top_k": 0.5, - "candidate_count": 1, - "max_output_tokens": 10, - "stop_sequences": ["stop"], - }, - "safety_settings": { - HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH - }, - "tools": [ - { - "function_declarations": [ - { - "name": "get_current_weather", - "description": "Get the current weather in a given location", - "parameters": { - "type_": "OBJECT", - "properties": { - "location": { - "type_": "STRING", - "description": "The city and state, e.g. San Francisco, CA", - }, - "unit": {"type_": "STRING", "enum": ["celsius", "fahrenheit"]}, +@patch("haystack_integrations.components.generators.google_vertex.gemini.vertexai") +@patch("haystack_integrations.components.generators.google_vertex.gemini.GenerativeModel") + +def test_from_dict(_mock_vertexai, _mock_generative_model): + gemini = VertexAIGeminiGenerator.from_dict( + { + "type": "haystack_integrations.components.generators.google_vertex.gemini.VertexAIGeminiGenerator", + "init_parameters": { + "project_id": "TestID123", + "model": "gemini-pro-vision", + "generation_config": { + "temperature": 0.5, + "top_p": 0.5, + "top_k": 0.5, + "candidate_count": 1, + "max_output_tokens": 10, + "stop_sequences": ["stop"], + }, + "safety_settings": { + HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH + }, + "tools": [ + { + "function_declarations": [ + { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type_": "OBJECT", + "properties": { + "location": { + "type_": "STRING", + "description": "The city and state, e.g. San Francisco, CA", }, - "required": ["location"], + "unit": {"type_": "STRING", "enum": ["celsius", "fahrenheit"]}, }, - } - ] - } - ], - "streaming_callback": None, - }, - } - ) + "required": ["location"], + }, + } + ] + } + ], + "streaming_callback": None, + }, + } + ) assert gemini._model_name == "gemini-pro-vision" assert gemini._project_id == "TestID123" @@ -175,15 +174,20 @@ def test_from_dict(): # assert gemini._tools == [Tool(function_declarations=[GET_CURRENT_WEATHER_FUNC])] assert isinstance(gemini._generation_config, GenerationConfig) - assert isinstance(gemini._model, GenerativeModel) -@patch("haystack_integrations.components.generators.google_vertex.gemini.vertexai.init") -@patch("haystack_integrations.components.generators.google_vertex.gemini.VertexAIGeminiGenerator") -def test_run(mock_model_class, _mock_vertexai_init): +@patch("haystack_integrations.components.generators.google_vertex.gemini.GenerativeModel") +def test_run(mock_generative_model): mock_model = Mock() - mock_model.predict.return_value = MagicMock() - mock_model_class.from_pretrained.return_value = mock_model - VertexAIGeminiGenerator(model="gemini-pro", project_id="TestID123") + mock_model.generate_content.return_value = MagicMock() + mock_generative_model.return_value = mock_model + + gemini = VertexAIGeminiGenerator(project_id="TestID123", location=None) + + response = gemini.run(["What's the weather like today?"]) + + mock_model.generate_content.assert_called_once() + assert "replies" in response + assert isinstance(response["replies"], list) + - _mock_vertexai_init.assert_called_once_with(project="TestID123", location=None) From b56470798d2a88794459e68d484e0ffc1d87da45 Mon Sep 17 00:00:00 2001 From: Amna Mubashar Date: Wed, 21 Aug 2024 12:14:28 +0200 Subject: [PATCH 03/17] Add test for streaming_callback --- .../generators/google_vertex/gemini.py | 4 ++-- .../google_vertex/tests/test_gemini.py | 22 +++++++++++++++++++ 2 files changed, 24 insertions(+), 2 deletions(-) diff --git a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/gemini.py b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/gemini.py index a7441b991..6a415b585 100644 --- a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/gemini.py +++ b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/gemini.py @@ -200,7 +200,7 @@ def run(self, parts: Variadic[Union[str, ByteStream, Part]]): ) self._model.start_chat() replies = ( - self.get_stream_responses(res, self._streaming_callback) + self.get_stream_response(res, self._streaming_callback) if self._streaming_callback else self.get_response(res) ) @@ -228,7 +228,7 @@ def get_response(self, response_body) -> List[str]: replies.append(function_call) return replies - def get_stream_responses(self, stream, streaming_callback: Callable[[StreamingChunk], None]) -> List[str]: + def get_stream_response(self, stream, streaming_callback: Callable[[StreamingChunk], None]) -> List[str]: """ Extracts the responses from the Vertex AI streaming response. diff --git a/integrations/google_vertex/tests/test_gemini.py b/integrations/google_vertex/tests/test_gemini.py index 2f7fa5db9..354b9d08a 100644 --- a/integrations/google_vertex/tests/test_gemini.py +++ b/integrations/google_vertex/tests/test_gemini.py @@ -1,5 +1,6 @@ from unittest.mock import MagicMock, Mock, patch +from haystack.dataclasses import StreamingChunk from vertexai.preview.generative_models import ( FunctionDeclaration, GenerationConfig, @@ -191,3 +192,24 @@ def test_run(mock_generative_model): assert isinstance(response["replies"], list) +@patch("haystack_integrations.components.generators.google_vertex.gemini.GenerativeModel") +def test_run_with_streaming_callback(mock_generative_model): + mock_model = Mock() + mock_stream = [ + MagicMock(text="First part", usage_metadata={}), + MagicMock(text="Second part", usage_metadata={}), + ] + + mock_model.generate_content.return_value = mock_stream + mock_generative_model.return_value = mock_model + + streaming_callback_called = False + + def streaming_callback(chunk: StreamingChunk) -> None: + nonlocal streaming_callback_called + streaming_callback_called = True + + gemini = VertexAIGeminiGenerator(model="gemini-pro", project_id="TestID123", streaming_callback=streaming_callback) + gemini.run(["Come on, stream!"]) + assert streaming_callback_called + From f8368da47452750865b75e870a82d185625782ed Mon Sep 17 00:00:00 2001 From: Amna Mubashar Date: Wed, 21 Aug 2024 12:28:35 +0200 Subject: [PATCH 04/17] Fix linting --- integrations/google_vertex/tests/test_gemini.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/integrations/google_vertex/tests/test_gemini.py b/integrations/google_vertex/tests/test_gemini.py index 354b9d08a..658ea9f8b 100644 --- a/integrations/google_vertex/tests/test_gemini.py +++ b/integrations/google_vertex/tests/test_gemini.py @@ -30,9 +30,10 @@ }, ) + @patch("haystack_integrations.components.generators.google_vertex.gemini.vertexai.init") @patch("haystack_integrations.components.generators.google_vertex.gemini.GenerativeModel") -def test_init(_mock_vertexai_init, mock_generative_model): +def test_init(mock_vertexai_init, _mock_generative_model): generation_config = GenerationConfig( candidate_count=1, @@ -53,15 +54,15 @@ def test_init(_mock_vertexai_init, mock_generative_model): safety_settings=safety_settings, tools=[tool], ) - _mock_vertexai_init.assert_called() + mock_vertexai_init.assert_called() assert gemini._model_name == "gemini-pro-vision" assert gemini._generation_config == generation_config assert gemini._safety_settings == safety_settings assert gemini._tools == [tool] + @patch("haystack_integrations.components.generators.google_vertex.gemini.vertexai") @patch("haystack_integrations.components.generators.google_vertex.gemini.GenerativeModel") - def test_to_dict(_mock_vertexai, _mock_generative_model): generation_config = GenerationConfig( candidate_count=1, @@ -124,7 +125,6 @@ def test_to_dict(_mock_vertexai, _mock_generative_model): @patch("haystack_integrations.components.generators.google_vertex.gemini.vertexai") @patch("haystack_integrations.components.generators.google_vertex.gemini.GenerativeModel") - def test_from_dict(_mock_vertexai, _mock_generative_model): gemini = VertexAIGeminiGenerator.from_dict( { @@ -140,9 +140,7 @@ def test_from_dict(_mock_vertexai, _mock_generative_model): "max_output_tokens": 10, "stop_sequences": ["stop"], }, - "safety_settings": { - HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH - }, + "safety_settings": {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH}, "tools": [ { "function_declarations": [ @@ -205,11 +203,10 @@ def test_run_with_streaming_callback(mock_generative_model): streaming_callback_called = False - def streaming_callback(chunk: StreamingChunk) -> None: + def streaming_callback(_chunk: StreamingChunk) -> None: nonlocal streaming_callback_called streaming_callback_called = True gemini = VertexAIGeminiGenerator(model="gemini-pro", project_id="TestID123", streaming_callback=streaming_callback) gemini.run(["Come on, stream!"]) assert streaming_callback_called - From e09fa09565d54f8e150f9c1bbba3b5068d9d21a0 Mon Sep 17 00:00:00 2001 From: Amna Mubashar Date: Wed, 21 Aug 2024 13:33:36 +0200 Subject: [PATCH 05/17] Remove extra functions --- .../generators/google_vertex/gemini.py | 17 ++--------------- integrations/google_vertex/tests/test_gemini.py | 2 +- 2 files changed, 3 insertions(+), 16 deletions(-) diff --git a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/gemini.py b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/gemini.py index 6a415b585..281e33c67 100644 --- a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/gemini.py +++ b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/gemini.py @@ -8,7 +8,6 @@ from haystack.dataclasses import ByteStream, StreamingChunk from vertexai.preview.generative_models import ( Content, - FunctionDeclaration, GenerationConfig, GenerativeModel, HarmBlockThreshold, @@ -105,18 +104,6 @@ def __init__( self._tools = tools self._streaming_callback = streaming_callback - def _function_to_dict(self, function: FunctionDeclaration) -> Dict[str, Any]: - return { - "name": function._raw_function_declaration.name, - "parameters": function._raw_function_declaration.parameters, - "description": function._raw_function_declaration.description, - } - - def _tool_to_dict(self, tool: Tool) -> Dict[str, Any]: - return { - "function_declarations": [self._function_to_dict(f) for f in tool._raw_tool.function_declarations], - } - def _generation_config_to_dict(self, config: Union[GenerationConfig, Dict[str, Any]]) -> Dict[str, Any]: if isinstance(config, dict): return config @@ -207,7 +194,7 @@ def run(self, parts: Variadic[Union[str, ByteStream, Part]]): return {"replies": replies} - def get_response(self, response_body) -> List[str]: + def get_response(self, response_body: List[str]) -> List[str]: """ Extracts the responses from the Vertex AI response. @@ -228,7 +215,7 @@ def get_response(self, response_body) -> List[str]: replies.append(function_call) return replies - def get_stream_response(self, stream, streaming_callback: Callable[[StreamingChunk], None]) -> List[str]: + def get_stream_response(self, stream: List[str], streaming_callback: Callable[[StreamingChunk], None]) -> List[str]: """ Extracts the responses from the Vertex AI streaming response. diff --git a/integrations/google_vertex/tests/test_gemini.py b/integrations/google_vertex/tests/test_gemini.py index 658ea9f8b..aa4629733 100644 --- a/integrations/google_vertex/tests/test_gemini.py +++ b/integrations/google_vertex/tests/test_gemini.py @@ -171,7 +171,7 @@ def test_from_dict(_mock_vertexai, _mock_generative_model): assert gemini._project_id == "TestID123" assert gemini._safety_settings == {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH} - # assert gemini._tools == [Tool(function_declarations=[GET_CURRENT_WEATHER_FUNC])] + #assert gemini._tools == [Tool(function_declarations=[GET_CURRENT_WEATHER_FUNC])] assert isinstance(gemini._generation_config, GenerationConfig) From 9ec32a5441f660926e490db0b95ef20a55ede9b6 Mon Sep 17 00:00:00 2001 From: Amna Mubashar Date: Wed, 21 Aug 2024 13:39:01 +0200 Subject: [PATCH 06/17] Update test_gemini.py --- integrations/google_vertex/tests/test_gemini.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/integrations/google_vertex/tests/test_gemini.py b/integrations/google_vertex/tests/test_gemini.py index aa4629733..658ea9f8b 100644 --- a/integrations/google_vertex/tests/test_gemini.py +++ b/integrations/google_vertex/tests/test_gemini.py @@ -171,7 +171,7 @@ def test_from_dict(_mock_vertexai, _mock_generative_model): assert gemini._project_id == "TestID123" assert gemini._safety_settings == {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH} - #assert gemini._tools == [Tool(function_declarations=[GET_CURRENT_WEATHER_FUNC])] + # assert gemini._tools == [Tool(function_declarations=[GET_CURRENT_WEATHER_FUNC])] assert isinstance(gemini._generation_config, GenerationConfig) From 7acaf5a5b078d5a2b771038675ee0d662e2002f7 Mon Sep 17 00:00:00 2001 From: Amna Mubashar Date: Wed, 21 Aug 2024 13:45:04 +0200 Subject: [PATCH 07/17] Type fixing --- .../components/generators/google_vertex/gemini.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/gemini.py b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/gemini.py index 281e33c67..7e012e315 100644 --- a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/gemini.py +++ b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/gemini.py @@ -194,7 +194,7 @@ def run(self, parts: Variadic[Union[str, ByteStream, Part]]): return {"replies": replies} - def get_response(self, response_body: List[str]) -> List[str]: + def get_response(self, response_body) -> List[str]: """ Extracts the responses from the Vertex AI response. @@ -215,7 +215,7 @@ def get_response(self, response_body: List[str]) -> List[str]: replies.append(function_call) return replies - def get_stream_response(self, stream: List[str], streaming_callback: Callable[[StreamingChunk], None]) -> List[str]: + def get_stream_response(self, stream, streaming_callback: Callable[[StreamingChunk], None]) -> List[str]: """ Extracts the responses from the Vertex AI streaming response. From 529e99baf2986e0b8e759eff79b18757486670fa Mon Sep 17 00:00:00 2001 From: Amna Mubashar Date: Wed, 21 Aug 2024 17:14:01 +0200 Subject: [PATCH 08/17] Fix small error --- .../components/generators/google_vertex/gemini.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/gemini.py b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/gemini.py index 7e012e315..29b136f0f 100644 --- a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/gemini.py +++ b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/gemini.py @@ -198,7 +198,7 @@ def get_response(self, response_body) -> List[str]: """ Extracts the responses from the Vertex AI response. - :param response_body: The response body from the Amazon Bedrock request. + :param response_body: The response body from the Vertex AI request. :returns: A list of string responses. """ From be39977d3153742a31e56c00e8360bfe07be0550 Mon Sep 17 00:00:00 2001 From: HaystackBot Date: Wed, 21 Aug 2024 12:00:51 +0000 Subject: [PATCH 09/17] Update the changelog --- integrations/amazon_bedrock/CHANGELOG.md | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/integrations/amazon_bedrock/CHANGELOG.md b/integrations/amazon_bedrock/CHANGELOG.md index 631978058..417c661fe 100644 --- a/integrations/amazon_bedrock/CHANGELOG.md +++ b/integrations/amazon_bedrock/CHANGELOG.md @@ -1,5 +1,15 @@ # Changelog +## [unreleased] + +### ๐Ÿ› Bug Fixes + +- *(Bedrock)* Allow tools kwargs for AWS Bedrock Claude model (#976) + +### ๐Ÿšœ Refactor + +- Remove usage of deprecated `ChatMessage.to_openai_format` (#1007) + ## [integrations/amazon_bedrock-v1.0.1] - 2024-08-19 ### ๐Ÿš€ Features From a1fbe3a82388069ad253f68a2b3aee337b9816ab Mon Sep 17 00:00:00 2001 From: HaystackBot Date: Wed, 21 Aug 2024 15:46:51 +0000 Subject: [PATCH 10/17] Update the changelog --- integrations/pinecone/CHANGELOG.md | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/integrations/pinecone/CHANGELOG.md b/integrations/pinecone/CHANGELOG.md index 317753192..d9d12505e 100644 --- a/integrations/pinecone/CHANGELOG.md +++ b/integrations/pinecone/CHANGELOG.md @@ -1,5 +1,26 @@ # Changelog +## [unreleased] + +### ๐Ÿš€ Features + +- Add filter_policy to pinecone integration (#821) + +### ๐Ÿ› Bug Fixes + +- `pinecone` - Fallback to default filter policy when deserializing retrievers without the init parameter (#901) +- Skip unsupported meta fields in PineconeDB (#1009) + +### ๐Ÿงช Testing + +- Pinecone - fix `test_serverless_index_creation_from_scratch` (#806) +- Do not retry tests in `hatch run test` command (#954) + +### โš™๏ธ Miscellaneous Tasks + +- Retry tests to reduce flakyness (#836) +- Update ruff invocation to include check parameter (#853) + ## [integrations/pinecone-v1.1.0] - 2024-06-11 ### ๐Ÿš€ Features From 42e76bf929de74c8b2952bab7d16ced00ec445dc Mon Sep 17 00:00:00 2001 From: Amna Mubashar Date: Thu, 22 Aug 2024 13:05:57 +0200 Subject: [PATCH 11/17] Add streaming_callback to run --- .../generators/google_vertex/gemini.py | 24 ++++++++++++------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/gemini.py b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/gemini.py index 29b136f0f..5df622e19 100644 --- a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/gemini.py +++ b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/gemini.py @@ -6,6 +6,7 @@ from haystack.core.component.types import Variadic from haystack.core.serialization import default_from_dict, default_to_dict from haystack.dataclasses import ByteStream, StreamingChunk +from haystack.utils import deserialize_callable, serialize_callable from vertexai.preview.generative_models import ( Content, GenerationConfig, @@ -123,6 +124,8 @@ def to_dict(self) -> Dict[str, Any]: :returns: Dictionary with serialized data. """ + + callback_name = serialize_callable(self._streaming_callback) if self._streaming_callback else None data = default_to_dict( self, model=self._model_name, @@ -131,7 +134,7 @@ def to_dict(self) -> Dict[str, Any]: generation_config=self._generation_config, safety_settings=self._safety_settings, tools=self._tools, - streaming_callback=self._streaming_callback, + streaming_callback=callback_name, ) if (tools := data["init_parameters"].get("tools")) is not None: data["init_parameters"]["tools"] = [Tool.to_dict(t) for t in tools] @@ -153,6 +156,8 @@ def from_dict(cls, data: Dict[str, Any]) -> "VertexAIGeminiGenerator": data["init_parameters"]["tools"] = [Tool.from_dict(t) for t in tools] if (generation_config := data["init_parameters"].get("generation_config")) is not None: data["init_parameters"]["generation_config"] = GenerationConfig.from_dict(generation_config) + if (serialized_callback_handler := data["init_parameters"].get("streaming_callback")) is not None: + data["init_parameters"]["streaming_callback"] = deserialize_callable(serialized_callback_handler) return default_from_dict(cls, data) def _convert_part(self, part: Union[str, ByteStream, Part]) -> Part: @@ -167,14 +172,21 @@ def _convert_part(self, part: Union[str, ByteStream, Part]) -> Part: raise ValueError(msg) @component.output_types(replies=List[Union[str, Dict[str, str]]]) - def run(self, parts: Variadic[Union[str, ByteStream, Part]]): + def run( + self, + parts: Variadic[Union[str, ByteStream, Part]], + streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, + ): """ Generates content using the Gemini model. :param parts: Prompt for the model. + :param streaming_callback: A callback function that is called when a new token is received from the stream. :returns: A dictionary with the following keys: - `replies`: A list of generated content. """ + # check if streaming_callback is passed + streaming_callback = streaming_callback or self._streaming_callback converted_parts = [self._convert_part(p) for p in parts] contents = [Content(parts=converted_parts, role="user")] @@ -183,14 +195,10 @@ def run(self, parts: Variadic[Union[str, ByteStream, Part]]): generation_config=self._generation_config, safety_settings=self._safety_settings, tools=self._tools, - stream=self._streaming_callback is not None, + stream=streaming_callback is not None, ) self._model.start_chat() - replies = ( - self.get_stream_response(res, self._streaming_callback) - if self._streaming_callback - else self.get_response(res) - ) + replies = self.get_stream_response(res, streaming_callback) if streaming_callback else self.get_response(res) return {"replies": replies} From d81b9507e2e73d44e302f197ded7ab847b7ce2ce Mon Sep 17 00:00:00 2001 From: Amna Mubashar Date: Fri, 23 Aug 2024 12:06:10 +0200 Subject: [PATCH 12/17] Revert "Update the changelog" This reverts commit be39977d3153742a31e56c00e8360bfe07be0550. --- integrations/amazon_bedrock/CHANGELOG.md | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/integrations/amazon_bedrock/CHANGELOG.md b/integrations/amazon_bedrock/CHANGELOG.md index 417c661fe..631978058 100644 --- a/integrations/amazon_bedrock/CHANGELOG.md +++ b/integrations/amazon_bedrock/CHANGELOG.md @@ -1,15 +1,5 @@ # Changelog -## [unreleased] - -### ๐Ÿ› Bug Fixes - -- *(Bedrock)* Allow tools kwargs for AWS Bedrock Claude model (#976) - -### ๐Ÿšœ Refactor - -- Remove usage of deprecated `ChatMessage.to_openai_format` (#1007) - ## [integrations/amazon_bedrock-v1.0.1] - 2024-08-19 ### ๐Ÿš€ Features From e5a4696daa3aabf0218b89115011cb04d68d1765 Mon Sep 17 00:00:00 2001 From: Amna Mubashar Date: Fri, 23 Aug 2024 12:08:49 +0200 Subject: [PATCH 13/17] Revert "Update the changelog" This reverts commit a1fbe3a82388069ad253f68a2b3aee337b9816ab. --- integrations/pinecone/CHANGELOG.md | 21 --------------------- 1 file changed, 21 deletions(-) diff --git a/integrations/pinecone/CHANGELOG.md b/integrations/pinecone/CHANGELOG.md index d9d12505e..317753192 100644 --- a/integrations/pinecone/CHANGELOG.md +++ b/integrations/pinecone/CHANGELOG.md @@ -1,26 +1,5 @@ # Changelog -## [unreleased] - -### ๐Ÿš€ Features - -- Add filter_policy to pinecone integration (#821) - -### ๐Ÿ› Bug Fixes - -- `pinecone` - Fallback to default filter policy when deserializing retrievers without the init parameter (#901) -- Skip unsupported meta fields in PineconeDB (#1009) - -### ๐Ÿงช Testing - -- Pinecone - fix `test_serverless_index_creation_from_scratch` (#806) -- Do not retry tests in `hatch run test` command (#954) - -### โš™๏ธ Miscellaneous Tasks - -- Retry tests to reduce flakyness (#836) -- Update ruff invocation to include check parameter (#853) - ## [integrations/pinecone-v1.1.0] - 2024-06-11 ### ๐Ÿš€ Features From 39f75bdc70aed4d2058f1b54bb7097ca417d9f61 Mon Sep 17 00:00:00 2001 From: Amna Mubashar Date: Fri, 23 Aug 2024 15:16:01 +0200 Subject: [PATCH 14/17] Updates based on PR review --- .../generators/google_vertex/gemini.py | 8 +-- .../google_vertex/tests/test_gemini.py | 58 ++++++++++++++++--- 2 files changed, 53 insertions(+), 13 deletions(-) diff --git a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/gemini.py b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/gemini.py index 5df622e19..36346254a 100644 --- a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/gemini.py +++ b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/gemini.py @@ -1,13 +1,13 @@ import logging from typing import Any, Callable, Dict, List, Optional, Union -import vertexai from haystack.core.component import component from haystack.core.component.types import Variadic from haystack.core.serialization import default_from_dict, default_to_dict from haystack.dataclasses import ByteStream, StreamingChunk from haystack.utils import deserialize_callable, serialize_callable -from vertexai.preview.generative_models import ( +from vertexai import init as vertexai_init +from vertexai.generative_models import ( Content, GenerationConfig, GenerativeModel, @@ -93,7 +93,7 @@ def __init__( """ # Login to GCP. This will fail if user has not set up their gcloud SDK - vertexai.init(project=project_id, location=location) + vertexai_init(project=project_id, location=location) self._model_name = model self._project_id = project_id @@ -234,7 +234,7 @@ def get_stream_response(self, stream, streaming_callback: Callable[[StreamingChu streaming_chunks: List[StreamingChunk] = [] for chunk in stream: - streaming_chunk = StreamingChunk(content=chunk.text, meta=chunk.usage_metadata) + streaming_chunk = StreamingChunk(content=chunk.text, meta=chunk.to_dict()) streaming_chunks.append(streaming_chunk) streaming_callback(streaming_chunk) diff --git a/integrations/google_vertex/tests/test_gemini.py b/integrations/google_vertex/tests/test_gemini.py index 658ea9f8b..7b8bb50e4 100644 --- a/integrations/google_vertex/tests/test_gemini.py +++ b/integrations/google_vertex/tests/test_gemini.py @@ -31,9 +31,8 @@ ) -@patch("haystack_integrations.components.generators.google_vertex.gemini.vertexai.init") -@patch("haystack_integrations.components.generators.google_vertex.gemini.GenerativeModel") -def test_init(mock_vertexai_init, _mock_generative_model): +@patch("haystack_integrations.components.generators.google_vertex.gemini.vertexai_init") +def test_init(mock_vertexai_init): generation_config = GenerationConfig( candidate_count=1, @@ -61,9 +60,28 @@ def test_init(mock_vertexai_init, _mock_generative_model): assert gemini._tools == [tool] -@patch("haystack_integrations.components.generators.google_vertex.gemini.vertexai") -@patch("haystack_integrations.components.generators.google_vertex.gemini.GenerativeModel") -def test_to_dict(_mock_vertexai, _mock_generative_model): +@patch("haystack_integrations.components.generators.google_vertex.gemini.vertexai_init") +def test_to_dict(_mock_vertexai_init): + + gemini = VertexAIGeminiGenerator( + project_id="TestID123", + ) + assert gemini.to_dict() == { + "type": "haystack_integrations.components.generators.google_vertex.gemini.VertexAIGeminiGenerator", + "init_parameters": { + "model": "gemini-pro-vision", + "project_id": "TestID123", + "location": None, + "generation_config": None, + "safety_settings": None, + "streaming_callback": None, + "tools": None, + }, + } + + +@patch("haystack_integrations.components.generators.google_vertex.gemini.vertexai_init") +def test_to_dict_with_params(_mock_vertexai_init): generation_config = GenerationConfig( candidate_count=1, stop_sequences=["stop"], @@ -123,9 +141,31 @@ def test_to_dict(_mock_vertexai, _mock_generative_model): } -@patch("haystack_integrations.components.generators.google_vertex.gemini.vertexai") -@patch("haystack_integrations.components.generators.google_vertex.gemini.GenerativeModel") -def test_from_dict(_mock_vertexai, _mock_generative_model): +@patch("haystack_integrations.components.generators.google_vertex.gemini.vertexai_init") +def test_from_dict(_mock_vertexai_init): + gemini = VertexAIGeminiGenerator.from_dict( + { + "type": "haystack_integrations.components.generators.google_vertex.gemini.VertexAIGeminiGenerator", + "init_parameters": { + "project_id": "TestID123", + "model": "gemini-pro-vision", + "generation_config": None, + "safety_settings": None, + "tools": None, + "streaming_callback": None, + }, + } + ) + + assert gemini._model_name == "gemini-pro-vision" + assert gemini._project_id == "TestID123" + assert gemini._safety_settings is None + assert gemini._tools is None + assert gemini._generation_config is None + + +@patch("haystack_integrations.components.generators.google_vertex.gemini.vertexai_init") +def test_from_dict_with_param(_mock_vertexai_init): gemini = VertexAIGeminiGenerator.from_dict( { "type": "haystack_integrations.components.generators.google_vertex.gemini.VertexAIGeminiGenerator", From 3a951f3944ba217cf1b55f940207b312ca1dc75e Mon Sep 17 00:00:00 2001 From: Amna Mubashar Date: Fri, 23 Aug 2024 15:59:53 +0200 Subject: [PATCH 15/17] Fix for tests --- integrations/google_vertex/tests/test_gemini.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/integrations/google_vertex/tests/test_gemini.py b/integrations/google_vertex/tests/test_gemini.py index 7b8bb50e4..affed1a8c 100644 --- a/integrations/google_vertex/tests/test_gemini.py +++ b/integrations/google_vertex/tests/test_gemini.py @@ -32,7 +32,8 @@ @patch("haystack_integrations.components.generators.google_vertex.gemini.vertexai_init") -def test_init(mock_vertexai_init): +@patch("haystack_integrations.components.generators.google_vertex.gemini.GenerativeModel") +def test_init(mock_vertexai_init, _mock_generative_model): generation_config = GenerationConfig( candidate_count=1, @@ -61,7 +62,8 @@ def test_init(mock_vertexai_init): @patch("haystack_integrations.components.generators.google_vertex.gemini.vertexai_init") -def test_to_dict(_mock_vertexai_init): +@patch("haystack_integrations.components.generators.google_vertex.gemini.GenerativeModel") +def test_to_dict(_mock_vertexai_init, _mock_generative_model): gemini = VertexAIGeminiGenerator( project_id="TestID123", @@ -81,7 +83,8 @@ def test_to_dict(_mock_vertexai_init): @patch("haystack_integrations.components.generators.google_vertex.gemini.vertexai_init") -def test_to_dict_with_params(_mock_vertexai_init): +@patch("haystack_integrations.components.generators.google_vertex.gemini.GenerativeModel") +def test_to_dict_with_params(_mock_vertexai_init, _mock_generative_model): generation_config = GenerationConfig( candidate_count=1, stop_sequences=["stop"], @@ -142,7 +145,8 @@ def test_to_dict_with_params(_mock_vertexai_init): @patch("haystack_integrations.components.generators.google_vertex.gemini.vertexai_init") -def test_from_dict(_mock_vertexai_init): +@patch("haystack_integrations.components.generators.google_vertex.gemini.GenerativeModel") +def test_from_dict(_mock_vertexai_init, _mock_generative_model): gemini = VertexAIGeminiGenerator.from_dict( { "type": "haystack_integrations.components.generators.google_vertex.gemini.VertexAIGeminiGenerator", @@ -165,7 +169,8 @@ def test_from_dict(_mock_vertexai_init): @patch("haystack_integrations.components.generators.google_vertex.gemini.vertexai_init") -def test_from_dict_with_param(_mock_vertexai_init): +@patch("haystack_integrations.components.generators.google_vertex.gemini.GenerativeModel") +def test_from_dict_with_param(_mock_vertexai_init, _mock_generative_model): gemini = VertexAIGeminiGenerator.from_dict( { "type": "haystack_integrations.components.generators.google_vertex.gemini.VertexAIGeminiGenerator", From 19f7db43066998fd919f50f3e521d675f351577f Mon Sep 17 00:00:00 2001 From: Amna Mubashar Date: Mon, 26 Aug 2024 01:00:30 +0200 Subject: [PATCH 16/17] Fix assertion --- integrations/google_vertex/tests/test_gemini.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/integrations/google_vertex/tests/test_gemini.py b/integrations/google_vertex/tests/test_gemini.py index affed1a8c..0704664c5 100644 --- a/integrations/google_vertex/tests/test_gemini.py +++ b/integrations/google_vertex/tests/test_gemini.py @@ -215,8 +215,7 @@ def test_from_dict_with_param(_mock_vertexai_init, _mock_generative_model): assert gemini._model_name == "gemini-pro-vision" assert gemini._project_id == "TestID123" assert gemini._safety_settings == {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH} - - # assert gemini._tools == [Tool(function_declarations=[GET_CURRENT_WEATHER_FUNC])] + assert repr(gemini._tools) == repr([Tool(function_declarations=[GET_CURRENT_WEATHER_FUNC])]) assert isinstance(gemini._generation_config, GenerationConfig) From 8359461fe7696bbb0101abfe96ac191ce6b11957 Mon Sep 17 00:00:00 2001 From: Amna Mubashar Date: Tue, 27 Aug 2024 12:48:16 +0200 Subject: [PATCH 17/17] Annotate param types --- .../components/generators/google_vertex/gemini.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/gemini.py b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/gemini.py index 36346254a..ca03a83b1 100644 --- a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/gemini.py +++ b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/gemini.py @@ -1,5 +1,5 @@ import logging -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Any, Callable, Dict, Iterable, List, Optional, Union from haystack.core.component import component from haystack.core.component.types import Variadic @@ -10,6 +10,7 @@ from vertexai.generative_models import ( Content, GenerationConfig, + GenerationResponse, GenerativeModel, HarmBlockThreshold, HarmCategory, @@ -198,11 +199,11 @@ def run( stream=streaming_callback is not None, ) self._model.start_chat() - replies = self.get_stream_response(res, streaming_callback) if streaming_callback else self.get_response(res) + replies = self._get_stream_response(res, streaming_callback) if streaming_callback else self._get_response(res) return {"replies": replies} - def get_response(self, response_body) -> List[str]: + def _get_response(self, response_body: GenerationResponse) -> List[str]: """ Extracts the responses from the Vertex AI response. @@ -223,7 +224,9 @@ def get_response(self, response_body) -> List[str]: replies.append(function_call) return replies - def get_stream_response(self, stream, streaming_callback: Callable[[StreamingChunk], None]) -> List[str]: + def _get_stream_response( + self, stream: Iterable[GenerationResponse], streaming_callback: Callable[[StreamingChunk], None] + ) -> List[str]: """ Extracts the responses from the Vertex AI streaming response.