Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support CAE in azure-identity #16323

Merged
merged 4 commits into from
Feb 3, 2021
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def _request_token(self, *scopes, **kwargs):
request_state = str(uuid.uuid4())
app = self._get_app()
auth_url = app.get_authorization_request_url(
scopes, redirect_uri=redirect_uri, state=request_state, prompt="select_account", **kwargs
scopes, redirect_uri=redirect_uri, state=request_state, prompt="select_account"
)

# open browser to that url
Expand All @@ -113,7 +113,9 @@ def _request_token(self, *scopes, **kwargs):

# redeem the authorization code for a token
code = self._parse_response(request_state, response)
return app.acquire_token_by_authorization_code(code, scopes=scopes, redirect_uri=redirect_uri, **kwargs)
return app.acquire_token_by_authorization_code(
code, scopes=scopes, redirect_uri=redirect_uri, claims_challenge=kwargs.get("claims_challenge")
)

@staticmethod
def _parse_response(request_state, response):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,13 +84,16 @@ def _request_token(self, *scopes, **kwargs):
else:
print(flow["message"])

claims_challenge = kwargs.get("claims_challenge")
if self._timeout is not None and self._timeout < flow["expires_in"]:
# user specified an effective timeout we will observe
deadline = int(time.time()) + self._timeout
result = app.acquire_token_by_device_flow(flow, exit_condition=lambda flow: time.time() > deadline)
result = app.acquire_token_by_device_flow(
flow, exit_condition=lambda flow: time.time() > deadline, claims_challenge=claims_challenge
)
else:
# MSAL will stop polling when the device code expires
result = app.acquire_token_by_device_flow(flow)
result = app.acquire_token_by_device_flow(flow, claims_challenge=claims_challenge)

if "access_token" not in result:
if result.get("error") == "authorization_pending":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument
raise CredentialUnavailableError(message="Shared token cache unavailable")

if self._auth_record:
return self._acquire_token_silent(*scopes)
return self._acquire_token_silent(*scopes, **kwargs)

account = self._get_account(self._username, self._tenant_id)

Expand Down Expand Up @@ -121,6 +121,7 @@ def _initialize(self):
authority="https://{}/{}".format(self._auth_record.authority, self._tenant_id),
token_cache=self._cache,
http_client=MsalClient(**self._client_kwargs),
client_capabilities=["CP1"]
)

self._initialized = True
Expand All @@ -146,7 +147,9 @@ def _acquire_token_silent(self, *scopes, **kwargs):
continue

now = int(time.time())
result = self._app.acquire_token_silent_with_error(list(scopes), account=account, **kwargs)
result = self._app.acquire_token_silent_with_error(
list(scopes), account=account, claims_challenge=kwargs.get("claims_challenge")
)
if result and "access_token" in result and "expires_in" in result:
return AccessToken(result["access_token"], now + int(result["expires_in"]))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,5 +56,8 @@ def _request_token(self, *scopes, **kwargs):
# type: (*str, **Any) -> dict
app = self._get_app()
return app.acquire_token_by_username_password(
username=self._username, password=self._password, scopes=list(scopes)
username=self._username,
password=self._password,
scopes=list(scopes),
claims_challenge=kwargs.get("claims_challenge"),
)
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,8 @@ def get_token(self, *scopes, **kwargs):
This method is called automatically by Azure SDK clients.

:param str scopes: desired scopes for the access token. This method requires at least one scope.
:keyword str claims_challenge: a claims challenge returned by a resource provider following an authorization
failure
:rtype: :class:`azure.core.credentials.AccessToken`
:raises CredentialUnavailableError: the credential is unable to attempt authentication because it lacks
required data, state, or platform support
Expand Down Expand Up @@ -187,7 +189,9 @@ def _acquire_token_silent(self, *scopes, **kwargs):
continue

now = int(time.time())
result = app.acquire_token_silent_with_error(list(scopes), account=account, **kwargs)
result = app.acquire_token_silent_with_error(
list(scopes), account=account, claims_challenge=kwargs.get("claims_challenge")
)
if result and "access_token" in result and "expires_in" in result:
return AccessToken(result["access_token"], now + int(result["expires_in"]))

Expand All @@ -200,7 +204,7 @@ def _acquire_token_silent(self, *scopes, **kwargs):
def _get_app(self):
# type: () -> msal.PublicClientApplication
if not self._msal_app:
self._msal_app = self._create_app(msal.PublicClientApplication)
self._msal_app = self._create_app(msal.PublicClientApplication, client_capabilities=["CP1"])
return self._msal_app

@abc.abstractmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,14 +57,15 @@ def _get_app(self):
# type: () -> msal.ClientApplication
pass

def _create_app(self, cls):
# type: (Type[msal.ClientApplication]) -> msal.ClientApplication
def _create_app(self, cls, **kwargs):
# type: (Type[msal.ClientApplication], **Any) -> msal.ClientApplication
app = cls(
client_id=self._client_id,
client_credential=self._client_credential,
authority="{}/{}".format(self._authority, self._tenant_id),
token_cache=self._cache,
http_client=self._client,
**kwargs
)

return app
77 changes: 77 additions & 0 deletions sdk/identity/azure-identity/tests/recording_processors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# ------------------------------------
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
import base64
import binascii
import hashlib
import json
import re
import time

from azure_devtools.scenario_tests import RecordingProcessor
import six


SECRETS = frozenset({
"access_token",
"client_secret",
"code",
"device_code",
"message",
"password",
"refresh_token",
"user_code",
})


class RecordingRedactor(RecordingProcessor):
"""Removes authentication secrets from recordings"""

def process_request(self, request):
# don't record the body because it probably contains secrets and is formed by msal anyway,
# i.e. it isn't this library's responsibility
request.body = None
return request

def process_response(self, response):
try:
body = json.loads(response["body"]["string"])
except (KeyError, ValueError):
return response

for field in body:
if field in SECRETS:
# record a hash of the secret instead of a simple replacement like "redacted"
# because some tests (e.g. for CAE) require unique, consistent values
digest = hashlib.sha256(six.ensure_binary(body[field])).digest()
body[field] = six.ensure_str(binascii.hexlify(digest))

response["body"]["string"] = json.dumps(body)
return response


class IdTokenProcessor(RecordingProcessor):
def process_response(self, response):
"""Changes the "exp" claim of recorded id tokens to be in the future during playback

This is necessary because msal always validates id tokens, raising an exception when they've expired.
"""
try:
# decode the recorded token
body = json.loads(six.ensure_str(response["body"]["string"]))
header, encoded_payload, signed = body["id_token"].split(".")
decoded_payload = base64.b64decode(encoded_payload + "=" * (4 - len(encoded_payload) % 4))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Out of curiosity, what are the "=" added onto the payload for?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A base64 encoded string should be padded with "=" to make its length divisible by 4. CPython's base64 strictly requires padding. However, because a decoder can infer the padding, encoders commonly omit it.


# set the token's expiry time to one hour from now
payload = json.loads(six.ensure_str(decoded_payload))
payload["exp"] = int(time.time()) + 3600

# write the modified token to the response body
new_payload = six.ensure_binary(json.dumps(payload))
body["id_token"] = ".".join((header, base64.b64encode(new_payload).decode("utf-8"), signed))
response["body"]["string"] = six.ensure_binary(json.dumps(body))
except KeyError:
pass

return response
Loading