Skip to content

Commit

Permalink
Add compatibility switch to disable CAE (#18148)
Browse files Browse the repository at this point in the history
  • Loading branch information
chlowell authored May 4, 2021
1 parent aa659c1 commit 701a6f5
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
import os
import time

from msal.application import PublicClientApplication
Expand Down Expand Up @@ -119,12 +120,16 @@ def _initialize(self):

self._load_cache()
if self._cache:
if "AZURE_IDENTITY_DISABLE_CP1" in os.environ:
capabilities = None
else:
capabilities = ["CP1"] # able to handle CAE claims challenges
self._app = PublicClientApplication(
client_id=self._auth_record.client_id,
authority="https://{}/{}".format(self._auth_record.authority, self._tenant_id),
token_cache=self._cache,
http_client=MsalClient(**self._client_kwargs),
client_capabilities=["CP1"]
client_capabilities=capabilities
)

self._initialized = True
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import base64
import json
import logging
import os
import time
from typing import TYPE_CHECKING

Expand Down Expand Up @@ -206,7 +207,11 @@ def _acquire_token_silent(self, *scopes, **kwargs):
def _get_app(self):
# type: () -> msal.PublicClientApplication
if not self._msal_app:
self._msal_app = self._create_app(msal.PublicClientApplication, client_capabilities=["CP1"])
if "AZURE_IDENTITY_DISABLE_CP1" in os.environ:
capabilities = None
else:
capabilities = ["CP1"] # able to handle CAE claims challenges
self._msal_app = self._create_app(msal.PublicClientApplication, client_capabilities=capabilities)
return self._msal_app

@abc.abstractmethod
Expand Down
13 changes: 10 additions & 3 deletions sdk/identity/azure-identity/tests/test_device_code_credential.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,18 +248,25 @@ def test_timeout():


def test_client_capabilities():
"""the credential should configure MSAL for capability CP1 (ability to handle claims challenges)"""
"""the credential should configure MSAL for capability CP1 unless AZURE_IDENTITY_DISABLE_CP1 is set"""

transport = Mock(send=Mock(side_effect=Exception("this test mocks MSAL, so no request should be sent")))
credential = DeviceCodeCredential(transport=transport)

with patch("msal.PublicClientApplication") as PublicClientApplication:
credential._get_app()
DeviceCodeCredential(transport=transport)._get_app()

assert PublicClientApplication.call_count == 1
_, kwargs = PublicClientApplication.call_args
assert kwargs["client_capabilities"] == ["CP1"]

with patch.dict("os.environ", {"AZURE_IDENTITY_DISABLE_CP1": "true"}):
with patch("msal.PublicClientApplication") as PublicClientApplication:
DeviceCodeCredential(transport=transport)._get_app()

assert PublicClientApplication.call_count == 1
_, kwargs = PublicClientApplication.call_args
assert kwargs["client_capabilities"] is None


def test_claims_challenge():
"""get_token and authenticate should pass any claims challenge to MSAL token acquisition APIs"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -736,7 +736,7 @@ def mock_send(request, **_):


def test_client_capabilities():
"""the credential should configure MSAL for capability CP1 (ability to handle claims challenges)"""
"""the credential should configure MSAL for capability CP1 unless AZURE_IDENTITY_DISABLE_CP1 is set"""

record = AuthenticationRecord("tenant-id", "client_id", "authority", "home_account_id", "username")
transport = Mock(send=Mock(side_effect=Exception("this test mocks MSAL, so no request should be sent")))
Expand All @@ -751,6 +751,17 @@ def test_client_capabilities():
_, kwargs = PublicClientApplication.call_args
assert kwargs["client_capabilities"] == ["CP1"]

credential = SharedTokenCacheCredential(
transport=transport, authentication_record=record, _cache=TokenCache()
)
with patch(SharedTokenCacheCredential.__module__ + ".PublicClientApplication") as PublicClientApplication:
with patch.dict("os.environ", {"AZURE_IDENTITY_DISABLE_CP1": "true"}):
credential._initialize()

assert PublicClientApplication.call_count == 1
_, kwargs = PublicClientApplication.call_args
assert kwargs["client_capabilities"] is None


def test_claims_challenge():
"""get_token should pass any claims challenge to MSAL token acquisition APIs"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,18 +149,27 @@ def test_authenticate():


def test_client_capabilities():
"""the credential should configure MSAL for capability CP1 (ability to handle claims challenges)"""
"""the credential should configure MSAL for capability CP1 unless AZURE_IDENTITY_DISABLE_CP1 is set"""

transport = Mock(send=Mock(side_effect=Exception("this test mocks MSAL, so no request should be sent")))
credential = UsernamePasswordCredential("client-id", "username", "password", transport=transport)

credential = UsernamePasswordCredential("client-id", "username", "password", transport=transport)
with patch("msal.PublicClientApplication") as PublicClientApplication:
credential._get_app()

assert PublicClientApplication.call_count == 1
_, kwargs = PublicClientApplication.call_args
assert kwargs["client_capabilities"] == ["CP1"]

credential = UsernamePasswordCredential("client-id", "username", "password", transport=transport)
with patch.dict("os.environ", {"AZURE_IDENTITY_DISABLE_CP1": "true"}):
with patch("msal.PublicClientApplication") as PublicClientApplication:
credential._get_app()

assert PublicClientApplication.call_count == 1
_, kwargs = PublicClientApplication.call_args
assert kwargs["client_capabilities"] is None


def test_claims_challenge():
"""get_token should and authenticate pass any claims challenge to MSAL token acquisition APIs"""
Expand Down

0 comments on commit 701a6f5

Please sign in to comment.