diff --git a/llm_client/llm_api_client/base_llm_api_client.py b/llm_client/llm_api_client/base_llm_api_client.py index dc092d8..3881d94 100644 --- a/llm_client/llm_api_client/base_llm_api_client.py +++ b/llm_client/llm_api_client/base_llm_api_client.py @@ -30,7 +30,7 @@ def __init__(self, config: LLMAPIClientConfig): @abstractmethod async def text_completion(self, prompt: str, model: Optional[str] = None, max_tokens: Optional[int] = None, - temperature: Optional[float] = None, **kwargs) -> list[str]: + temperature: Optional[float] = None,top_p : Optional[float] = None, **kwargs) -> list[str]: raise NotImplementedError() async def embedding(self, text: str, model: Optional[str] = None, **kwargs) -> list[float]: diff --git a/llm_client/llm_api_client/google_client.py b/llm_client/llm_api_client/google_client.py index 8a93f79..0fda2d6 100644 --- a/llm_client/llm_api_client/google_client.py +++ b/llm_client/llm_api_client/google_client.py @@ -33,10 +33,11 @@ def __init__(self, config: LLMAPIClientConfig): self._params = {AUTH_PARAM: self._api_key} async def text_completion(self, prompt: str, model: Optional[str] = None, max_tokens: Optional[int] = 64, - temperature: Optional[float] = None, **kwargs) -> list[str]: + temperature: Optional[float] = None,top_p: float = 0.95, **kwargs) -> list[str]: model = model or self._default_model kwargs[PROMPT_KEY] = {TEXT_KEY: prompt} kwargs[MAX_TOKENS_KEY] = kwargs.pop(MAX_TOKENS_KEY, max_tokens) + kwargs["topP"] = kwargs.pop("topP", top_p) kwargs["temperature"] = kwargs.pop("temperature", temperature) response = await self._session.post(self._base_url + model + ":" + COMPLETE_PATH, params=self._params, diff --git a/tests/llm_api_client/google_client/test_google_client.py b/tests/llm_api_client/google_client/test_google_client.py index d404a2c..4081740 100644 --- a/tests/llm_api_client/google_client/test_google_client.py +++ b/tests/llm_api_client/google_client/test_google_client.py @@ -31,7 +31,7 @@ async def test_text_completion__sanity(mock_aioresponse, llm_client, params): 'Once upon a time, there was a young boy named Billy...'] mock_aioresponse.assert_called_once_with(url, method='POST', params={AUTH_PARAM: llm_client._api_key}, json={PROMPT_KEY: {TEXT_KEY: 'These are a few of my favorite'}, - MAX_TOKENS_KEY: 64, + MAX_TOKENS_KEY: 64,"topP" : 0.95, 'temperature': None}, headers=llm_client._headers, raise_for_status=True, @@ -53,7 +53,7 @@ async def test_text_completion__override_model(mock_aioresponse, llm_client, par 'Once upon a time, there was a young boy named Billy...'] mock_aioresponse.assert_called_once_with(url, method='POST', params={AUTH_PARAM: llm_client._api_key}, json={PROMPT_KEY: {TEXT_KEY: 'These are a few of my favorite'}, - MAX_TOKENS_KEY: 64, + MAX_TOKENS_KEY: 64,"topP" : 0.95, 'temperature': None}, headers=llm_client._headers, raise_for_status=True, @@ -76,7 +76,7 @@ async def test_text_completion__with_kwargs(mock_aioresponse, llm_client, params json={PROMPT_KEY: {TEXT_KEY: 'These are a few of my favorite'}, MAX_TOKENS_KEY: 10, 'temperature': None, - 'blabla': 'aaa'}, + 'blabla': 'aaa',"topP" : 0.95}, headers=llm_client._headers, raise_for_status=True, )