Skip to content

Commit

Permalink
Add another unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
yeounoh committed Jul 25, 2024
1 parent 491b67e commit d38df67
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 3 deletions.
7 changes: 4 additions & 3 deletions autogen/oai/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,15 +209,16 @@ def create(self, params: Dict) -> ChatCompletion:

gen_model_cls = GenerativeModel if self.use_vertexai else genai.GenerativeModel
if context_cache:
# Context prefix caching can help reduce the cost.
model = gen_model_cls.from_cached_content(cached_content=context_cache)
else:
model = gen_model_cls(
model_name,
generation_config=generation_config,
safety_settings=safety_settings,
system_instruction=system_instruction,
)
else:
# Context prefix caching can help reduce the cost.
model = gen_model_cls.from_cached_content(cached_content=context_cache)


if "vision" not in model_name:
# A. create and call the chat model.
Expand Down
49 changes: 49 additions & 0 deletions test/oai/test_gemini.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
from unittest.mock import MagicMock, patch

from autogen.oai.gemini import calculate_gemini_cost
import pytest

try:
Expand All @@ -13,6 +14,7 @@
from vertexai.generative_models import SafetySetting as VertexAISafetySetting

from autogen.oai.gemini import GeminiClient
from autogen.oai.gemini import GeminiContextCacheß

skip = False
except ImportError:
Expand Down Expand Up @@ -382,6 +384,53 @@ def test_vertexai_default_auth_create_response(mock_init, mock_generative_model,
# Assertions to check if response is structured as expected
assert response.choices[0].message.content == "Example response", "Response content should match expected output"

@pytest.mark.skipif(skip, reason="Google GenAI dependency is not installed")
@patch("autogen.oai.gemini.GenerativeModel")
@patch("autogen.oai.gemini.vertexai.init")
def test_vertexai_default_auth_create_response_with_context_cache(mock_init, mock_generative_model, gemini_google_auth_default_client):
# Mock the genai model configuration and creation process
mock_chat = MagicMock()
mock_model = MagicMock()
mock_init.return_value = None
mock_generative_model.return_value = mock_model
mock_model.start_chat.return_value = mock_chat

# Set up a mock for the chat history item access and the text attribute return
mock_history_part = MagicMock()
mock_history_part.text = "Example response"
mock_chat.history.__getitem__.return_value.parts.__getitem__.return_value = mock_history_part

# Setup the mock to return a mocked chat response
mock_chat.send_message.return_value = MagicMock(history=[MagicMock(parts=[MagicMock(text="Example response")])])

# Setup the mock to return a mocked cache usage
mock_context_cache = MagicMock(usage_metadata=MagicMock(total_token_count = 10))

# Call the create method
response = gemini_google_auth_default_client.create(
{"model": "gemini-pro", "messages": [{"content": "Hello", "role": "user"}], "stream": False}
)
response_with_cache = gemini_google_auth_default_client.create({
"model":
"gemini-1.5-pro-001",
"context_cache":
mock_context_cache,
"messages": [{
"content": "Hello",
"role": "user"
}],
"stream":
False
})

# Assertions to check if response is structured as expected
assert response_with_cache.choices[
0].message.content == "Example response", "Response content should match expected output"
assert gemini_google_auth_default_client.cost(
response) > gemini_google_auth_default_client.cost(
response_with_cache
), "Context caching should result in reduced cost."


@pytest.mark.skipif(skip, reason="Google GenAI dependency is not installed")
@patch("autogen.oai.gemini.genai.GenerativeModel")
Expand Down

0 comments on commit d38df67

Please sign in to comment.