Skip to content

Commit

Permalink
Allow overriding client_id for token exchange (#20571)
Browse files Browse the repository at this point in the history
  • Loading branch information
chlowell authored Sep 8, 2021
1 parent 6ccb4ad commit dcbd6d9
Show file tree
Hide file tree
Showing 5 changed files with 152 additions and 53 deletions.
2 changes: 1 addition & 1 deletion sdk/identity/azure-identity/azure/identity/_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,4 +48,4 @@ class EnvironmentVariables:
AZURE_REGIONAL_AUTHORITY_NAME = "AZURE_REGIONAL_AUTHORITY_NAME"

AZURE_FEDERATED_TOKEN_FILE = "AZURE_FEDERATED_TOKEN_FILE"
TOKEN_EXCHANGE_VARS = (AZURE_CLIENT_ID, AZURE_TENANT_ID, AZURE_FEDERATED_TOKEN_FILE)
TOKEN_EXCHANGE_VARS = (AZURE_AUTHORITY_HOST, AZURE_TENANT_ID, AZURE_FEDERATED_TOKEN_FILE)
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,13 @@ def __init__(self, **kwargs):
_LOGGER.info("%s will use token exchange", self.__class__.__name__)
from .token_exchange import TokenExchangeCredential

client_id = kwargs.pop("client_id", None) or os.environ.get(EnvironmentVariables.AZURE_CLIENT_ID)
if not client_id:
raise ValueError('Configure the environment with a client ID or pass a value for "client_id" argument')

self._credential = TokenExchangeCredential(
tenant_id=os.environ[EnvironmentVariables.AZURE_TENANT_ID],
client_id=os.environ[EnvironmentVariables.AZURE_CLIENT_ID],
client_id=client_id,
token_file_path=os.environ[EnvironmentVariables.AZURE_FEDERATED_TOKEN_FILE],
**kwargs
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,13 @@ def __init__(self, **kwargs: "Any") -> None:
_LOGGER.info("%s will use token exchange", self.__class__.__name__)
from .token_exchange import TokenExchangeCredential

client_id = kwargs.pop("client_id", None) or os.environ.get(EnvironmentVariables.AZURE_CLIENT_ID)
if not client_id:
raise ValueError('Configure the environment with a client ID or pass a value for "client_id" argument')

self._credential = TokenExchangeCredential(
tenant_id=os.environ[EnvironmentVariables.AZURE_TENANT_ID],
client_id=os.environ[EnvironmentVariables.AZURE_CLIENT_ID],
client_id=client_id,
token_file_path=os.environ[EnvironmentVariables.AZURE_FEDERATED_TOKEN_FILE],
**kwargs
)
Expand Down
105 changes: 71 additions & 34 deletions sdk/identity/azure-identity/tests/test_managed_identity.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
},
{EnvironmentVariables.IDENTITY_ENDPOINT: "...", EnvironmentVariables.IMDS_ENDPOINT: "..."}, # Arc
{ # token exchange
EnvironmentVariables.AZURE_AUTHORITY_HOST: "https://localhost",
EnvironmentVariables.AZURE_CLIENT_ID: "...",
EnvironmentVariables.AZURE_TENANT_ID: "...",
EnvironmentVariables.AZURE_FEDERATED_TOKEN_FILE: __file__,
Expand Down Expand Up @@ -73,24 +74,6 @@ def test_context_manager_incomplete_configuration():
pass


ALL_ENVIRONMENTS = (
{EnvironmentVariables.MSI_ENDPOINT: "...", EnvironmentVariables.MSI_SECRET: "..."}, # App Service
{EnvironmentVariables.MSI_ENDPOINT: "..."}, # Cloud Shell
{ # Service Fabric
EnvironmentVariables.IDENTITY_ENDPOINT: "...",
EnvironmentVariables.IDENTITY_HEADER: "...",
EnvironmentVariables.IDENTITY_SERVER_THUMBPRINT: "...",
},
{EnvironmentVariables.IDENTITY_ENDPOINT: "...", EnvironmentVariables.IMDS_ENDPOINT: "..."}, # Arc
{ # token exchange
EnvironmentVariables.AZURE_CLIENT_ID: "...",
EnvironmentVariables.AZURE_TENANT_ID: "...",
EnvironmentVariables.AZURE_FEDERATED_TOKEN_FILE: __file__,
},
{}, # IMDS
)


@pytest.mark.parametrize("environ", ALL_ENVIRONMENTS)
def test_custom_hooks(environ):
"""The credential's pipeline should include azure-core's CustomHookPolicy"""
Expand Down Expand Up @@ -790,10 +773,21 @@ def test_token_exchange(tmpdir):
token_file.write(exchange_token)
access_token = "***"
authority = "https://localhost"
client_id = "client_id"
default_client_id = "default_client_id"
tenant = "tenant_id"
scope = "scope"

success_response = mock_response(
json_payload={
"access_token": access_token,
"expires_in": 3600,
"ext_expires_in": 3600,
"expires_on": int(time.time()) + 3600,
"not_before": int(time.time()),
"resource": scope,
"token_type": "Bearer",
}
)
transport = validating_transport(
requests=[
Request(
Expand All @@ -802,38 +796,81 @@ def test_token_exchange(tmpdir):
required_data={
"client_assertion": exchange_token,
"client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
"client_id": client_id,
"client_id": default_client_id,
"grant_type": "client_credentials",
"scope": scope,
},
)
],
responses=[
mock_response(
json_payload={
"access_token": access_token,
"expires_in": 3600,
"ext_expires_in": 3600,
"expires_on": int(time.time()) + 3600,
"not_before": int(time.time()),
"resource": scope,
"token_type": "Bearer",
}
responses=[success_response],
)

mock_environ = {
EnvironmentVariables.AZURE_AUTHORITY_HOST: authority,
EnvironmentVariables.AZURE_CLIENT_ID: default_client_id,
EnvironmentVariables.AZURE_TENANT_ID: tenant,
EnvironmentVariables.AZURE_FEDERATED_TOKEN_FILE: token_file.strpath,
}
# credential should default to AZURE_CLIENT_ID
with mock.patch.dict("os.environ", mock_environ, clear=True):
credential = ManagedIdentityCredential(transport=transport)
token = credential.get_token(scope)
assert token.token == access_token

# client_id kwarg should override AZURE_CLIENT_ID
nondefault_client_id = "non" + default_client_id
transport = validating_transport(
requests=[
Request(
base_url=authority,
method="POST",
required_data={
"client_assertion": exchange_token,
"client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
"client_id": nondefault_client_id,
"grant_type": "client_credentials",
"scope": scope,
},
)
],
responses=[success_response],
)

with mock.patch.dict("os.environ", mock_environ, clear=True):
credential = ManagedIdentityCredential(client_id=nondefault_client_id, transport=transport)
token = credential.get_token(scope)
assert token.token == access_token

# AZURE_CLIENT_ID may not have a value, in which case client_id is required
transport = validating_transport(
requests=[
Request(
base_url=authority,
method="POST",
required_data={
"client_assertion": exchange_token,
"client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
"client_id": nondefault_client_id,
"grant_type": "client_credentials",
"scope": scope,
},
)
],
responses=[success_response],
)

with mock.patch.dict(
"os.environ",
{
EnvironmentVariables.AZURE_AUTHORITY_HOST: authority,
EnvironmentVariables.AZURE_CLIENT_ID: client_id,
EnvironmentVariables.AZURE_TENANT_ID: tenant,
EnvironmentVariables.AZURE_FEDERATED_TOKEN_FILE: token_file.strpath,
},
clear=True,
):
credential = ManagedIdentityCredential(transport=transport)
token = credential.get_token(scope)
with pytest.raises(ValueError):
ManagedIdentityCredential()

credential = ManagedIdentityCredential(client_id=nondefault_client_id, transport=transport)
token = credential.get_token(scope)
assert token.token == access_token
86 changes: 70 additions & 16 deletions sdk/identity/azure-identity/tests/test_managed_identity_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -731,10 +731,21 @@ async def test_token_exchange(tmpdir):
token_file.write(exchange_token)
access_token = "***"
authority = "https://localhost"
client_id = "client_id"
default_client_id = "default_client_id"
tenant = "tenant_id"
scope = "scope"

success_response = mock_response(
json_payload={
"access_token": access_token,
"expires_in": 3600,
"ext_expires_in": 3600,
"expires_on": int(time.time()) + 3600,
"not_before": int(time.time()),
"resource": scope,
"token_type": "Bearer",
}
)
transport = async_validating_transport(
requests=[
Request(
Expand All @@ -743,38 +754,81 @@ async def test_token_exchange(tmpdir):
required_data={
"client_assertion": exchange_token,
"client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
"client_id": client_id,
"client_id": default_client_id,
"grant_type": "client_credentials",
"scope": scope,
},
)
],
responses=[
mock_response(
json_payload={
"access_token": access_token,
"expires_in": 3600,
"ext_expires_in": 3600,
"expires_on": int(time.time()) + 3600,
"not_before": int(time.time()),
"resource": scope,
"token_type": "Bearer",
}
responses=[success_response],
)

mock_environ = {
EnvironmentVariables.AZURE_AUTHORITY_HOST: authority,
EnvironmentVariables.AZURE_CLIENT_ID: default_client_id,
EnvironmentVariables.AZURE_TENANT_ID: tenant,
EnvironmentVariables.AZURE_FEDERATED_TOKEN_FILE: token_file.strpath,
}
# credential should default to AZURE_CLIENT_ID
with mock.patch.dict("os.environ", mock_environ, clear=True):
credential = ManagedIdentityCredential(transport=transport)
token = await credential.get_token(scope)
assert token.token == access_token

# client_id kwarg should override AZURE_CLIENT_ID
nondefault_client_id = "non" + default_client_id
transport = async_validating_transport(
requests=[
Request(
base_url=authority,
method="POST",
required_data={
"client_assertion": exchange_token,
"client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
"client_id": nondefault_client_id,
"grant_type": "client_credentials",
"scope": scope,
},
)
],
responses=[success_response],
)

with mock.patch.dict("os.environ", mock_environ, clear=True):
credential = ManagedIdentityCredential(client_id=nondefault_client_id, transport=transport)
token = await credential.get_token(scope)
assert token.token == access_token

# AZURE_CLIENT_ID may not have a value, in which case client_id is required
transport = async_validating_transport(
requests=[
Request(
base_url=authority,
method="POST",
required_data={
"client_assertion": exchange_token,
"client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
"client_id": nondefault_client_id,
"grant_type": "client_credentials",
"scope": scope,
},
)
],
responses=[success_response],
)

with mock.patch.dict(
"os.environ",
{
EnvironmentVariables.AZURE_AUTHORITY_HOST: authority,
EnvironmentVariables.AZURE_CLIENT_ID: client_id,
EnvironmentVariables.AZURE_TENANT_ID: tenant,
EnvironmentVariables.AZURE_FEDERATED_TOKEN_FILE: token_file.strpath,
},
clear=True,
):
credential = ManagedIdentityCredential(transport=transport)
token = await credential.get_token(scope)
with pytest.raises(ValueError):
ManagedIdentityCredential()

credential = ManagedIdentityCredential(client_id=nondefault_client_id, transport=transport)
token = await credential.get_token(scope)
assert token.token == access_token

0 comments on commit dcbd6d9

Please sign in to comment.