Skip to content

Commit

Permalink
Merge branch 'token-source' into dev
Browse files Browse the repository at this point in the history
  • Loading branch information
rayluo committed Oct 26, 2023
2 parents fde6129 + 0ca81d8 commit b3b2195
Show file tree
Hide file tree
Showing 8 changed files with 56 additions and 11 deletions.
44 changes: 33 additions & 11 deletions msal/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,10 @@ class ClientApplication(object):
REMOVE_ACCOUNT_ID = "903"

ATTEMPT_REGION_DISCOVERY = True # "TryAutoDetect"
_TOKEN_SOURCE = "token_source"
_TOKEN_SOURCE_IDP = "identity_provider"
_TOKEN_SOURCE_CACHE = "cache"
_TOKEN_SOURCE_BROKER = "broker"

def __init__(
self, client_id,
Expand Down Expand Up @@ -998,6 +1002,8 @@ def authorize(): # A controller in a web app
self._client_capabilities,
auth_code_flow.pop("claims_challenge", None))),
**kwargs))
if "access_token" in response:
response[self._TOKEN_SOURCE] = self._TOKEN_SOURCE_IDP
telemetry_context.update_telemetry(response)
return response

Expand Down Expand Up @@ -1070,6 +1076,8 @@ def acquire_token_by_authorization_code(
self._client_capabilities, claims_challenge)),
nonce=nonce,
**kwargs))
if "access_token" in response:
response[self._TOKEN_SOURCE] = self._TOKEN_SOURCE_IDP
telemetry_context.update_telemetry(response)
return response

Expand Down Expand Up @@ -1218,6 +1226,8 @@ def _acquire_token_by_cloud_shell(self, scopes, data=None):
data=data or {},
authority_type=_AUTHORITY_TYPE_CLOUDSHELL,
))
if "access_token" in response:
response[self._TOKEN_SOURCE] = self._TOKEN_SOURCE_BROKER
return response

def acquire_token_silent(
Expand Down Expand Up @@ -1395,6 +1405,7 @@ def _acquire_token_silent_from_cache_and_possibly_refresh_it(
"access_token": entry["secret"],
"token_type": entry.get("token_type", "Bearer"),
"expires_in": int(expires_in), # OAuth2 specs defines it as int
self._TOKEN_SOURCE: self._TOKEN_SOURCE_CACHE,
}
if "refresh_on" in entry and int(entry["refresh_on"]) < now: # aging
refresh_reason = msal.telemetry.AT_AGING
Expand Down Expand Up @@ -1437,6 +1448,8 @@ def _acquire_token_silent_from_cache_and_possibly_refresh_it(
result = self._acquire_token_for_client(
scopes, refresh_reason, claims_challenge=claims_challenge,
**kwargs)
if result and "access_token" in result:
result[self._TOKEN_SOURCE] = self._TOKEN_SOURCE_IDP
if (result and "error" not in result) or (not access_token_from_cache):
return result
except http_exceptions:
Expand All @@ -1455,6 +1468,7 @@ def _process_broker_response(self, response, scopes, data):
data=data,
_account_id=response["_account_id"],
))
response[self._TOKEN_SOURCE] = self._TOKEN_SOURCE_BROKER
return _clean_up(response)

def _acquire_token_silent_by_finding_rt_belongs_to_me_or_my_family(
Expand Down Expand Up @@ -1611,6 +1625,8 @@ def acquire_token_by_refresh_token(self, refresh_token, scopes, **kwargs):
on_updating_rt=False,
on_removing_rt=lambda rt_item: None, # No OP
**kwargs))
if "access_token" in response:
response[self._TOKEN_SOURCE] = self._TOKEN_SOURCE_IDP
telemetry_context.update_telemetry(response)
return response

Expand Down Expand Up @@ -1658,6 +1674,7 @@ def acquire_token_by_username_password(
self.ACQUIRE_TOKEN_BY_USERNAME_PASSWORD_ID)
headers = telemetry_context.generate_headers()
data = dict(kwargs.pop("data", {}), claims=claims)
response = None
if not self.authority.is_adfs:
user_realm_result = self.authority.user_realm_discovery(
username, correlation_id=headers[msal.telemetry.CLIENT_REQUEST_ID])
Expand All @@ -1666,13 +1683,14 @@ def acquire_token_by_username_password(
user_realm_result, username, password, scopes=scopes,
data=data,
headers=headers, **kwargs))
telemetry_context.update_telemetry(response)
return response
response = _clean_up(self.client.obtain_token_by_username_password(
if response is None: # Either ADFS or not federated
response = _clean_up(self.client.obtain_token_by_username_password(
username, password, scope=scopes,
headers=headers,
data=data,
**kwargs))
if "access_token" in response:
response[self._TOKEN_SOURCE] = self._TOKEN_SOURCE_IDP
telemetry_context.update_telemetry(response)
return response

Expand Down Expand Up @@ -1859,7 +1877,7 @@ def acquire_token_interactive(
logger.warning(
"Ignoring parameter extra_scopes_to_consent, "
"which is not supported by broker")
return self._acquire_token_interactive_via_broker(
response = self._acquire_token_interactive_via_broker(
scopes,
parent_window_handle,
enable_msa_passthrough,
Expand All @@ -1870,6 +1888,7 @@ def acquire_token_interactive(
login_hint=login_hint,
max_age=max_age,
)
return self._process_broker_response(response, scopes, data)

on_before_launching_ui(ui="browser")
telemetry_context = self._build_telemetry_context(
Expand All @@ -1892,6 +1911,8 @@ def acquire_token_interactive(
headers=telemetry_context.generate_headers(),
browser_name=_preferred_browser(),
**kwargs))
if "access_token" in response:
response[self._TOKEN_SOURCE] = self._TOKEN_SOURCE_IDP
telemetry_context.update_telemetry(response)
return response

Expand Down Expand Up @@ -1928,7 +1949,7 @@ def _acquire_token_interactive_via_broker(
claims=claims,
**data)
if response and "error" not in response:
return self._process_broker_response(response, scopes, data)
return response
# login_hint undecisive or not exists
if prompt == "none" or not prompt: # Must/Can attempt _signin_silently()
logger.debug("Calling broker._signin_silently()")
Expand All @@ -1949,9 +1970,7 @@ def _acquire_token_interactive_via_broker(
if is_wrong_account:
logger.debug(wrong_account_error_message)
if prompt == "none":
return self._process_broker_response( # It is either token or error
response, scopes, data
) if not is_wrong_account else {
return response if not is_wrong_account else {
"error": "broker_error",
"error_description": wrong_account_error_message,
}
Expand All @@ -1966,11 +1985,11 @@ def _acquire_token_interactive_via_broker(
"_broker_status") in recoverable_errors:
pass # It will fall back to the _signin_interactively()
else:
return self._process_broker_response(response, scopes, data)
return response

logger.debug("Falls back to broker._signin_interactively()")
on_before_launching_ui(ui="broker")
response = _signin_interactively(
return _signin_interactively(
authority, self.client_id, scopes,
None if parent_window_handle is self.CONSOLE_WINDOW_HANDLE
else parent_window_handle,
Expand All @@ -1981,7 +2000,6 @@ def _acquire_token_interactive_via_broker(
max_age=max_age,
enable_msa_pt=enable_msa_passthrough,
**data)
return self._process_broker_response(response, scopes, data)

def initiate_device_flow(self, scopes=None, **kwargs):
"""Initiate a Device Flow instance,
Expand Down Expand Up @@ -2036,6 +2054,8 @@ def acquire_token_by_device_flow(self, flow, claims_challenge=None, **kwargs):
),
headers=telemetry_context.generate_headers(),
**kwargs))
if "access_token" in response:
response[self._TOKEN_SOURCE] = self._TOKEN_SOURCE_IDP
telemetry_context.update_telemetry(response)
return response

Expand Down Expand Up @@ -2145,5 +2165,7 @@ def acquire_token_on_behalf_of(self, user_assertion, scopes, claims_challenge=No
headers=telemetry_context.generate_headers(),
# TBD: Expose a login_hint (or ccs_routing_hint) param for web app
**kwargs))
if "access_token" in response:
response[self._TOKEN_SOURCE] = self._TOKEN_SOURCE_IDP
telemetry_context.update_telemetry(response)
return response
1 change: 1 addition & 0 deletions sample/confidential_client_certificate_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def acquire_and_use_token():
result = global_app.acquire_token_for_client(scopes=config["scope"])

if "access_token" in result:
print("Token was obtained from:", result["token_source"]) # Since MSAL 1.25
# Calling graph using the access token
graph_data = requests.get( # Use token to call downstream service
config["endpoint"],
Expand Down
1 change: 1 addition & 0 deletions sample/confidential_client_secret_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def acquire_and_use_token():
result = global_app.acquire_token_for_client(scopes=config["scope"])

if "access_token" in result:
print("Token was obtained from:", result["token_source"]) # Since MSAL 1.25
# Calling graph using the access token
graph_data = requests.get( # Use token to call downstream service
config["endpoint"],
Expand Down
1 change: 1 addition & 0 deletions sample/device_flow_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def acquire_and_use_token():
# and then keep calling acquire_token_by_device_flow(flow) in your own customized loop.

if "access_token" in result:
print("Token was obtained from:", result["token_source"]) # Since MSAL 1.25
# Calling graph using the access token
graph_data = requests.get( # Use token to call downstream service
config["endpoint"],
Expand Down
1 change: 1 addition & 0 deletions sample/interactive_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def acquire_and_use_token():
)

if "access_token" in result:
print("Token was obtained from:", result["token_source"]) # Since MSAL 1.25
# Calling graph using the access token
graph_response = requests.get( # Use token to call downstream service
config["endpoint"],
Expand Down
1 change: 1 addition & 0 deletions sample/username_password_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def acquire_and_use_token():
config["username"], config["password"], scopes=config["scope"])

if "access_token" in result:
print("Token was obtained from:", result["token_source"]) # Since MSAL 1.25
# Calling graph using the access token
graph_data = requests.get( # Use token to call downstream service
config["endpoint"],
Expand Down
1 change: 1 addition & 0 deletions sample/vault_jwt_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ def acquire_and_use_token():
result = global_app.acquire_token_for_client(scopes=config["scope"])

if "access_token" in result:
print("Token was obtained from:", result["token_source"]) # Since MSAL 1.25
# Calling graph using the access token
graph_data = requests.get( # Use token to call downstream service
config["endpoint"],
Expand Down
17 changes: 17 additions & 0 deletions tests/test_application.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ def tester(url, **kwargs):
self.scopes, self.account, post=tester)
self.assertEqual("", result.get("classification"))


class TestClientApplicationAcquireTokenSilentFociBehaviors(unittest.TestCase):

def setUp(self):
Expand Down Expand Up @@ -263,6 +264,7 @@ def test_get_accounts_should_find_accounts_under_different_alias(self):
def test_acquire_token_silent_should_find_at_under_different_alias(self):
result = self.app.acquire_token_silent(self.scopes, self.account)
self.assertNotEqual(None, result)
self.assertEqual(result[self.app._TOKEN_SOURCE], self.app._TOKEN_SOURCE_CACHE)
self.assertEqual(self.access_token, result.get('access_token'))

def test_acquire_token_silent_should_find_rt_under_different_alias(self):
Expand Down Expand Up @@ -360,6 +362,7 @@ def test_fresh_token_should_be_returned_from_cache(self):
post=lambda url, *args, **kwargs: # Utilize the undocumented test feature
self.fail("I/O shouldn't happen in cache hit AT scenario")
)
self.assertEqual(result[self.app._TOKEN_SOURCE], self.app._TOKEN_SOURCE_CACHE)
self.assertEqual(access_token, result.get("access_token"))
self.assertNotIn("refresh_in", result, "Customers need not know refresh_in")

Expand All @@ -374,6 +377,7 @@ def mock_post(url, headers=None, *args, **kwargs):
"refresh_in": 123,
}))
result = self.app.acquire_token_silent(['s1'], self.account, post=mock_post)
self.assertEqual(result[self.app._TOKEN_SOURCE], self.app._TOKEN_SOURCE_IDP)
self.assertEqual(new_access_token, result.get("access_token"))
self.assertNotIn("refresh_in", result, "Customers need not know refresh_in")

Expand All @@ -385,6 +389,7 @@ def mock_post(url, headers=None, *args, **kwargs):
self.assertEqual("4|84,4|", (headers or {}).get(CLIENT_CURRENT_TELEMETRY))
return MinimalResponse(status_code=400, text=json.dumps({"error": "foo"}))
result = self.app.acquire_token_silent(['s1'], self.account, post=mock_post)
self.assertEqual(result[self.app._TOKEN_SOURCE], self.app._TOKEN_SOURCE_CACHE)
self.assertEqual(old_at, result.get("access_token"))

def test_expired_token_and_unavailable_aad_should_return_error(self):
Expand All @@ -409,6 +414,7 @@ def mock_post(url, headers=None, *args, **kwargs):
"refresh_in": 123,
}))
result = self.app.acquire_token_silent(['s1'], self.account, post=mock_post)
self.assertEqual(result[self.app._TOKEN_SOURCE], self.app._TOKEN_SOURCE_IDP)
self.assertEqual(new_access_token, result.get("access_token"))
self.assertNotIn("refresh_in", result, "Customers need not know refresh_in")

Expand Down Expand Up @@ -444,6 +450,7 @@ def test_maintaining_offline_state_and_sending_them(self):
post=lambda url, *args, **kwargs: # Utilize the undocumented test feature
self.fail("I/O shouldn't happen in cache hit AT scenario")
)
self.assertEqual(result[app._TOKEN_SOURCE], app._TOKEN_SOURCE_CACHE)
self.assertEqual(cached_access_token, result.get("access_token"))

error1 = "error_1"
Expand Down Expand Up @@ -477,6 +484,7 @@ def mock_post(url, headers=None, *args, **kwargs):
"The previous error should result in same success counter plus latest error info")
return MinimalResponse(status_code=200, text=json.dumps({"access_token": at}))
result = app.acquire_token_by_device_flow({"device_code": "123"}, post=mock_post)
self.assertEqual(result[app._TOKEN_SOURCE], app._TOKEN_SOURCE_IDP)
self.assertEqual(at, result.get("access_token"))

def mock_post(url, headers=None, *args, **kwargs):
Expand All @@ -485,6 +493,7 @@ def mock_post(url, headers=None, *args, **kwargs):
"The previous success should reset all offline telemetry counters")
return MinimalResponse(status_code=200, text=json.dumps({"access_token": at}))
result = app.acquire_token_by_device_flow({"device_code": "123"}, post=mock_post)
self.assertEqual(result[app._TOKEN_SOURCE], app._TOKEN_SOURCE_IDP)
self.assertEqual(at, result.get("access_token"))


Expand All @@ -503,6 +512,7 @@ def mock_post(url, headers=None, *args, **kwargs):
result = self.app.acquire_token_by_auth_code_flow(
{"state": state, "code_verifier": "bar"}, {"state": state, "code": "012"},
post=mock_post)
self.assertEqual(result[self.app._TOKEN_SOURCE], self.app._TOKEN_SOURCE_IDP)
self.assertEqual(at, result.get("access_token"))

def test_acquire_token_by_refresh_token(self):
Expand All @@ -511,6 +521,7 @@ def mock_post(url, headers=None, *args, **kwargs):
self.assertEqual("4|85,1|", (headers or {}).get(CLIENT_CURRENT_TELEMETRY))
return MinimalResponse(status_code=200, text=json.dumps({"access_token": at}))
result = self.app.acquire_token_by_refresh_token("rt", ["s"], post=mock_post)
self.assertEqual(result[self.app._TOKEN_SOURCE], self.app._TOKEN_SOURCE_IDP)
self.assertEqual(at, result.get("access_token"))


Expand All @@ -529,6 +540,7 @@ def mock_post(url, headers=None, *args, **kwargs):
return MinimalResponse(status_code=200, text=json.dumps({"access_token": at}))
result = self.app.acquire_token_by_device_flow(
{"device_code": "123"}, post=mock_post)
self.assertEqual(result[self.app._TOKEN_SOURCE], self.app._TOKEN_SOURCE_IDP)
self.assertEqual(at, result.get("access_token"))

def test_acquire_token_by_username_password(self):
Expand All @@ -538,6 +550,7 @@ def mock_post(url, headers=None, *args, **kwargs):
return MinimalResponse(status_code=200, text=json.dumps({"access_token": at}))
result = self.app.acquire_token_by_username_password(
"username", "password", ["scope"], post=mock_post)
self.assertEqual(result[self.app._TOKEN_SOURCE], self.app._TOKEN_SOURCE_IDP)
self.assertEqual(at, result.get("access_token"))


Expand All @@ -556,6 +569,7 @@ def mock_post(url, headers=None, *args, **kwargs):
"expires_in": 0,
}))
result = self.app.acquire_token_for_client(["scope"], post=mock_post)
self.assertEqual(result[self.app._TOKEN_SOURCE], self.app._TOKEN_SOURCE_IDP)
self.assertEqual("AT 1", result.get("access_token"), "Shall get a new token")

def mock_post(url, headers=None, *args, **kwargs):
Expand All @@ -566,13 +580,15 @@ def mock_post(url, headers=None, *args, **kwargs):
"refresh_in": -100, # A hack to make sure it will attempt refresh
}))
result = self.app.acquire_token_for_client(["scope"], post=mock_post)
self.assertEqual(result[self.app._TOKEN_SOURCE], self.app._TOKEN_SOURCE_IDP)
self.assertEqual("AT 2", result.get("access_token"), "Shall get a new token")

def mock_post(url, headers=None, *args, **kwargs):
# 1/0 # TODO: Make sure this was called
self.assertEqual("4|730,4|", (headers or {}).get(CLIENT_CURRENT_TELEMETRY))
return MinimalResponse(status_code=400, text=json.dumps({"error": "foo"}))
result = self.app.acquire_token_for_client(["scope"], post=mock_post)
self.assertEqual(result[self.app._TOKEN_SOURCE], self.app._TOKEN_SOURCE_CACHE)
self.assertEqual("AT 2", result.get("access_token"), "Shall get aging token")

def test_acquire_token_on_behalf_of(self):
Expand All @@ -581,6 +597,7 @@ def mock_post(url, headers=None, *args, **kwargs):
self.assertEqual("4|523,0|", (headers or {}).get(CLIENT_CURRENT_TELEMETRY))
return MinimalResponse(status_code=200, text=json.dumps({"access_token": at}))
result = self.app.acquire_token_on_behalf_of("assertion", ["s"], post=mock_post)
self.assertEqual(result[self.app._TOKEN_SOURCE], self.app._TOKEN_SOURCE_IDP)
self.assertEqual(at, result.get("access_token"))


Expand Down

0 comments on commit b3b2195

Please sign in to comment.