diff --git a/sdk/identity/azure-identity/azure/identity/_constants.py b/sdk/identity/azure-identity/azure/identity/_constants.py index e72aa59d9772..878d7f6bce7f 100644 --- a/sdk/identity/azure-identity/azure/identity/_constants.py +++ b/sdk/identity/azure-identity/azure/identity/_constants.py @@ -48,4 +48,4 @@ class EnvironmentVariables: AZURE_REGIONAL_AUTHORITY_NAME = "AZURE_REGIONAL_AUTHORITY_NAME" AZURE_FEDERATED_TOKEN_FILE = "AZURE_FEDERATED_TOKEN_FILE" - TOKEN_EXCHANGE_VARS = (AZURE_CLIENT_ID, AZURE_TENANT_ID, AZURE_FEDERATED_TOKEN_FILE) + TOKEN_EXCHANGE_VARS = (AZURE_AUTHORITY_HOST, AZURE_TENANT_ID, AZURE_FEDERATED_TOKEN_FILE) diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/managed_identity.py b/sdk/identity/azure-identity/azure/identity/_credentials/managed_identity.py index d48a9ac4c4ba..d602c4d4d4e3 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/managed_identity.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/managed_identity.py @@ -69,9 +69,13 @@ def __init__(self, **kwargs): _LOGGER.info("%s will use token exchange", self.__class__.__name__) from .token_exchange import TokenExchangeCredential + client_id = kwargs.pop("client_id", None) or os.environ.get(EnvironmentVariables.AZURE_CLIENT_ID) + if not client_id: + raise ValueError('Configure the environment with a client ID or pass a value for "client_id" argument') + self._credential = TokenExchangeCredential( tenant_id=os.environ[EnvironmentVariables.AZURE_TENANT_ID], - client_id=os.environ[EnvironmentVariables.AZURE_CLIENT_ID], + client_id=client_id, token_file_path=os.environ[EnvironmentVariables.AZURE_FEDERATED_TOKEN_FILE], **kwargs ) diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/managed_identity.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/managed_identity.py index ac0e6c450f0e..dabb1be0b21f 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/managed_identity.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/managed_identity.py @@ -66,9 +66,13 @@ def __init__(self, **kwargs: "Any") -> None: _LOGGER.info("%s will use token exchange", self.__class__.__name__) from .token_exchange import TokenExchangeCredential + client_id = kwargs.pop("client_id", None) or os.environ.get(EnvironmentVariables.AZURE_CLIENT_ID) + if not client_id: + raise ValueError('Configure the environment with a client ID or pass a value for "client_id" argument') + self._credential = TokenExchangeCredential( tenant_id=os.environ[EnvironmentVariables.AZURE_TENANT_ID], - client_id=os.environ[EnvironmentVariables.AZURE_CLIENT_ID], + client_id=client_id, token_file_path=os.environ[EnvironmentVariables.AZURE_FEDERATED_TOKEN_FILE], **kwargs ) diff --git a/sdk/identity/azure-identity/tests/test_managed_identity.py b/sdk/identity/azure-identity/tests/test_managed_identity.py index f7a21eba2d7b..1eb31cc12448 100644 --- a/sdk/identity/azure-identity/tests/test_managed_identity.py +++ b/sdk/identity/azure-identity/tests/test_managed_identity.py @@ -31,6 +31,7 @@ }, {EnvironmentVariables.IDENTITY_ENDPOINT: "...", EnvironmentVariables.IMDS_ENDPOINT: "..."}, # Arc { # token exchange + EnvironmentVariables.AZURE_AUTHORITY_HOST: "https://localhost", EnvironmentVariables.AZURE_CLIENT_ID: "...", EnvironmentVariables.AZURE_TENANT_ID: "...", EnvironmentVariables.AZURE_FEDERATED_TOKEN_FILE: __file__, @@ -73,24 +74,6 @@ def test_context_manager_incomplete_configuration(): pass -ALL_ENVIRONMENTS = ( - {EnvironmentVariables.MSI_ENDPOINT: "...", EnvironmentVariables.MSI_SECRET: "..."}, # App Service - {EnvironmentVariables.MSI_ENDPOINT: "..."}, # Cloud Shell - { # Service Fabric - EnvironmentVariables.IDENTITY_ENDPOINT: "...", - EnvironmentVariables.IDENTITY_HEADER: "...", - EnvironmentVariables.IDENTITY_SERVER_THUMBPRINT: "...", - }, - {EnvironmentVariables.IDENTITY_ENDPOINT: "...", EnvironmentVariables.IMDS_ENDPOINT: "..."}, # Arc - { # token exchange - EnvironmentVariables.AZURE_CLIENT_ID: "...", - EnvironmentVariables.AZURE_TENANT_ID: "...", - EnvironmentVariables.AZURE_FEDERATED_TOKEN_FILE: __file__, - }, - {}, # IMDS -) - - @pytest.mark.parametrize("environ", ALL_ENVIRONMENTS) def test_custom_hooks(environ): """The credential's pipeline should include azure-core's CustomHookPolicy""" @@ -790,10 +773,21 @@ def test_token_exchange(tmpdir): token_file.write(exchange_token) access_token = "***" authority = "https://localhost" - client_id = "client_id" + default_client_id = "default_client_id" tenant = "tenant_id" scope = "scope" + success_response = mock_response( + json_payload={ + "access_token": access_token, + "expires_in": 3600, + "ext_expires_in": 3600, + "expires_on": int(time.time()) + 3600, + "not_before": int(time.time()), + "resource": scope, + "token_type": "Bearer", + } + ) transport = validating_transport( requests=[ Request( @@ -802,38 +796,81 @@ def test_token_exchange(tmpdir): required_data={ "client_assertion": exchange_token, "client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer", - "client_id": client_id, + "client_id": default_client_id, "grant_type": "client_credentials", "scope": scope, }, ) ], - responses=[ - mock_response( - json_payload={ - "access_token": access_token, - "expires_in": 3600, - "ext_expires_in": 3600, - "expires_on": int(time.time()) + 3600, - "not_before": int(time.time()), - "resource": scope, - "token_type": "Bearer", - } + responses=[success_response], + ) + + mock_environ = { + EnvironmentVariables.AZURE_AUTHORITY_HOST: authority, + EnvironmentVariables.AZURE_CLIENT_ID: default_client_id, + EnvironmentVariables.AZURE_TENANT_ID: tenant, + EnvironmentVariables.AZURE_FEDERATED_TOKEN_FILE: token_file.strpath, + } + # credential should default to AZURE_CLIENT_ID + with mock.patch.dict("os.environ", mock_environ, clear=True): + credential = ManagedIdentityCredential(transport=transport) + token = credential.get_token(scope) + assert token.token == access_token + + # client_id kwarg should override AZURE_CLIENT_ID + nondefault_client_id = "non" + default_client_id + transport = validating_transport( + requests=[ + Request( + base_url=authority, + method="POST", + required_data={ + "client_assertion": exchange_token, + "client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer", + "client_id": nondefault_client_id, + "grant_type": "client_credentials", + "scope": scope, + }, ) ], + responses=[success_response], + ) + + with mock.patch.dict("os.environ", mock_environ, clear=True): + credential = ManagedIdentityCredential(client_id=nondefault_client_id, transport=transport) + token = credential.get_token(scope) + assert token.token == access_token + + # AZURE_CLIENT_ID may not have a value, in which case client_id is required + transport = validating_transport( + requests=[ + Request( + base_url=authority, + method="POST", + required_data={ + "client_assertion": exchange_token, + "client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer", + "client_id": nondefault_client_id, + "grant_type": "client_credentials", + "scope": scope, + }, + ) + ], + responses=[success_response], ) with mock.patch.dict( "os.environ", { EnvironmentVariables.AZURE_AUTHORITY_HOST: authority, - EnvironmentVariables.AZURE_CLIENT_ID: client_id, EnvironmentVariables.AZURE_TENANT_ID: tenant, EnvironmentVariables.AZURE_FEDERATED_TOKEN_FILE: token_file.strpath, }, clear=True, ): - credential = ManagedIdentityCredential(transport=transport) - token = credential.get_token(scope) + with pytest.raises(ValueError): + ManagedIdentityCredential() + credential = ManagedIdentityCredential(client_id=nondefault_client_id, transport=transport) + token = credential.get_token(scope) assert token.token == access_token diff --git a/sdk/identity/azure-identity/tests/test_managed_identity_async.py b/sdk/identity/azure-identity/tests/test_managed_identity_async.py index 0cbcf044f2b8..27fe2779a2c7 100644 --- a/sdk/identity/azure-identity/tests/test_managed_identity_async.py +++ b/sdk/identity/azure-identity/tests/test_managed_identity_async.py @@ -731,10 +731,21 @@ async def test_token_exchange(tmpdir): token_file.write(exchange_token) access_token = "***" authority = "https://localhost" - client_id = "client_id" + default_client_id = "default_client_id" tenant = "tenant_id" scope = "scope" + success_response = mock_response( + json_payload={ + "access_token": access_token, + "expires_in": 3600, + "ext_expires_in": 3600, + "expires_on": int(time.time()) + 3600, + "not_before": int(time.time()), + "resource": scope, + "token_type": "Bearer", + } + ) transport = async_validating_transport( requests=[ Request( @@ -743,38 +754,81 @@ async def test_token_exchange(tmpdir): required_data={ "client_assertion": exchange_token, "client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer", - "client_id": client_id, + "client_id": default_client_id, "grant_type": "client_credentials", "scope": scope, }, ) ], - responses=[ - mock_response( - json_payload={ - "access_token": access_token, - "expires_in": 3600, - "ext_expires_in": 3600, - "expires_on": int(time.time()) + 3600, - "not_before": int(time.time()), - "resource": scope, - "token_type": "Bearer", - } + responses=[success_response], + ) + + mock_environ = { + EnvironmentVariables.AZURE_AUTHORITY_HOST: authority, + EnvironmentVariables.AZURE_CLIENT_ID: default_client_id, + EnvironmentVariables.AZURE_TENANT_ID: tenant, + EnvironmentVariables.AZURE_FEDERATED_TOKEN_FILE: token_file.strpath, + } + # credential should default to AZURE_CLIENT_ID + with mock.patch.dict("os.environ", mock_environ, clear=True): + credential = ManagedIdentityCredential(transport=transport) + token = await credential.get_token(scope) + assert token.token == access_token + + # client_id kwarg should override AZURE_CLIENT_ID + nondefault_client_id = "non" + default_client_id + transport = async_validating_transport( + requests=[ + Request( + base_url=authority, + method="POST", + required_data={ + "client_assertion": exchange_token, + "client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer", + "client_id": nondefault_client_id, + "grant_type": "client_credentials", + "scope": scope, + }, ) ], + responses=[success_response], + ) + + with mock.patch.dict("os.environ", mock_environ, clear=True): + credential = ManagedIdentityCredential(client_id=nondefault_client_id, transport=transport) + token = await credential.get_token(scope) + assert token.token == access_token + + # AZURE_CLIENT_ID may not have a value, in which case client_id is required + transport = async_validating_transport( + requests=[ + Request( + base_url=authority, + method="POST", + required_data={ + "client_assertion": exchange_token, + "client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer", + "client_id": nondefault_client_id, + "grant_type": "client_credentials", + "scope": scope, + }, + ) + ], + responses=[success_response], ) with mock.patch.dict( "os.environ", { EnvironmentVariables.AZURE_AUTHORITY_HOST: authority, - EnvironmentVariables.AZURE_CLIENT_ID: client_id, EnvironmentVariables.AZURE_TENANT_ID: tenant, EnvironmentVariables.AZURE_FEDERATED_TOKEN_FILE: token_file.strpath, }, clear=True, ): - credential = ManagedIdentityCredential(transport=transport) - token = await credential.get_token(scope) + with pytest.raises(ValueError): + ManagedIdentityCredential() + credential = ManagedIdentityCredential(client_id=nondefault_client_id, transport=transport) + token = await credential.get_token(scope) assert token.token == access_token