Skip to content

Commit

Permalink
added top_p for google and BaseLLMClientApi
Browse files Browse the repository at this point in the history
  • Loading branch information
EyalPazz committed Jul 1, 2023
1 parent 432b155 commit 6bbaf86
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 5 deletions.
2 changes: 1 addition & 1 deletion llm_client/llm_api_client/base_llm_api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def __init__(self, config: LLMAPIClientConfig):

@abstractmethod
async def text_completion(self, prompt: str, model: Optional[str] = None, max_tokens: Optional[int] = None,
temperature: Optional[float] = None, **kwargs) -> list[str]:
temperature: Optional[float] = None,top_p : Optional[float] = None, **kwargs) -> list[str]:
raise NotImplementedError()

async def embedding(self, text: str, model: Optional[str] = None, **kwargs) -> list[float]:
Expand Down
3 changes: 2 additions & 1 deletion llm_client/llm_api_client/google_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,11 @@ def __init__(self, config: LLMAPIClientConfig):
self._params = {AUTH_PARAM: self._api_key}

async def text_completion(self, prompt: str, model: Optional[str] = None, max_tokens: Optional[int] = 64,
temperature: Optional[float] = None, **kwargs) -> list[str]:
temperature: Optional[float] = None,top_p: float = 0.95, **kwargs) -> list[str]:
model = model or self._default_model
kwargs[PROMPT_KEY] = {TEXT_KEY: prompt}
kwargs[MAX_TOKENS_KEY] = kwargs.pop(MAX_TOKENS_KEY, max_tokens)
kwargs["topP"] = kwargs.pop("topP", top_p)
kwargs["temperature"] = kwargs.pop("temperature", temperature)
response = await self._session.post(self._base_url + model + ":" + COMPLETE_PATH,
params=self._params,
Expand Down
6 changes: 3 additions & 3 deletions tests/llm_api_client/google_client/test_google_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ async def test_text_completion__sanity(mock_aioresponse, llm_client, params):
'Once upon a time, there was a young boy named Billy...']
mock_aioresponse.assert_called_once_with(url, method='POST', params={AUTH_PARAM: llm_client._api_key},
json={PROMPT_KEY: {TEXT_KEY: 'These are a few of my favorite'},
MAX_TOKENS_KEY: 64,
MAX_TOKENS_KEY: 64,"topP" : 0.95,
'temperature': None},
headers=llm_client._headers,
raise_for_status=True,
Expand All @@ -53,7 +53,7 @@ async def test_text_completion__override_model(mock_aioresponse, llm_client, par
'Once upon a time, there was a young boy named Billy...']
mock_aioresponse.assert_called_once_with(url, method='POST', params={AUTH_PARAM: llm_client._api_key},
json={PROMPT_KEY: {TEXT_KEY: 'These are a few of my favorite'},
MAX_TOKENS_KEY: 64,
MAX_TOKENS_KEY: 64,"topP" : 0.95,
'temperature': None},
headers=llm_client._headers,
raise_for_status=True,
Expand All @@ -76,7 +76,7 @@ async def test_text_completion__with_kwargs(mock_aioresponse, llm_client, params
json={PROMPT_KEY: {TEXT_KEY: 'These are a few of my favorite'},
MAX_TOKENS_KEY: 10,
'temperature': None,
'blabla': 'aaa'},
'blabla': 'aaa',"topP" : 0.95},
headers=llm_client._headers,
raise_for_status=True,
)
Expand Down

0 comments on commit 6bbaf86

Please sign in to comment.