Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable gemini context caching #3207

Draft
wants to merge 8 commits into
base: 0.2
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 105 additions & 33 deletions autogen/oai/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from __future__ import annotations

import base64
import datetime
import logging
import os
import random
Expand All @@ -47,6 +48,7 @@
from google.ai.generativelanguage import Content, Part
from google.api_core.exceptions import InternalServerError
from google.auth.credentials import Credentials
from google.generativeai import protos
from openai.types.chat import ChatCompletion
from openai.types.chat.chat_completion import ChatCompletionMessage, Choice
from openai.types.completion_usage import CompletionUsage
Expand All @@ -57,6 +59,7 @@
from vertexai.generative_models import HarmCategory as VertexAIHarmCategory
from vertexai.generative_models import Part as VertexAIPart
from vertexai.generative_models import SafetySetting as VertexAISafetySetting
from vertexai.preview import caching

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -129,6 +132,8 @@ def __init__(self, **kwargs):
assert ("project_id" not in kwargs) and (
"location" not in kwargs
), "Google Cloud project and compute location cannot be set when using an API Key!"
genai.configure(api_key=self.api_key)
self.context_cache = None

def message_retrieval(self, response) -> List:
"""
Expand All @@ -140,6 +145,7 @@ def message_retrieval(self, response) -> List:
return [choice.message for choice in response.choices]

def cost(self, response) -> float:
# TODO(yeounoh) should use cost calculation function.
return response.cost

@staticmethod
Expand Down Expand Up @@ -175,6 +181,8 @@ def create(self, params: Dict) -> ChatCompletion:
n_response = params.get("n", 1)
system_instruction = params.get("system_instruction", None)
response_validation = params.get("response_validation", True)
context_cache = params.get("context_cache", None)
self.context_cache = context_cache # Keep the cache reference used at the creation time

generation_config = {
gemini_term: params[autogen_term]
Expand All @@ -195,26 +203,27 @@ def create(self, params: Dict) -> ChatCompletion:
if n_response > 1:
warnings.warn("Gemini only supports `n=1` for now. We only generate one response.", UserWarning)

gen_model_cls = GenerativeModel if self.use_vertexai else genai.GenerativeModel
if context_cache:
# Context prefix caching can help reduce the cost.
model = gen_model_cls.from_cached_content(cached_content=context_cache)
else:
model = gen_model_cls(
model_name,
generation_config=generation_config,
safety_settings=safety_settings,
system_instruction=system_instruction,
)

if "vision" not in model_name:
# A. create and call the chat model.
gemini_messages = self._oai_messages_to_gemini_messages(messages)
if self.use_vertexai:
model = GenerativeModel(
model_name,
generation_config=generation_config,
safety_settings=safety_settings,
system_instruction=system_instruction,
)
# `response_validation=True` (default) sanitizes the chat history by logging
# only valid and complete messages. Blocked messages should be excluded to keep
# the chat session state usable. This is only available in Vertex AI SDK.
chat = model.start_chat(history=gemini_messages[:-1], response_validation=response_validation)
else:
# we use chat model by default
model = genai.GenerativeModel(
model_name,
generation_config=generation_config,
safety_settings=safety_settings,
system_instruction=system_instruction,
)
genai.configure(api_key=self.api_key)
chat = model.start_chat(history=gemini_messages[:-1])
max_retries = 5
for attempt in range(max_retries):
Expand Down Expand Up @@ -243,22 +252,7 @@ def create(self, params: Dict) -> ChatCompletion:
prompt_tokens = model.count_tokens(chat.history[:-1]).total_tokens
completion_tokens = model.count_tokens(ans).total_tokens
elif model_name == "gemini-pro-vision":
# B. handle the vision model
if self.use_vertexai:
model = GenerativeModel(
model_name,
generation_config=generation_config,
safety_settings=safety_settings,
system_instruction=system_instruction,
)
else:
model = genai.GenerativeModel(
model_name,
generation_config=generation_config,
safety_settings=safety_settings,
system_instruction=system_instruction,
)
genai.configure(api_key=self.api_key)
# B. handle the vision model.
# Gemini's vision model does not support chat history yet
# chat = model.start_chat(history=gemini_messages[:-1])
# response = chat.send_message(gemini_messages[-1].parts)
Expand All @@ -283,6 +277,7 @@ def create(self, params: Dict) -> ChatCompletion:
# 3. convert output
message = ChatCompletionMessage(role="assistant", content=ans, function_call=None, tool_calls=None)
choices = [Choice(finish_reason="stop", index=0, message=message)]
context_cache_tokens = int(self.context_cache.usage_metadata.total_token_count if self.context_cache else 0)

response_oai = ChatCompletion(
id=str(random.randint(0, 1000)),
Expand All @@ -295,7 +290,9 @@ def create(self, params: Dict) -> ChatCompletion:
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
),
cost=calculate_gemini_cost(prompt_tokens, completion_tokens, model_name),
cost=calculate_gemini_cost(
prompt_tokens - context_cache_tokens, completion_tokens, context_cache_tokens, model_name
),
)

return response_oai
Expand Down Expand Up @@ -438,6 +435,80 @@ def _to_vertexai_safety_settings(safety_settings):
return safety_settings


class GeminiContextCache:
"""
Context cache for Gemini models. The semantics of this cache operation is different
from the generic autogen.cache, where the input prompt and the agent outputs are cached.
Here, context cache stores the common prefix tokens to Gemini models.

Context cache helps reduce the cost by caching the same input tokens that are used repeatedly. A cache instance is created using
a publisher model and the model name is immutable once the cache is created.
The created cache has TTL (1 hour by default) and this can be updated after the creation.
The cost for caching depends on the input token size and how long you want the tokens to persist.
Context cache is available in Gemini 1.5.
"""

def __init__(
self,
model: str,
display_name: str,
system_instruction: str,
contents: list[str],
ttl: datetime.timedelta,
use_vertexai=True,
):
self.use_vertexai = use_vertexai
_caching = caching if use_vertexai else genai.caching
self.cache = _caching.CachedContent.create(
model=model, display_name=display_name, system_instruction=system_instruction, contents=contents, ttl=ttl
)

def is_compatible(self, model: Union[GenerativeModel, genai.GenerativeModel]) -> bool:
"""
Verify if this cache is compatible with a given model.
"""
# Context cache is available in gemini 1.5 stable versions.
if re.match(r"^gemini-1\.5-(pro|flash)-\d{3}$", model._model_name):
if (self.use_vertexai and isinstance(model, GenerativeModel)) or (
not self.use_vertexai and isinstance(model, genai.GenerativeModel)
):
return True
warnings.warn(
"Cache was created using a different SDK than the model: "
f"use_vertexai={self.use_vertexai}, type(model)={type(model)}"
)
return False

def update_ttl(self, ttl: datetime.timedelta):
self.cache.update(ttl=ttl)

def delete(self):
self.cache.delete()

@property
def model(self) -> str:
return self.cache.model()

@property
def name(self) -> str:
return self.cache.name()

@property
def display_name(self) -> str:
return self.cache.display_name()

@property
def usage_metadata(self) -> protos.CachedContent.UsageMetadata:
return self.cache.usage_metadata()

@property
def expire_time(self) -> datetime.datetime:
return self.cache.expire_time()

def __str__(self):
return self.cache.__str__()


def _to_pil(data: str) -> Image.Image:
"""
Converts a base64 encoded image data string to a PIL Image object.
Expand Down Expand Up @@ -472,11 +543,12 @@ def get_image_data(image_file: str, use_b64=True) -> bytes:
return content


def calculate_gemini_cost(input_tokens: int, output_tokens: int, model_name: str) -> float:
def calculate_gemini_cost(input_tokens: int, output_tokens: int, context_cache_tokens: int, model_name: str) -> float:
# TODO(yeounoh) - update the pricing model to reflect the prompt size
if "1.5" in model_name or "gemini-experimental" in model_name:
# "gemini-1.5-pro-preview-0409"
# Cost is $7 per million input tokens and $21 per million output tokens
return 7.0 * input_tokens / 1e6 + 21.0 * output_tokens / 1e6
return 7.0 * input_tokens / 1e6 + 21.0 * output_tokens / 1e6 + 1.75 * context_cache_tokens / 1e6

if "gemini-pro" not in model_name and "gemini-1.0-pro" not in model_name:
warnings.warn(f"Cost calculation is not implemented for model {model_name}. Using Gemini-1.0-Pro.", UserWarning)
Expand Down
72 changes: 70 additions & 2 deletions test/oai/test_gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

import pytest

from autogen.oai.gemini import calculate_gemini_cost

try:
import google.auth
from google.api_core.exceptions import InternalServerError
Expand All @@ -12,7 +14,7 @@
from vertexai.generative_models import HarmCategory as VertexAIHarmCategory
from vertexai.generative_models import SafetySetting as VertexAISafetySetting

from autogen.oai.gemini import GeminiClient
from autogen.oai.gemini import GeminiClient, GeminiContextCacheß

skip = False
except ImportError:
Expand Down Expand Up @@ -268,15 +270,35 @@ def test_internal_server_error_retry(mock_genai, gemini_client):
# Test cost calculation
@pytest.mark.skipif(skip, reason="Google GenAI dependency is not installed")
def test_cost_calculation(gemini_client, mock_response):
# TODO(yeounoh) - update the test case so that it is more meaningful.
response = mock_response(
text="Example response",
choices=[{"message": "Test message 1"}],
usage={"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15},
cost=0.01,
cost=0.000175,
model="gemini-pro",
)
assert gemini_client.cost(response) > 0, "Cost should be correctly calculated as zero"

response_with_cache = mock_response(
text="Example response",
choices=[{"message": "Test message 1"}],
usage={
# openai usage stats do not reflect gemini context caching.
"prompt_tokens": 10,
"completion_tokens": 5,
"total_tokens": 15,
# context_cache_tokens should offset prompt_tokens and reduce the
# total cost durign the cost calculation.
"context_cache_tokens": 3,
},
cost=0.00015925,
model="gemini-pro",
)
assert gemini_client.cost(response) > gemini_client.cost(
response_with_cache
), "Context caching should reduce the cost."


@pytest.mark.skipif(skip, reason="Google GenAI dependency is not installed")
@patch("autogen.oai.gemini.genai.GenerativeModel")
Expand Down Expand Up @@ -362,6 +384,52 @@ def test_vertexai_default_auth_create_response(mock_init, mock_generative_model,
assert response.choices[0].message.content == "Example response", "Response content should match expected output"


@pytest.mark.skipif(skip, reason="Google GenAI dependency is not installed")
@patch("autogen.oai.gemini.GenerativeModel")
@patch("autogen.oai.gemini.vertexai.init")
def test_vertexai_default_auth_create_response_with_context_cache(
mock_init, mock_generative_model, gemini_google_auth_default_client
):
# Mock the genai model configuration and creation process
mock_chat = MagicMock()
mock_model = MagicMock()
mock_init.return_value = None
mock_generative_model.return_value = mock_model
mock_model.start_chat.return_value = mock_chat

# Set up a mock for the chat history item access and the text attribute return
mock_history_part = MagicMock()
mock_history_part.text = "Example response"
mock_chat.history.__getitem__.return_value.parts.__getitem__.return_value = mock_history_part

# Setup the mock to return a mocked chat response
mock_chat.send_message.return_value = MagicMock(history=[MagicMock(parts=[MagicMock(text="Example response")])])

# Setup the mock to return a mocked cache usage
mock_context_cache = MagicMock(usage_metadata=MagicMock(total_token_count=10))

# Call the create method
response = gemini_google_auth_default_client.create(
{"model": "gemini-pro", "messages": [{"content": "Hello", "role": "user"}], "stream": False}
)
response_with_cache = gemini_google_auth_default_client.create(
{
"model": "gemini-1.5-pro-001",
"context_cache": mock_context_cache,
"messages": [{"content": "Hello", "role": "user"}],
"stream": False,
}
)

# Assertions to check if response is structured as expected
assert (
response_with_cache.choices[0].message.content == "Example response"
), "Response content should match expected output"
assert gemini_google_auth_default_client.cost(response) > gemini_google_auth_default_client.cost(
response_with_cache
), "Context caching should result in reduced cost."


@pytest.mark.skipif(skip, reason="Google GenAI dependency is not installed")
@patch("autogen.oai.gemini.genai.GenerativeModel")
@patch("autogen.oai.gemini.genai.configure")
Expand Down
Loading