Skip to content

Commit

Permalink
small changes
Browse files Browse the repository at this point in the history
  • Loading branch information
EyalPazz committed Jul 14, 2023
1 parent 8ab999e commit 31c1f19
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 7 deletions.
5 changes: 3 additions & 2 deletions llm_client/llm_api_client/google_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,12 @@ def __init__(self, config: LLMAPIClientConfig):
self._params = {AUTH_PARAM: self._api_key}

async def text_completion(self, prompt: str, model: Optional[str] = None, max_tokens: Optional[int] = 64,
temperature: Optional[float] = None,top_p: float = 0.95, **kwargs) -> list[str]:
temperature: Optional[float] = None,top_p: Optional[float] = None, **kwargs) -> list[str]:
model = model or self._default_model
kwargs[PROMPT_KEY] = {TEXT_KEY: prompt}
kwargs[MAX_TOKENS_KEY] = kwargs.pop(MAX_TOKENS_KEY, max_tokens)
kwargs["topP"] = kwargs.pop("topP", top_p)
if top_p or ("topP" in kwargs):
kwargs["topP"] = kwargs.pop("topP", top_p)
kwargs["temperature"] = kwargs.pop("temperature", temperature)
response = await self._session.post(self._base_url + model + ":" + COMPLETE_PATH,
params=self._params,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ async def test_text_completion__with_kwargs(mock_aioresponse, llm_client, comple
payload={COMPLETIONS_KEY: "completion text"}
)

actual = await llm_client.text_completion(prompt="These are a few of my favorite", max_tokens=10, temperature=0.5)
actual = await llm_client.text_completion(prompt="These are a few of my favorite", max_tokens=10, temperature=0.5,top_p=0.5)

assert actual == ["completion text"]
mock_aioresponse.assert_called_once_with(complete_url, method='POST',
Expand All @@ -101,7 +101,7 @@ async def test_text_completion__with_kwargs(mock_aioresponse, llm_client, comple
json={PROMPT_KEY: 'These are a few of my favorite',
MAX_TOKENS_KEY: 10,
MODEL_KEY: llm_client._default_model,
"temperature": 0.5},
"temperature": 0.5, "top_p" : 0.5},
raise_for_status=True)


Expand Down
6 changes: 3 additions & 3 deletions tests/llm_api_client/google_client/test_google_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ async def test_text_completion__sanity(mock_aioresponse, llm_client, params):
'Once upon a time, there was a young boy named Billy...']
mock_aioresponse.assert_called_once_with(url, method='POST', params={AUTH_PARAM: llm_client._api_key},
json={PROMPT_KEY: {TEXT_KEY: 'These are a few of my favorite'},
MAX_TOKENS_KEY: 64,"topP" : 0.95,
MAX_TOKENS_KEY: 64,
'temperature': None},
headers=llm_client._headers,
raise_for_status=True,
Expand All @@ -53,7 +53,7 @@ async def test_text_completion__override_model(mock_aioresponse, llm_client, par
'Once upon a time, there was a young boy named Billy...']
mock_aioresponse.assert_called_once_with(url, method='POST', params={AUTH_PARAM: llm_client._api_key},
json={PROMPT_KEY: {TEXT_KEY: 'These are a few of my favorite'},
MAX_TOKENS_KEY: 64,"topP" : 0.95,
MAX_TOKENS_KEY: 64,
'temperature': None},
headers=llm_client._headers,
raise_for_status=True,
Expand All @@ -68,7 +68,7 @@ async def test_text_completion__with_kwargs(mock_aioresponse, llm_client, params
payload=load_json_resource("google/text_completion.json")
)

actual = await llm_client.text_completion(prompt="These are a few of my favorite", max_tokens=10, blabla="aaa")
actual = await llm_client.text_completion(prompt="These are a few of my favorite", max_tokens=10, blabla="aaa", top_p= 0.95)

assert actual == ['Once upon a time, there was a young girl named Lily...',
'Once upon a time, there was a young boy named Billy...']
Expand Down

0 comments on commit 31c1f19

Please sign in to comment.