Skip to content

Commit

Permalink
update anthropic version (uripeled2#25)
Browse files Browse the repository at this point in the history
* update anthropic version

* fix import in test_anthropic_client.py

* fix import in test_anthropic_client.py

* Update antrofic version
  • Loading branch information
uripeled2 authored Jul 1, 2023
1 parent f417c5b commit 50bc2c8
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 21 deletions.
2 changes: 1 addition & 1 deletion llm_client/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "0.6.1"
__version__ = "0.6.2"

from llm_client.base_llm_client import BaseLLMClient

Expand Down
8 changes: 6 additions & 2 deletions llm_client/llm_api_client/anthropic_client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Optional

from anthropic import count_tokens
from anthropic import AsyncAnthropic

from llm_client.llm_api_client.base_llm_api_client import BaseLLMAPIClient, LLMAPIClientConfig
from llm_client.consts import PROMPT_KEY
Expand All @@ -10,6 +10,7 @@
COMPLETIONS_KEY = "completion"
AUTH_HEADER = "x-api-key"
ACCEPT_HEADER = "Accept"
VERSION_HEADER = "anthropic-version"
ACCEPT_VALUE = "application/json"
MAX_TOKENS_KEY = "max_tokens_to_sample"

Expand All @@ -19,6 +20,9 @@ def __init__(self, config: LLMAPIClientConfig):
super().__init__(config)
if self._base_url is None:
self._base_url = BASE_URL
self._anthropic = AsyncAnthropic()
if self._headers.get(VERSION_HEADER) is None:
self._headers[VERSION_HEADER] = self._anthropic.default_headers[VERSION_HEADER]
self._headers[ACCEPT_HEADER] = ACCEPT_VALUE
self._headers[AUTH_HEADER] = self._api_key

Expand All @@ -40,4 +44,4 @@ async def text_completion(self, prompt: str, model: Optional[str] = None, max_to
return [response_json[COMPLETIONS_KEY]]

async def get_tokens_count(self, text: str, **kwargs) -> int:
return count_tokens(text)
return await self._anthropic.count_tokens(text)
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ huggingface = [
"transformers >= 4.0.0"
]
anthropic = [
"anthropic >= 0.2.0"
"anthropic >= 0.3.2"
]
google = [
"google-generativeai >= 0.1.0"
Expand Down
23 changes: 21 additions & 2 deletions tests/llm_api_client/anthropic_client/conftest.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from unittest.mock import patch, AsyncMock

import pytest

from llm_client import AnthropicClient
from llm_client.llm_api_client.anthropic_client import BASE_URL, COMPLETE_PATH
from llm_client.llm_api_client.anthropic_client import BASE_URL, COMPLETE_PATH, VERSION_HEADER, AnthropicClient
from llm_client.llm_api_client.base_llm_api_client import LLMAPIClientConfig


Expand All @@ -23,3 +24,21 @@ def llm_client(config):
@pytest.fixture
def complete_url():
return BASE_URL + COMPLETE_PATH


@pytest.fixture
def number_of_tokens():
return 10


@pytest.fixture
def anthropic_version():
return "2023-06-01"


@pytest.fixture(autouse=True)
def mock_anthropic(number_of_tokens, anthropic_version):
with patch("llm_client.llm_api_client.anthropic_client.AsyncAnthropic") as mock_anthropic:
mock_anthropic.return_value.count_tokens = AsyncMock(return_value=number_of_tokens)
mock_anthropic.return_value.default_headers = {VERSION_HEADER: anthropic_version}
yield mock_anthropic
52 changes: 37 additions & 15 deletions tests/llm_api_client/anthropic_client/test_anthropic_client.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
from unittest.mock import patch

import pytest

from llm_client import LLMAPIClientFactory, LLMAPIClientType, AnthropicClient
from llm_client import LLMAPIClientFactory, LLMAPIClientType
from llm_client.consts import PROMPT_KEY, MODEL_KEY
from llm_client.llm_api_client.anthropic_client import AUTH_HEADER, COMPLETIONS_KEY, MAX_TOKENS_KEY, ACCEPT_HEADER, \
ACCEPT_VALUE
ACCEPT_VALUE, VERSION_HEADER, AnthropicClient


@pytest.mark.asyncio
Expand All @@ -18,18 +16,41 @@ async def test_get_llm_api_client__with_anthropic(config):


@pytest.mark.asyncio
async def test_text_completion__sanity(mock_aioresponse, llm_client, complete_url):
async def test_text_completion__sanity(mock_aioresponse, llm_client, complete_url, anthropic_version):
mock_aioresponse.post(
complete_url,
payload={COMPLETIONS_KEY: "completion text"}
)

actual = await llm_client.text_completion(prompt="These are a few of my favorite", max_tokens=10)

assert actual == ["completion text"]
mock_aioresponse.assert_called_once_with(complete_url, method='POST',
headers={AUTH_HEADER: llm_client._api_key,
ACCEPT_HEADER: ACCEPT_VALUE,
VERSION_HEADER: anthropic_version},
json={PROMPT_KEY: 'These are a few of my favorite',
MAX_TOKENS_KEY: 10, "temperature": 1,
MODEL_KEY: llm_client._default_model},
raise_for_status=True)


@pytest.mark.asyncio
async def test_text_completion__with_version_header(mock_aioresponse, config, complete_url):
mock_aioresponse.post(
complete_url,
payload={COMPLETIONS_KEY: "completion text"}
)
config.headers[VERSION_HEADER] = "1.0.0"
llm_client = AnthropicClient(config)

actual = await llm_client.text_completion(prompt="These are a few of my favorite", max_tokens=10)

assert actual == ["completion text"]
mock_aioresponse.assert_called_once_with(complete_url, method='POST',
headers={AUTH_HEADER: llm_client._api_key,
ACCEPT_HEADER: ACCEPT_VALUE},
ACCEPT_HEADER: ACCEPT_VALUE,
VERSION_HEADER: "1.0.0"},
json={PROMPT_KEY: 'These are a few of my favorite',
MAX_TOKENS_KEY: 10, "temperature": 1,
MODEL_KEY: llm_client._default_model},
Expand All @@ -43,7 +64,7 @@ async def test_text_completion__without_max_tokens_raise_value_error(mock_aiores


@pytest.mark.asyncio
async def test_text_completion__override_model(mock_aioresponse, llm_client, complete_url):
async def test_text_completion__override_model(mock_aioresponse, llm_client, complete_url, anthropic_version):
new_model_name = "claude-instant"
mock_aioresponse.post(
complete_url,
Expand All @@ -56,15 +77,16 @@ async def test_text_completion__override_model(mock_aioresponse, llm_client, com
assert actual == ["completion text"]
mock_aioresponse.assert_called_once_with(complete_url, method='POST',
headers={AUTH_HEADER: llm_client._api_key,
ACCEPT_HEADER: ACCEPT_VALUE},
ACCEPT_HEADER: ACCEPT_VALUE,
VERSION_HEADER: anthropic_version},
json={PROMPT_KEY: 'These are a few of my favorite',
MAX_TOKENS_KEY: 10, "temperature": 1,
MODEL_KEY: new_model_name},
raise_for_status=True)


@pytest.mark.asyncio
async def test_text_completion__with_kwargs(mock_aioresponse, llm_client, complete_url):
async def test_text_completion__with_kwargs(mock_aioresponse, llm_client, complete_url, anthropic_version):
mock_aioresponse.post(
complete_url,
payload={COMPLETIONS_KEY: "completion text"}
Expand All @@ -75,7 +97,8 @@ async def test_text_completion__with_kwargs(mock_aioresponse, llm_client, comple
assert actual == ["completion text"]
mock_aioresponse.assert_called_once_with(complete_url, method='POST',
headers={AUTH_HEADER: llm_client._api_key,
ACCEPT_HEADER: ACCEPT_VALUE},
ACCEPT_HEADER: ACCEPT_VALUE,
VERSION_HEADER: anthropic_version},
json={PROMPT_KEY: 'These are a few of my favorite',
MAX_TOKENS_KEY: 10,
MODEL_KEY: llm_client._default_model,
Expand All @@ -84,9 +107,8 @@ async def test_text_completion__with_kwargs(mock_aioresponse, llm_client, comple


@pytest.mark.asyncio
async def test_get_tokens_count__sanity(llm_client):
with patch("llm_client.llm_api_client.anthropic_client.count_tokens") as mock_count_tokens:
actual = await llm_client.get_tokens_count(text="These are a few of my favorite things!")
async def test_get_tokens_count__sanity(llm_client, number_of_tokens, mock_anthropic):
actual = await llm_client.get_tokens_count(text="These are a few of my favorite things!")

assert actual == mock_count_tokens.return_value
mock_count_tokens.assert_called_once_with("These are a few of my favorite things!")
assert actual == 10
mock_anthropic.return_value.count_tokens.assert_awaited_once_with("These are a few of my favorite things!")

0 comments on commit 50bc2c8

Please sign in to comment.