diff --git a/src/openai/_client.py b/src/openai/_client.py index 6664dc4233..aa00073281 100644 --- a/src/openai/_client.py +++ b/src/openai/_client.py @@ -4,8 +4,8 @@ import os import asyncio -from typing import Union, Mapping -from typing_extensions import override +from typing import Any, Union, Mapping +from typing_extensions import Self, override import httpx @@ -164,12 +164,10 @@ def copy( set_default_headers: Mapping[str, str] | None = None, default_query: Mapping[str, object] | None = None, set_default_query: Mapping[str, object] | None = None, - ) -> OpenAI: + _extra_kwargs: Mapping[str, Any] = {}, + ) -> Self: """ Create a new client instance re-using the same options given to the current client with optional overriding. - - It should be noted that this does not share the underlying httpx client class which may lead - to performance issues. """ if default_headers is not None and set_default_headers is not None: raise ValueError("The `default_headers` and `set_default_headers` arguments are mutually exclusive") @@ -199,6 +197,7 @@ def copy( max_retries=max_retries if is_given(max_retries) else self.max_retries, default_headers=headers, default_query=params, + **_extra_kwargs, ) # Alias for `copy` for nicer inline usage, e.g. @@ -374,12 +373,10 @@ def copy( set_default_headers: Mapping[str, str] | None = None, default_query: Mapping[str, object] | None = None, set_default_query: Mapping[str, object] | None = None, - ) -> AsyncOpenAI: + _extra_kwargs: Mapping[str, Any] = {}, + ) -> Self: """ Create a new client instance re-using the same options given to the current client with optional overriding. - - It should be noted that this does not share the underlying httpx client class which may lead - to performance issues. """ if default_headers is not None and set_default_headers is not None: raise ValueError("The `default_headers` and `set_default_headers` arguments are mutually exclusive") @@ -409,6 +406,7 @@ def copy( max_retries=max_retries if is_given(max_retries) else self.max_retries, default_headers=headers, default_query=params, + **_extra_kwargs, ) # Alias for `copy` for nicer inline usage, e.g. diff --git a/src/openai/lib/azure.py b/src/openai/lib/azure.py index d31313e95a..27bebd8cab 100644 --- a/src/openai/lib/azure.py +++ b/src/openai/lib/azure.py @@ -3,7 +3,7 @@ import os import inspect from typing import Any, Union, Mapping, TypeVar, Callable, Awaitable, overload -from typing_extensions import override +from typing_extensions import Self, override import httpx @@ -178,7 +178,7 @@ def __init__( if default_query is None: default_query = {"api-version": api_version} else: - default_query = {"api-version": api_version, **default_query} + default_query = {**default_query, "api-version": api_version} if base_url is None: if azure_endpoint is None: @@ -212,9 +212,53 @@ def __init__( http_client=http_client, _strict_response_validation=_strict_response_validation, ) + self._api_version = api_version self._azure_ad_token = azure_ad_token self._azure_ad_token_provider = azure_ad_token_provider + @override + def copy( + self, + *, + api_key: str | None = None, + organization: str | None = None, + api_version: str | None = None, + azure_ad_token: str | None = None, + azure_ad_token_provider: AzureADTokenProvider | None = None, + base_url: str | httpx.URL | None = None, + timeout: float | Timeout | None | NotGiven = NOT_GIVEN, + http_client: httpx.Client | None = None, + max_retries: int | NotGiven = NOT_GIVEN, + default_headers: Mapping[str, str] | None = None, + set_default_headers: Mapping[str, str] | None = None, + default_query: Mapping[str, object] | None = None, + set_default_query: Mapping[str, object] | None = None, + _extra_kwargs: Mapping[str, Any] = {}, + ) -> Self: + """ + Create a new client instance re-using the same options given to the current client with optional overriding. + """ + return super().copy( + api_key=api_key, + organization=organization, + base_url=base_url, + timeout=timeout, + http_client=http_client, + max_retries=max_retries, + default_headers=default_headers, + set_default_headers=set_default_headers, + default_query=default_query, + set_default_query=set_default_query, + _extra_kwargs={ + "api_version": api_version or self._api_version, + "azure_ad_token": azure_ad_token or self._azure_ad_token, + "azure_ad_token_provider": azure_ad_token_provider or self._azure_ad_token_provider, + **_extra_kwargs, + }, + ) + + with_options = copy + def _get_azure_ad_token(self) -> str | None: if self._azure_ad_token is not None: return self._azure_ad_token @@ -367,7 +411,7 @@ def __init__( if default_query is None: default_query = {"api-version": api_version} else: - default_query = {"api-version": api_version, **default_query} + default_query = {**default_query, "api-version": api_version} if base_url is None: if azure_endpoint is None: @@ -401,9 +445,53 @@ def __init__( http_client=http_client, _strict_response_validation=_strict_response_validation, ) + self._api_version = api_version self._azure_ad_token = azure_ad_token self._azure_ad_token_provider = azure_ad_token_provider + @override + def copy( + self, + *, + api_key: str | None = None, + organization: str | None = None, + api_version: str | None = None, + azure_ad_token: str | None = None, + azure_ad_token_provider: AsyncAzureADTokenProvider | None = None, + base_url: str | httpx.URL | None = None, + timeout: float | Timeout | None | NotGiven = NOT_GIVEN, + http_client: httpx.AsyncClient | None = None, + max_retries: int | NotGiven = NOT_GIVEN, + default_headers: Mapping[str, str] | None = None, + set_default_headers: Mapping[str, str] | None = None, + default_query: Mapping[str, object] | None = None, + set_default_query: Mapping[str, object] | None = None, + _extra_kwargs: Mapping[str, Any] = {}, + ) -> Self: + """ + Create a new client instance re-using the same options given to the current client with optional overriding. + """ + return super().copy( + api_key=api_key, + organization=organization, + base_url=base_url, + timeout=timeout, + http_client=http_client, + max_retries=max_retries, + default_headers=default_headers, + set_default_headers=set_default_headers, + default_query=default_query, + set_default_query=set_default_query, + _extra_kwargs={ + "api_version": api_version or self._api_version, + "azure_ad_token": azure_ad_token or self._azure_ad_token, + "azure_ad_token_provider": azure_ad_token_provider or self._azure_ad_token_provider, + **_extra_kwargs, + }, + ) + + with_options = copy + async def _get_azure_ad_token(self) -> str | None: if self._azure_ad_token is not None: return self._azure_ad_token diff --git a/tests/lib/test_azure.py b/tests/lib/test_azure.py index b0bd87571b..9360b2925a 100644 --- a/tests/lib/test_azure.py +++ b/tests/lib/test_azure.py @@ -1,4 +1,5 @@ from typing import Union +from typing_extensions import Literal import pytest @@ -34,3 +35,32 @@ def test_implicit_deployment_path(client: Client) -> None: req.url == "https://example-resource.azure.openai.com/openai/deployments/my-deployment-model/chat/completions?api-version=2023-07-01" ) + + +@pytest.mark.parametrize( + "client,method", + [ + (sync_client, "copy"), + (sync_client, "with_options"), + (async_client, "copy"), + (async_client, "with_options"), + ], +) +def test_client_copying(client: Client, method: Literal["copy", "with_options"]) -> None: + if method == "copy": + copied = client.copy() + else: + copied = client.with_options() + + assert copied._custom_query == {"api-version": "2023-07-01"} + + +@pytest.mark.parametrize( + "client", + [sync_client, async_client], +) +def test_client_copying_override_options(client: Client) -> None: + copied = client.copy( + api_version="2022-05-01", + ) + assert copied._custom_query == {"api-version": "2022-05-01"}