diff --git a/autogen/oai/gemini.py b/autogen/oai/gemini.py index 33790c9851c..14bfd931715 100644 --- a/autogen/oai/gemini.py +++ b/autogen/oai/gemini.py @@ -32,6 +32,7 @@ from __future__ import annotations import base64 +import datetime import logging import os import random @@ -47,6 +48,7 @@ from google.ai.generativelanguage import Content, Part from google.api_core.exceptions import InternalServerError from google.auth.credentials import Credentials +from google.generativeai import protos from openai.types.chat import ChatCompletion from openai.types.chat.chat_completion import ChatCompletionMessage, Choice from openai.types.completion_usage import CompletionUsage @@ -57,6 +59,7 @@ from vertexai.generative_models import HarmCategory as VertexAIHarmCategory from vertexai.generative_models import Part as VertexAIPart from vertexai.generative_models import SafetySetting as VertexAISafetySetting +from vertexai.preview import caching logger = logging.getLogger(__name__) @@ -129,6 +132,8 @@ def __init__(self, **kwargs): assert ("project_id" not in kwargs) and ( "location" not in kwargs ), "Google Cloud project and compute location cannot be set when using an API Key!" + genai.configure(api_key=self.api_key) + self.context_cache = None def message_retrieval(self, response) -> List: """ @@ -140,6 +145,7 @@ def message_retrieval(self, response) -> List: return [choice.message for choice in response.choices] def cost(self, response) -> float: + # TODO(yeounoh) should use cost calculation function. return response.cost @staticmethod @@ -175,6 +181,8 @@ def create(self, params: Dict) -> ChatCompletion: n_response = params.get("n", 1) system_instruction = params.get("system_instruction", None) response_validation = params.get("response_validation", True) + context_cache = params.get("context_cache", None) + self.context_cache = context_cache # Keep the cache reference used at the creation time generation_config = { gemini_term: params[autogen_term] @@ -195,26 +203,27 @@ def create(self, params: Dict) -> ChatCompletion: if n_response > 1: warnings.warn("Gemini only supports `n=1` for now. We only generate one response.", UserWarning) + gen_model_cls = GenerativeModel if self.use_vertexai else genai.GenerativeModel + if context_cache: + # Context prefix caching can help reduce the cost. + model = gen_model_cls.from_cached_content(cached_content=context_cache) + else: + model = gen_model_cls( + model_name, + generation_config=generation_config, + safety_settings=safety_settings, + system_instruction=system_instruction, + ) + if "vision" not in model_name: # A. create and call the chat model. gemini_messages = self._oai_messages_to_gemini_messages(messages) if self.use_vertexai: - model = GenerativeModel( - model_name, - generation_config=generation_config, - safety_settings=safety_settings, - system_instruction=system_instruction, - ) + # `response_validation=True` (default) sanitizes the chat history by logging + # only valid and complete messages. Blocked messages should be excluded to keep + # the chat session state usable. This is only available in Vertex AI SDK. chat = model.start_chat(history=gemini_messages[:-1], response_validation=response_validation) else: - # we use chat model by default - model = genai.GenerativeModel( - model_name, - generation_config=generation_config, - safety_settings=safety_settings, - system_instruction=system_instruction, - ) - genai.configure(api_key=self.api_key) chat = model.start_chat(history=gemini_messages[:-1]) max_retries = 5 for attempt in range(max_retries): @@ -243,22 +252,7 @@ def create(self, params: Dict) -> ChatCompletion: prompt_tokens = model.count_tokens(chat.history[:-1]).total_tokens completion_tokens = model.count_tokens(ans).total_tokens elif model_name == "gemini-pro-vision": - # B. handle the vision model - if self.use_vertexai: - model = GenerativeModel( - model_name, - generation_config=generation_config, - safety_settings=safety_settings, - system_instruction=system_instruction, - ) - else: - model = genai.GenerativeModel( - model_name, - generation_config=generation_config, - safety_settings=safety_settings, - system_instruction=system_instruction, - ) - genai.configure(api_key=self.api_key) + # B. handle the vision model. # Gemini's vision model does not support chat history yet # chat = model.start_chat(history=gemini_messages[:-1]) # response = chat.send_message(gemini_messages[-1].parts) @@ -283,6 +277,7 @@ def create(self, params: Dict) -> ChatCompletion: # 3. convert output message = ChatCompletionMessage(role="assistant", content=ans, function_call=None, tool_calls=None) choices = [Choice(finish_reason="stop", index=0, message=message)] + context_cache_tokens = int(self.context_cache.usage_metadata.total_token_count if self.context_cache else 0) response_oai = ChatCompletion( id=str(random.randint(0, 1000)), @@ -295,7 +290,9 @@ def create(self, params: Dict) -> ChatCompletion: completion_tokens=completion_tokens, total_tokens=prompt_tokens + completion_tokens, ), - cost=calculate_gemini_cost(prompt_tokens, completion_tokens, model_name), + cost=calculate_gemini_cost( + prompt_tokens - context_cache_tokens, completion_tokens, context_cache_tokens, model_name + ), ) return response_oai @@ -438,6 +435,80 @@ def _to_vertexai_safety_settings(safety_settings): return safety_settings +class GeminiContextCache: + """ + Context cache for Gemini models. The semantics of this cache operation is different + from the generic autogen.cache, where the input prompt and the agent outputs are cached. + Here, context cache stores the common prefix tokens to Gemini models. + + Context cache helps reduce the cost by caching the same input tokens that are used repeatedly. A cache instance is created using + a publisher model and the model name is immutable once the cache is created. + The created cache has TTL (1 hour by default) and this can be updated after the creation. + The cost for caching depends on the input token size and how long you want the tokens to persist. + Context cache is available in Gemini 1.5. + """ + + def __init__( + self, + model: str, + display_name: str, + system_instruction: str, + contents: list[str], + ttl: datetime.timedelta, + use_vertexai=True, + ): + self.use_vertexai = use_vertexai + _caching = caching if use_vertexai else genai.caching + self.cache = _caching.CachedContent.create( + model=model, display_name=display_name, system_instruction=system_instruction, contents=contents, ttl=ttl + ) + + def is_compatible(self, model: Union[GenerativeModel, genai.GenerativeModel]) -> bool: + """ + Verify if this cache is compatible with a given model. + """ + # Context cache is available in gemini 1.5 stable versions. + if re.match(r"^gemini-1\.5-(pro|flash)-\d{3}$", model._model_name): + if (self.use_vertexai and isinstance(model, GenerativeModel)) or ( + not self.use_vertexai and isinstance(model, genai.GenerativeModel) + ): + return True + warnings.warn( + "Cache was created using a different SDK than the model: " + f"use_vertexai={self.use_vertexai}, type(model)={type(model)}" + ) + return False + + def update_ttl(self, ttl: datetime.timedelta): + self.cache.update(ttl=ttl) + + def delete(self): + self.cache.delete() + + @property + def model(self) -> str: + return self.cache.model() + + @property + def name(self) -> str: + return self.cache.name() + + @property + def display_name(self) -> str: + return self.cache.display_name() + + @property + def usage_metadata(self) -> protos.CachedContent.UsageMetadata: + return self.cache.usage_metadata() + + @property + def expire_time(self) -> datetime.datetime: + return self.cache.expire_time() + + def __str__(self): + return self.cache.__str__() + + def _to_pil(data: str) -> Image.Image: """ Converts a base64 encoded image data string to a PIL Image object. @@ -472,11 +543,12 @@ def get_image_data(image_file: str, use_b64=True) -> bytes: return content -def calculate_gemini_cost(input_tokens: int, output_tokens: int, model_name: str) -> float: +def calculate_gemini_cost(input_tokens: int, output_tokens: int, context_cache_tokens: int, model_name: str) -> float: + # TODO(yeounoh) - update the pricing model to reflect the prompt size if "1.5" in model_name or "gemini-experimental" in model_name: # "gemini-1.5-pro-preview-0409" # Cost is $7 per million input tokens and $21 per million output tokens - return 7.0 * input_tokens / 1e6 + 21.0 * output_tokens / 1e6 + return 7.0 * input_tokens / 1e6 + 21.0 * output_tokens / 1e6 + 1.75 * context_cache_tokens / 1e6 if "gemini-pro" not in model_name and "gemini-1.0-pro" not in model_name: warnings.warn(f"Cost calculation is not implemented for model {model_name}. Using Gemini-1.0-Pro.", UserWarning) diff --git a/test/oai/test_gemini.py b/test/oai/test_gemini.py index 61fdbe6d735..3c362b4a599 100644 --- a/test/oai/test_gemini.py +++ b/test/oai/test_gemini.py @@ -3,6 +3,8 @@ import pytest +from autogen.oai.gemini import calculate_gemini_cost + try: import google.auth from google.api_core.exceptions import InternalServerError @@ -12,7 +14,7 @@ from vertexai.generative_models import HarmCategory as VertexAIHarmCategory from vertexai.generative_models import SafetySetting as VertexAISafetySetting - from autogen.oai.gemini import GeminiClient + from autogen.oai.gemini import GeminiClient, GeminiContextCacheß skip = False except ImportError: @@ -268,15 +270,35 @@ def test_internal_server_error_retry(mock_genai, gemini_client): # Test cost calculation @pytest.mark.skipif(skip, reason="Google GenAI dependency is not installed") def test_cost_calculation(gemini_client, mock_response): + # TODO(yeounoh) - update the test case so that it is more meaningful. response = mock_response( text="Example response", choices=[{"message": "Test message 1"}], usage={"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15}, - cost=0.01, + cost=0.000175, model="gemini-pro", ) assert gemini_client.cost(response) > 0, "Cost should be correctly calculated as zero" + response_with_cache = mock_response( + text="Example response", + choices=[{"message": "Test message 1"}], + usage={ + # openai usage stats do not reflect gemini context caching. + "prompt_tokens": 10, + "completion_tokens": 5, + "total_tokens": 15, + # context_cache_tokens should offset prompt_tokens and reduce the + # total cost durign the cost calculation. + "context_cache_tokens": 3, + }, + cost=0.00015925, + model="gemini-pro", + ) + assert gemini_client.cost(response) > gemini_client.cost( + response_with_cache + ), "Context caching should reduce the cost." + @pytest.mark.skipif(skip, reason="Google GenAI dependency is not installed") @patch("autogen.oai.gemini.genai.GenerativeModel") @@ -362,6 +384,52 @@ def test_vertexai_default_auth_create_response(mock_init, mock_generative_model, assert response.choices[0].message.content == "Example response", "Response content should match expected output" +@pytest.mark.skipif(skip, reason="Google GenAI dependency is not installed") +@patch("autogen.oai.gemini.GenerativeModel") +@patch("autogen.oai.gemini.vertexai.init") +def test_vertexai_default_auth_create_response_with_context_cache( + mock_init, mock_generative_model, gemini_google_auth_default_client +): + # Mock the genai model configuration and creation process + mock_chat = MagicMock() + mock_model = MagicMock() + mock_init.return_value = None + mock_generative_model.return_value = mock_model + mock_model.start_chat.return_value = mock_chat + + # Set up a mock for the chat history item access and the text attribute return + mock_history_part = MagicMock() + mock_history_part.text = "Example response" + mock_chat.history.__getitem__.return_value.parts.__getitem__.return_value = mock_history_part + + # Setup the mock to return a mocked chat response + mock_chat.send_message.return_value = MagicMock(history=[MagicMock(parts=[MagicMock(text="Example response")])]) + + # Setup the mock to return a mocked cache usage + mock_context_cache = MagicMock(usage_metadata=MagicMock(total_token_count=10)) + + # Call the create method + response = gemini_google_auth_default_client.create( + {"model": "gemini-pro", "messages": [{"content": "Hello", "role": "user"}], "stream": False} + ) + response_with_cache = gemini_google_auth_default_client.create( + { + "model": "gemini-1.5-pro-001", + "context_cache": mock_context_cache, + "messages": [{"content": "Hello", "role": "user"}], + "stream": False, + } + ) + + # Assertions to check if response is structured as expected + assert ( + response_with_cache.choices[0].message.content == "Example response" + ), "Response content should match expected output" + assert gemini_google_auth_default_client.cost(response) > gemini_google_auth_default_client.cost( + response_with_cache + ), "Context caching should result in reduced cost." + + @pytest.mark.skipif(skip, reason="Google GenAI dependency is not installed") @patch("autogen.oai.gemini.genai.GenerativeModel") @patch("autogen.oai.gemini.genai.configure")