diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/shared_cache.py b/sdk/identity/azure-identity/azure/identity/_credentials/shared_cache.py index edddc8a8974c..4ebb45a226c5 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/shared_cache.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/shared_cache.py @@ -2,6 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # ------------------------------------ +import os import time from msal.application import PublicClientApplication @@ -119,12 +120,16 @@ def _initialize(self): self._load_cache() if self._cache: + if "AZURE_IDENTITY_DISABLE_CP1" in os.environ: + capabilities = None + else: + capabilities = ["CP1"] # able to handle CAE claims challenges self._app = PublicClientApplication( client_id=self._auth_record.client_id, authority="https://{}/{}".format(self._auth_record.authority, self._tenant_id), token_cache=self._cache, http_client=MsalClient(**self._client_kwargs), - client_capabilities=["CP1"] + client_capabilities=capabilities ) self._initialized = True diff --git a/sdk/identity/azure-identity/azure/identity/_internal/interactive.py b/sdk/identity/azure-identity/azure/identity/_internal/interactive.py index 3b678d64f8da..2eb9653900e9 100644 --- a/sdk/identity/azure-identity/azure/identity/_internal/interactive.py +++ b/sdk/identity/azure-identity/azure/identity/_internal/interactive.py @@ -8,6 +8,7 @@ import base64 import json import logging +import os import time from typing import TYPE_CHECKING @@ -206,7 +207,11 @@ def _acquire_token_silent(self, *scopes, **kwargs): def _get_app(self): # type: () -> msal.PublicClientApplication if not self._msal_app: - self._msal_app = self._create_app(msal.PublicClientApplication, client_capabilities=["CP1"]) + if "AZURE_IDENTITY_DISABLE_CP1" in os.environ: + capabilities = None + else: + capabilities = ["CP1"] # able to handle CAE claims challenges + self._msal_app = self._create_app(msal.PublicClientApplication, client_capabilities=capabilities) return self._msal_app @abc.abstractmethod diff --git a/sdk/identity/azure-identity/tests/test_device_code_credential.py b/sdk/identity/azure-identity/tests/test_device_code_credential.py index 98c2647679f3..0df5feb02629 100644 --- a/sdk/identity/azure-identity/tests/test_device_code_credential.py +++ b/sdk/identity/azure-identity/tests/test_device_code_credential.py @@ -248,18 +248,25 @@ def test_timeout(): def test_client_capabilities(): - """the credential should configure MSAL for capability CP1 (ability to handle claims challenges)""" + """the credential should configure MSAL for capability CP1 unless AZURE_IDENTITY_DISABLE_CP1 is set""" transport = Mock(send=Mock(side_effect=Exception("this test mocks MSAL, so no request should be sent"))) - credential = DeviceCodeCredential(transport=transport) with patch("msal.PublicClientApplication") as PublicClientApplication: - credential._get_app() + DeviceCodeCredential(transport=transport)._get_app() assert PublicClientApplication.call_count == 1 _, kwargs = PublicClientApplication.call_args assert kwargs["client_capabilities"] == ["CP1"] + with patch.dict("os.environ", {"AZURE_IDENTITY_DISABLE_CP1": "true"}): + with patch("msal.PublicClientApplication") as PublicClientApplication: + DeviceCodeCredential(transport=transport)._get_app() + + assert PublicClientApplication.call_count == 1 + _, kwargs = PublicClientApplication.call_args + assert kwargs["client_capabilities"] is None + def test_claims_challenge(): """get_token and authenticate should pass any claims challenge to MSAL token acquisition APIs""" diff --git a/sdk/identity/azure-identity/tests/test_shared_cache_credential.py b/sdk/identity/azure-identity/tests/test_shared_cache_credential.py index 9ee2dcf1327f..12dec789f76f 100644 --- a/sdk/identity/azure-identity/tests/test_shared_cache_credential.py +++ b/sdk/identity/azure-identity/tests/test_shared_cache_credential.py @@ -736,7 +736,7 @@ def mock_send(request, **_): def test_client_capabilities(): - """the credential should configure MSAL for capability CP1 (ability to handle claims challenges)""" + """the credential should configure MSAL for capability CP1 unless AZURE_IDENTITY_DISABLE_CP1 is set""" record = AuthenticationRecord("tenant-id", "client_id", "authority", "home_account_id", "username") transport = Mock(send=Mock(side_effect=Exception("this test mocks MSAL, so no request should be sent"))) @@ -751,6 +751,17 @@ def test_client_capabilities(): _, kwargs = PublicClientApplication.call_args assert kwargs["client_capabilities"] == ["CP1"] + credential = SharedTokenCacheCredential( + transport=transport, authentication_record=record, _cache=TokenCache() + ) + with patch(SharedTokenCacheCredential.__module__ + ".PublicClientApplication") as PublicClientApplication: + with patch.dict("os.environ", {"AZURE_IDENTITY_DISABLE_CP1": "true"}): + credential._initialize() + + assert PublicClientApplication.call_count == 1 + _, kwargs = PublicClientApplication.call_args + assert kwargs["client_capabilities"] is None + def test_claims_challenge(): """get_token should pass any claims challenge to MSAL token acquisition APIs""" diff --git a/sdk/identity/azure-identity/tests/test_username_password_credential.py b/sdk/identity/azure-identity/tests/test_username_password_credential.py index 3ac70a1fb544..d07f03dfa1d1 100644 --- a/sdk/identity/azure-identity/tests/test_username_password_credential.py +++ b/sdk/identity/azure-identity/tests/test_username_password_credential.py @@ -149,11 +149,11 @@ def test_authenticate(): def test_client_capabilities(): - """the credential should configure MSAL for capability CP1 (ability to handle claims challenges)""" + """the credential should configure MSAL for capability CP1 unless AZURE_IDENTITY_DISABLE_CP1 is set""" transport = Mock(send=Mock(side_effect=Exception("this test mocks MSAL, so no request should be sent"))) - credential = UsernamePasswordCredential("client-id", "username", "password", transport=transport) + credential = UsernamePasswordCredential("client-id", "username", "password", transport=transport) with patch("msal.PublicClientApplication") as PublicClientApplication: credential._get_app() @@ -161,6 +161,15 @@ def test_client_capabilities(): _, kwargs = PublicClientApplication.call_args assert kwargs["client_capabilities"] == ["CP1"] + credential = UsernamePasswordCredential("client-id", "username", "password", transport=transport) + with patch.dict("os.environ", {"AZURE_IDENTITY_DISABLE_CP1": "true"}): + with patch("msal.PublicClientApplication") as PublicClientApplication: + credential._get_app() + + assert PublicClientApplication.call_count == 1 + _, kwargs = PublicClientApplication.call_args + assert kwargs["client_capabilities"] is None + def test_claims_challenge(): """get_token should and authenticate pass any claims challenge to MSAL token acquisition APIs"""