From 6b8e74774b41c0ce8f1ee7d0b045c878665e8784 Mon Sep 17 00:00:00 2001 From: Eyal Paz Date: Fri, 14 Jul 2023 09:49:29 +0300 Subject: [PATCH] changes after cr --- llm_client/llm_api_client/ai21_client.py | 2 +- llm_client/llm_api_client/aleph_alpha_client.py | 2 +- llm_client/llm_api_client/anthropic_client.py | 5 +++-- tests/llm_api_client/ai21_client/test_ai21.py | 8 ++++---- tests/resources/openai/chat_completion.json | 4 +--- tests/resources/openai/text_completion.json | 4 +--- 6 files changed, 11 insertions(+), 14 deletions(-) diff --git a/llm_client/llm_api_client/ai21_client.py b/llm_client/llm_api_client/ai21_client.py index ef14861..46617a3 100644 --- a/llm_client/llm_api_client/ai21_client.py +++ b/llm_client/llm_api_client/ai21_client.py @@ -22,7 +22,7 @@ def __init__(self, config: LLMAPIClientConfig): self._headers[AUTH_HEADER] = BEARER_TOKEN + self._api_key async def text_completion(self, prompt: str, model: Optional[str] = None, max_tokens: int = 16, - temperature: float = 0.7, top_p: Optional[float] = None ,**kwargs) -> list[str]: + temperature: float = 0.7, top_p: float = 1 ,**kwargs) -> list[str]: model = model or self._default_model kwargs[PROMPT_KEY] = prompt kwargs["topP"] = kwargs.pop("topP", top_p) diff --git a/llm_client/llm_api_client/aleph_alpha_client.py b/llm_client/llm_api_client/aleph_alpha_client.py index 01138b2..1aaff3e 100644 --- a/llm_client/llm_api_client/aleph_alpha_client.py +++ b/llm_client/llm_api_client/aleph_alpha_client.py @@ -27,7 +27,7 @@ def __init__(self, config: LLMAPIClientConfig): self._headers[AUTH_HEADER] = BEARER_TOKEN + self._api_key async def text_completion(self, prompt: str, model: Optional[str] = None, max_tokens: Optional[int] = None, - temperature: float = 0,top_p: Optional[float] = None, **kwargs) -> \ + temperature: float = 0,top_p: float = 0, **kwargs) -> \ list[str]: self._set_model_in_kwargs(kwargs, model) if max_tokens is None: diff --git a/llm_client/llm_api_client/anthropic_client.py b/llm_client/llm_api_client/anthropic_client.py index 3f41c7d..644a866 100644 --- a/llm_client/llm_api_client/anthropic_client.py +++ b/llm_client/llm_api_client/anthropic_client.py @@ -27,14 +27,15 @@ def __init__(self, config: LLMAPIClientConfig): self._headers[AUTH_HEADER] = self._api_key async def text_completion(self, prompt: str, model: Optional[str] = None, max_tokens: Optional[int] = None, - temperature: float = 1, top_p: float = -1, + temperature: float = 1, top_p: Optional[float] = None, **kwargs) -> \ list[str]: if max_tokens is None and kwargs.get(MAX_TOKENS_KEY) is None: raise ValueError(f"max_tokens or {MAX_TOKENS_KEY} must be specified") + if top_p: + kwargs["top_p"] = top_p self._set_model_in_kwargs(kwargs, model) kwargs[PROMPT_KEY] = prompt - kwargs["top_p"] = top_p kwargs[MAX_TOKENS_KEY] = kwargs.pop(MAX_TOKENS_KEY, max_tokens) kwargs["temperature"] = temperature response = await self._session.post(self._base_url + COMPLETE_PATH, diff --git a/tests/llm_api_client/ai21_client/test_ai21.py b/tests/llm_api_client/ai21_client/test_ai21.py index bf96278..beff1a2 100644 --- a/tests/llm_api_client/ai21_client/test_ai21.py +++ b/tests/llm_api_client/ai21_client/test_ai21.py @@ -30,7 +30,7 @@ async def test_text_completion__sanity(mock_aioresponse, llm_client, url): 'friends, entertaining family...you get the point! One of my favorite things to do is plan parties'] mock_aioresponse.assert_called_once_with(url, method='POST', headers={AUTH_HEADER: BEARER_TOKEN + llm_client._api_key }, - json={'prompt': 'These are a few of my favorite', "maxTokens" : 16, "temperature" : 0.7, "topP" : None }, + json={'prompt': 'These are a few of my favorite', "maxTokens" : 16, "temperature" : 0.7, "topP" : 0 }, raise_for_status=True) @@ -49,7 +49,7 @@ async def test_text_completion__return_multiple_completions(mock_aioresponse, ll ] mock_aioresponse.assert_called_once_with(url, method='POST', headers={AUTH_HEADER: BEARER_TOKEN + llm_client._api_key}, - json={'prompt': 'These are a few of my favorite', "maxTokens" : 16, "temperature" : 0.7, "topP" : None }, + json={'prompt': 'These are a few of my favorite', "maxTokens" : 16, "temperature" : 0.7, "topP" : 0 }, raise_for_status=True) @@ -69,7 +69,7 @@ async def test_text_completion__override_model(mock_aioresponse, llm_client): 'friends, entertaining family...you get the point! One of my favorite things to do is plan parties'] mock_aioresponse.assert_called_once_with(url, method='POST', headers={AUTH_HEADER: BEARER_TOKEN + llm_client._api_key}, - json={'prompt': 'These are a few of my favorite', "maxTokens" : 16, "temperature" : 0.7, "topP" : None }, + json={'prompt': 'These are a few of my favorite', "maxTokens" : 16, "temperature" : 0.7, "topP" : 0 }, raise_for_status=True) @@ -87,7 +87,7 @@ async def test_text_completion__with_kwargs(mock_aioresponse, llm_client, url): 'friends, entertaining family...you get the point! One of my favorite things to do is plan parties'] mock_aioresponse.assert_called_once_with(url, method='POST', headers={AUTH_HEADER: BEARER_TOKEN + llm_client._api_key}, - json={'prompt': 'These are a few of my favorite', "maxTokens" : 10, "temperature" : 0.7 ,"topP" : None}, + json={'prompt': 'These are a few of my favorite', "maxTokens" : 10, "temperature" : 0.7 ,"topP" : 0}, raise_for_status=True) diff --git a/tests/resources/openai/chat_completion.json b/tests/resources/openai/chat_completion.json index 3bbe453..8929ea6 100644 --- a/tests/resources/openai/chat_completion.json +++ b/tests/resources/openai/chat_completion.json @@ -13,7 +13,5 @@ "usage": { "prompt_tokens": 9, "completion_tokens": 12, - "total_tokens": 21, - "top_p" : 1 - } + "total_tokens": 21,} } diff --git a/tests/resources/openai/text_completion.json b/tests/resources/openai/text_completion.json index 38382c2..18a1b3e 100644 --- a/tests/resources/openai/text_completion.json +++ b/tests/resources/openai/text_completion.json @@ -14,7 +14,5 @@ "usage": { "prompt_tokens": 5, "completion_tokens": 7, - "total_tokens": 12, - "top_p" : 1 - } + "total_tokens": 12} }