Skip to content

Commit

Permalink
changes after cr
Browse files Browse the repository at this point in the history
  • Loading branch information
EyalPazz committed Jul 14, 2023
1 parent 6bbaf86 commit 6b8e747
Show file tree
Hide file tree
Showing 6 changed files with 11 additions and 14 deletions.
2 changes: 1 addition & 1 deletion llm_client/llm_api_client/ai21_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def __init__(self, config: LLMAPIClientConfig):
self._headers[AUTH_HEADER] = BEARER_TOKEN + self._api_key

async def text_completion(self, prompt: str, model: Optional[str] = None, max_tokens: int = 16,
temperature: float = 0.7, top_p: Optional[float] = None ,**kwargs) -> list[str]:
temperature: float = 0.7, top_p: float = 1 ,**kwargs) -> list[str]:
model = model or self._default_model
kwargs[PROMPT_KEY] = prompt
kwargs["topP"] = kwargs.pop("topP", top_p)
Expand Down
2 changes: 1 addition & 1 deletion llm_client/llm_api_client/aleph_alpha_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def __init__(self, config: LLMAPIClientConfig):
self._headers[AUTH_HEADER] = BEARER_TOKEN + self._api_key

async def text_completion(self, prompt: str, model: Optional[str] = None, max_tokens: Optional[int] = None,
temperature: float = 0,top_p: Optional[float] = None, **kwargs) -> \
temperature: float = 0,top_p: float = 0, **kwargs) -> \
list[str]:
self._set_model_in_kwargs(kwargs, model)
if max_tokens is None:
Expand Down
5 changes: 3 additions & 2 deletions llm_client/llm_api_client/anthropic_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,15 @@ def __init__(self, config: LLMAPIClientConfig):
self._headers[AUTH_HEADER] = self._api_key

async def text_completion(self, prompt: str, model: Optional[str] = None, max_tokens: Optional[int] = None,
temperature: float = 1, top_p: float = -1,
temperature: float = 1, top_p: Optional[float] = None,
**kwargs) -> \
list[str]:
if max_tokens is None and kwargs.get(MAX_TOKENS_KEY) is None:
raise ValueError(f"max_tokens or {MAX_TOKENS_KEY} must be specified")
if top_p:
kwargs["top_p"] = top_p
self._set_model_in_kwargs(kwargs, model)
kwargs[PROMPT_KEY] = prompt
kwargs["top_p"] = top_p
kwargs[MAX_TOKENS_KEY] = kwargs.pop(MAX_TOKENS_KEY, max_tokens)
kwargs["temperature"] = temperature
response = await self._session.post(self._base_url + COMPLETE_PATH,
Expand Down
8 changes: 4 additions & 4 deletions tests/llm_api_client/ai21_client/test_ai21.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ async def test_text_completion__sanity(mock_aioresponse, llm_client, url):
'friends, entertaining family...you get the point! One of my favorite things to do is plan parties']
mock_aioresponse.assert_called_once_with(url, method='POST',
headers={AUTH_HEADER: BEARER_TOKEN + llm_client._api_key },
json={'prompt': 'These are a few of my favorite', "maxTokens" : 16, "temperature" : 0.7, "topP" : None },
json={'prompt': 'These are a few of my favorite', "maxTokens" : 16, "temperature" : 0.7, "topP" : 0 },
raise_for_status=True)


Expand All @@ -49,7 +49,7 @@ async def test_text_completion__return_multiple_completions(mock_aioresponse, ll
]
mock_aioresponse.assert_called_once_with(url, method='POST',
headers={AUTH_HEADER: BEARER_TOKEN + llm_client._api_key},
json={'prompt': 'These are a few of my favorite', "maxTokens" : 16, "temperature" : 0.7, "topP" : None },
json={'prompt': 'These are a few of my favorite', "maxTokens" : 16, "temperature" : 0.7, "topP" : 0 },
raise_for_status=True)


Expand All @@ -69,7 +69,7 @@ async def test_text_completion__override_model(mock_aioresponse, llm_client):
'friends, entertaining family...you get the point! One of my favorite things to do is plan parties']
mock_aioresponse.assert_called_once_with(url, method='POST',
headers={AUTH_HEADER: BEARER_TOKEN + llm_client._api_key},
json={'prompt': 'These are a few of my favorite', "maxTokens" : 16, "temperature" : 0.7, "topP" : None },
json={'prompt': 'These are a few of my favorite', "maxTokens" : 16, "temperature" : 0.7, "topP" : 0 },
raise_for_status=True)


Expand All @@ -87,7 +87,7 @@ async def test_text_completion__with_kwargs(mock_aioresponse, llm_client, url):
'friends, entertaining family...you get the point! One of my favorite things to do is plan parties']
mock_aioresponse.assert_called_once_with(url, method='POST',
headers={AUTH_HEADER: BEARER_TOKEN + llm_client._api_key},
json={'prompt': 'These are a few of my favorite', "maxTokens" : 10, "temperature" : 0.7 ,"topP" : None},
json={'prompt': 'These are a few of my favorite', "maxTokens" : 10, "temperature" : 0.7 ,"topP" : 0},
raise_for_status=True)


Expand Down
4 changes: 1 addition & 3 deletions tests/resources/openai/chat_completion.json
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,5 @@
"usage": {
"prompt_tokens": 9,
"completion_tokens": 12,
"total_tokens": 21,
"top_p" : 1
}
"total_tokens": 21,}
}
4 changes: 1 addition & 3 deletions tests/resources/openai/text_completion.json
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,5 @@
"usage": {
"prompt_tokens": 5,
"completion_tokens": 7,
"total_tokens": 12,
"top_p" : 1
}
"total_tokens": 12}
}

0 comments on commit 6b8e747

Please sign in to comment.