Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update anthropic version #25

Merged
merged 4 commits into from
Jul 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion llm_client/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "0.6.1"
__version__ = "0.6.2"

from llm_client.base_llm_client import BaseLLMClient

Expand Down
8 changes: 6 additions & 2 deletions llm_client/llm_api_client/anthropic_client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Optional

from anthropic import count_tokens
from anthropic import AsyncAnthropic

from llm_client.llm_api_client.base_llm_api_client import BaseLLMAPIClient, LLMAPIClientConfig
from llm_client.consts import PROMPT_KEY
Expand All @@ -10,6 +10,7 @@
COMPLETIONS_KEY = "completion"
AUTH_HEADER = "x-api-key"
ACCEPT_HEADER = "Accept"
VERSION_HEADER = "anthropic-version"
ACCEPT_VALUE = "application/json"
MAX_TOKENS_KEY = "max_tokens_to_sample"

Expand All @@ -19,6 +20,9 @@ def __init__(self, config: LLMAPIClientConfig):
super().__init__(config)
if self._base_url is None:
self._base_url = BASE_URL
self._anthropic = AsyncAnthropic()
if self._headers.get(VERSION_HEADER) is None:
self._headers[VERSION_HEADER] = self._anthropic.default_headers[VERSION_HEADER]
self._headers[ACCEPT_HEADER] = ACCEPT_VALUE
self._headers[AUTH_HEADER] = self._api_key

Expand All @@ -40,4 +44,4 @@ async def text_completion(self, prompt: str, model: Optional[str] = None, max_to
return [response_json[COMPLETIONS_KEY]]

async def get_tokens_count(self, text: str, **kwargs) -> int:
return count_tokens(text)
return await self._anthropic.count_tokens(text)
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ huggingface = [
"transformers >= 4.0.0"
]
anthropic = [
"anthropic >= 0.2.0"
"anthropic >= 0.3.2"
]
google = [
"google-generativeai >= 0.1.0"
Expand Down
23 changes: 21 additions & 2 deletions tests/llm_api_client/anthropic_client/conftest.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from unittest.mock import patch, AsyncMock

import pytest

from llm_client import AnthropicClient
from llm_client.llm_api_client.anthropic_client import BASE_URL, COMPLETE_PATH
from llm_client.llm_api_client.anthropic_client import BASE_URL, COMPLETE_PATH, VERSION_HEADER, AnthropicClient
from llm_client.llm_api_client.base_llm_api_client import LLMAPIClientConfig


Expand All @@ -23,3 +24,21 @@ def llm_client(config):
@pytest.fixture
def complete_url():
return BASE_URL + COMPLETE_PATH


@pytest.fixture
def number_of_tokens():
return 10


@pytest.fixture
def anthropic_version():
return "2023-06-01"


@pytest.fixture(autouse=True)
def mock_anthropic(number_of_tokens, anthropic_version):
with patch("llm_client.llm_api_client.anthropic_client.AsyncAnthropic") as mock_anthropic:
mock_anthropic.return_value.count_tokens = AsyncMock(return_value=number_of_tokens)
mock_anthropic.return_value.default_headers = {VERSION_HEADER: anthropic_version}
yield mock_anthropic
52 changes: 37 additions & 15 deletions tests/llm_api_client/anthropic_client/test_anthropic_client.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
from unittest.mock import patch

import pytest

from llm_client import LLMAPIClientFactory, LLMAPIClientType, AnthropicClient
from llm_client import LLMAPIClientFactory, LLMAPIClientType
from llm_client.consts import PROMPT_KEY, MODEL_KEY
from llm_client.llm_api_client.anthropic_client import AUTH_HEADER, COMPLETIONS_KEY, MAX_TOKENS_KEY, ACCEPT_HEADER, \
ACCEPT_VALUE
ACCEPT_VALUE, VERSION_HEADER, AnthropicClient


@pytest.mark.asyncio
Expand All @@ -18,18 +16,41 @@ async def test_get_llm_api_client__with_anthropic(config):


@pytest.mark.asyncio
async def test_text_completion__sanity(mock_aioresponse, llm_client, complete_url):
async def test_text_completion__sanity(mock_aioresponse, llm_client, complete_url, anthropic_version):
mock_aioresponse.post(
complete_url,
payload={COMPLETIONS_KEY: "completion text"}
)

actual = await llm_client.text_completion(prompt="These are a few of my favorite", max_tokens=10)

assert actual == ["completion text"]
mock_aioresponse.assert_called_once_with(complete_url, method='POST',
headers={AUTH_HEADER: llm_client._api_key,
ACCEPT_HEADER: ACCEPT_VALUE,
VERSION_HEADER: anthropic_version},
json={PROMPT_KEY: 'These are a few of my favorite',
MAX_TOKENS_KEY: 10, "temperature": 1,
MODEL_KEY: llm_client._default_model},
raise_for_status=True)


@pytest.mark.asyncio
async def test_text_completion__with_version_header(mock_aioresponse, config, complete_url):
mock_aioresponse.post(
complete_url,
payload={COMPLETIONS_KEY: "completion text"}
)
config.headers[VERSION_HEADER] = "1.0.0"
llm_client = AnthropicClient(config)

actual = await llm_client.text_completion(prompt="These are a few of my favorite", max_tokens=10)

assert actual == ["completion text"]
mock_aioresponse.assert_called_once_with(complete_url, method='POST',
headers={AUTH_HEADER: llm_client._api_key,
ACCEPT_HEADER: ACCEPT_VALUE},
ACCEPT_HEADER: ACCEPT_VALUE,
VERSION_HEADER: "1.0.0"},
json={PROMPT_KEY: 'These are a few of my favorite',
MAX_TOKENS_KEY: 10, "temperature": 1,
MODEL_KEY: llm_client._default_model},
Expand All @@ -43,7 +64,7 @@ async def test_text_completion__without_max_tokens_raise_value_error(mock_aiores


@pytest.mark.asyncio
async def test_text_completion__override_model(mock_aioresponse, llm_client, complete_url):
async def test_text_completion__override_model(mock_aioresponse, llm_client, complete_url, anthropic_version):
new_model_name = "claude-instant"
mock_aioresponse.post(
complete_url,
Expand All @@ -56,15 +77,16 @@ async def test_text_completion__override_model(mock_aioresponse, llm_client, com
assert actual == ["completion text"]
mock_aioresponse.assert_called_once_with(complete_url, method='POST',
headers={AUTH_HEADER: llm_client._api_key,
ACCEPT_HEADER: ACCEPT_VALUE},
ACCEPT_HEADER: ACCEPT_VALUE,
VERSION_HEADER: anthropic_version},
json={PROMPT_KEY: 'These are a few of my favorite',
MAX_TOKENS_KEY: 10, "temperature": 1,
MODEL_KEY: new_model_name},
raise_for_status=True)


@pytest.mark.asyncio
async def test_text_completion__with_kwargs(mock_aioresponse, llm_client, complete_url):
async def test_text_completion__with_kwargs(mock_aioresponse, llm_client, complete_url, anthropic_version):
mock_aioresponse.post(
complete_url,
payload={COMPLETIONS_KEY: "completion text"}
Expand All @@ -75,7 +97,8 @@ async def test_text_completion__with_kwargs(mock_aioresponse, llm_client, comple
assert actual == ["completion text"]
mock_aioresponse.assert_called_once_with(complete_url, method='POST',
headers={AUTH_HEADER: llm_client._api_key,
ACCEPT_HEADER: ACCEPT_VALUE},
ACCEPT_HEADER: ACCEPT_VALUE,
VERSION_HEADER: anthropic_version},
json={PROMPT_KEY: 'These are a few of my favorite',
MAX_TOKENS_KEY: 10,
MODEL_KEY: llm_client._default_model,
Expand All @@ -84,9 +107,8 @@ async def test_text_completion__with_kwargs(mock_aioresponse, llm_client, comple


@pytest.mark.asyncio
async def test_get_tokens_count__sanity(llm_client):
with patch("llm_client.llm_api_client.anthropic_client.count_tokens") as mock_count_tokens:
actual = await llm_client.get_tokens_count(text="These are a few of my favorite things!")
async def test_get_tokens_count__sanity(llm_client, number_of_tokens, mock_anthropic):
actual = await llm_client.get_tokens_count(text="These are a few of my favorite things!")

assert actual == mock_count_tokens.return_value
mock_count_tokens.assert_called_once_with("These are a few of my favorite things!")
assert actual == 10
mock_anthropic.return_value.count_tokens.assert_awaited_once_with("These are a few of my favorite things!")