From 5a0f93906fb0ca774d98c8ae63b97833d4bfc3fb Mon Sep 17 00:00:00 2001 From: Xiang Yan Date: Mon, 2 Aug 2021 16:54:40 -0700 Subject: [PATCH] Close session of ChallengePolicyClient in context manager --- sdk/containerregistry/azure-containerregistry/CHANGELOG.md | 2 ++ .../azure/containerregistry/_authentication_policy.py | 7 +++++++ .../azure/containerregistry/_base_client.py | 6 ++++-- .../aio/_async_anonymous_exchange_client.py | 4 ++-- .../containerregistry/aio/_async_authentication_policy.py | 7 +++++++ .../azure/containerregistry/aio/_async_base_client.py | 6 ++++-- .../azure/containerregistry/aio/_async_exchange_client.py | 4 ++-- 7 files changed, 28 insertions(+), 8 deletions(-) diff --git a/sdk/containerregistry/azure-containerregistry/CHANGELOG.md b/sdk/containerregistry/azure-containerregistry/CHANGELOG.md index 71e58a01938e..a07d6409a79e 100644 --- a/sdk/containerregistry/azure-containerregistry/CHANGELOG.md +++ b/sdk/containerregistry/azure-containerregistry/CHANGELOG.md @@ -8,6 +8,8 @@ ### Bugs Fixed +- Close session of `ChallengePolicyClient` in context manager #20000 + ### Other Changes - Bumped dependency on `msrest` to `>=0.6.21` diff --git a/sdk/containerregistry/azure-containerregistry/azure/containerregistry/_authentication_policy.py b/sdk/containerregistry/azure-containerregistry/azure/containerregistry/_authentication_policy.py index 202a14dddc79..4f97428d6b61 100644 --- a/sdk/containerregistry/azure-containerregistry/azure/containerregistry/_authentication_policy.py +++ b/sdk/containerregistry/azure-containerregistry/azure/containerregistry/_authentication_policy.py @@ -70,3 +70,10 @@ def on_challenge(self, request, response, challenge): access_token = self._exchange_client.get_acr_access_token(challenge) request.http_request.headers["Authorization"] = "Bearer " + access_token return access_token is not None + + def __enter__(self): + self._exchange_client.__enter__() + return self + + def __exit__(self, *args): + self._exchange_client.__exit__(*args) diff --git a/sdk/containerregistry/azure-containerregistry/azure/containerregistry/_base_client.py b/sdk/containerregistry/azure-containerregistry/azure/containerregistry/_base_client.py index 2b79ade179cf..dc48722d3f83 100644 --- a/sdk/containerregistry/azure-containerregistry/azure/containerregistry/_base_client.py +++ b/sdk/containerregistry/azure-containerregistry/azure/containerregistry/_base_client.py @@ -34,20 +34,22 @@ class ContainerRegistryBaseClient(object): def __init__(self, endpoint, credential, **kwargs): # type: (str, Optional[TokenCredential], Dict[str, Any]) -> None - auth_policy = ContainerRegistryChallengePolicy(credential, endpoint, **kwargs) + self._auth_policy = ContainerRegistryChallengePolicy(credential, endpoint, **kwargs) self._client = ContainerRegistry( credential=credential, url=endpoint, sdk_moniker=USER_AGENT, - authentication_policy=auth_policy, + authentication_policy=self._auth_policy, **kwargs ) def __enter__(self): + self._auth_policy.__enter__() self._client.__enter__() return self def __exit__(self, *args): + self._auth_policy.__exit__(*args) self._client.__exit__(*args) def close(self): diff --git a/sdk/containerregistry/azure-containerregistry/azure/containerregistry/aio/_async_anonymous_exchange_client.py b/sdk/containerregistry/azure-containerregistry/azure/containerregistry/aio/_async_anonymous_exchange_client.py index 78575b90e9ac..a42c0723c17f 100644 --- a/sdk/containerregistry/azure-containerregistry/azure/containerregistry/aio/_async_anonymous_exchange_client.py +++ b/sdk/containerregistry/azure-containerregistry/azure/containerregistry/aio/_async_anonymous_exchange_client.py @@ -61,11 +61,11 @@ async def exchange_refresh_token_for_access_token( return access_token.access_token async def __aenter__(self): - self._client.__aenter__() + await self._client.__aenter__() return self async def __aexit__(self, *args): - self._client.__aexit__(*args) + await self._client.__aexit__(*args) async def close(self) -> None: """Close sockets opened by the client. diff --git a/sdk/containerregistry/azure-containerregistry/azure/containerregistry/aio/_async_authentication_policy.py b/sdk/containerregistry/azure-containerregistry/azure/containerregistry/aio/_async_authentication_policy.py index 8f9858d65da9..ff7fc0c7f6b8 100644 --- a/sdk/containerregistry/azure-containerregistry/azure/containerregistry/aio/_async_authentication_policy.py +++ b/sdk/containerregistry/azure-containerregistry/azure/containerregistry/aio/_async_authentication_policy.py @@ -67,3 +67,10 @@ async def on_challenge(self, request, response, challenge): access_token = await self._exchange_client.get_acr_access_token(challenge) request.http_request.headers["Authorization"] = "Bearer " + access_token return access_token is not None + + async def __aenter__(self): + await self._exchange_client.__aenter__() + return self + + async def __aexit__(self, *args): + await self._exchange_client.__aexit__() diff --git a/sdk/containerregistry/azure-containerregistry/azure/containerregistry/aio/_async_base_client.py b/sdk/containerregistry/azure-containerregistry/azure/containerregistry/aio/_async_base_client.py index 4b5fcbfe18b8..29e1e13f0832 100644 --- a/sdk/containerregistry/azure-containerregistry/azure/containerregistry/aio/_async_base_client.py +++ b/sdk/containerregistry/azure-containerregistry/azure/containerregistry/aio/_async_base_client.py @@ -33,20 +33,22 @@ class ContainerRegistryBaseClient(object): """ def __init__(self, endpoint: str, credential: Optional["AsyncTokenCredential"] = None, **kwargs) -> None: - auth_policy = ContainerRegistryChallengePolicy(credential, endpoint, **kwargs) + self._auth_policy = ContainerRegistryChallengePolicy(credential, endpoint, **kwargs) self._client = ContainerRegistry( credential=credential, url=endpoint, sdk_moniker=USER_AGENT, - authentication_policy=auth_policy, + authentication_policy=self._auth_policy, **kwargs ) async def __aenter__(self): + await self._auth_policy.__aenter__() await self._client.__aenter__() return self async def __aexit__(self, *args): + await self._auth_policy.__aexit__(*args) await self._client.__aexit__(*args) async def close(self) -> None: diff --git a/sdk/containerregistry/azure-containerregistry/azure/containerregistry/aio/_async_exchange_client.py b/sdk/containerregistry/azure-containerregistry/azure/containerregistry/aio/_async_exchange_client.py index 5ce2c48d7354..944cf74f1a72 100644 --- a/sdk/containerregistry/azure-containerregistry/azure/containerregistry/aio/_async_exchange_client.py +++ b/sdk/containerregistry/azure-containerregistry/azure/containerregistry/aio/_async_exchange_client.py @@ -81,11 +81,11 @@ async def exchange_refresh_token_for_access_token( return access_token.access_token async def __aenter__(self): - self._client.__aenter__() + await self._client.__aenter__() return self async def __aexit__(self, *args): - self._client.__aexit__(*args) + await self._client.__aexit__(*args) async def close(self) -> None: """Close sockets opened by the client.