diff --git a/sdk/identity/azure-identity/azure/identity/_authn_client.py b/sdk/identity/azure-identity/azure/identity/_authn_client.py index a45eb49a81a2..2b165d0a0a52 100644 --- a/sdk/identity/azure-identity/azure/identity/_authn_client.py +++ b/sdk/identity/azure-identity/azure/identity/_authn_client.py @@ -23,7 +23,7 @@ ) from azure.core.pipeline.transport import RequestsTransport, HttpRequest from ._constants import AZURE_CLI_CLIENT_ID -from ._internal import get_default_authority +from ._internal import get_default_authority, normalize_authority from ._internal.user_agent import USER_AGENT try: @@ -62,8 +62,8 @@ def __init__(self, endpoint=None, authority=None, tenant=None, **kwargs): # pyl else: if not tenant: raise ValueError("'tenant' is required") - authority = authority or get_default_authority() - self._auth_url = "https://" + "/".join((authority.strip("/"), tenant.strip("/"), "oauth2/v2.0/token")) + authority = normalize_authority(authority) if authority else get_default_authority() + self._auth_url = "/".join((authority, tenant.strip("/"), "oauth2/v2.0/token")) self._cache = kwargs.get("cache") or TokenCache() # type: TokenCache @property diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/default.py b/sdk/identity/azure-identity/azure/identity/_credentials/default.py index de1103e0a4b9..7432867a3ac6 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/default.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/default.py @@ -6,7 +6,7 @@ import os from .._constants import EnvironmentVariables -from .._internal import get_default_authority +from .._internal import get_default_authority, normalize_authority from .browser import InteractiveBrowserCredential from .chained import ChainedTokenCredential from .environment import EnvironmentCredential @@ -62,7 +62,8 @@ class DefaultAzureCredential(ChainedTokenCredential): """ def __init__(self, **kwargs): - authority = kwargs.pop("authority", None) or get_default_authority() + authority = kwargs.pop("authority", None) + authority = normalize_authority(authority) if authority else get_default_authority() shared_cache_username = kwargs.pop("shared_cache_username", os.environ.get(EnvironmentVariables.AZURE_USERNAME)) shared_cache_tenant_id = kwargs.pop( diff --git a/sdk/identity/azure-identity/azure/identity/_internal/__init__.py b/sdk/identity/azure-identity/azure/identity/_internal/__init__.py index a3cbd73d1586..da4f702d6842 100644 --- a/sdk/identity/azure-identity/azure/identity/_internal/__init__.py +++ b/sdk/identity/azure-identity/azure/identity/_internal/__init__.py @@ -3,12 +3,30 @@ # Licensed under the MIT License. # ------------------------------------ import os +from six.moves.urllib_parse import urlparse from .._constants import EnvironmentVariables, KnownAuthorities +def normalize_authority(authority): + # type: (str) -> str + """Ensure authority uses https, strip trailing spaces and /""" + + parsed = urlparse(authority) + if not parsed.scheme: + return "https://" + authority.rstrip(" /") + if parsed.scheme != "https": + raise ValueError( + "'{}' is an invalid authority. The value must be a TLS protected (https) URL.".format(authority) + ) + + return authority.rstrip(" /") + + def get_default_authority(): - return os.environ.get(EnvironmentVariables.AZURE_AUTHORITY_HOST, KnownAuthorities.AZURE_PUBLIC_CLOUD) + # type: () -> str + authority = os.environ.get(EnvironmentVariables.AZURE_AUTHORITY_HOST, KnownAuthorities.AZURE_PUBLIC_CLOUD) + return normalize_authority(authority) # pylint:disable=wrong-import-position diff --git a/sdk/identity/azure-identity/azure/identity/_internal/aad_client_base.py b/sdk/identity/azure-identity/azure/identity/_internal/aad_client_base.py index 2fb8ef43f0e0..10f2a0ef8ea8 100644 --- a/sdk/identity/azure-identity/azure/identity/_internal/aad_client_base.py +++ b/sdk/identity/azure-identity/azure/identity/_internal/aad_client_base.py @@ -17,7 +17,7 @@ from azure.core.credentials import AccessToken from azure.core.exceptions import ClientAuthenticationError -from . import get_default_authority +from . import get_default_authority, normalize_authority try: ABC = abc.ABC @@ -34,10 +34,10 @@ class AadClientBase(ABC): def __init__(self, tenant_id, client_id, cache=None, **kwargs): # type: (str, str, Optional[TokenCache], **Any) -> None - authority = kwargs.pop("authority", None) or get_default_authority() - if authority[-1] == "/": - authority = authority[:-1] - token_endpoint = "https://" + "/".join((authority, tenant_id, "oauth2/v2.0/token")) + authority = kwargs.pop("authority", None) + authority = normalize_authority(authority) if authority else get_default_authority() + + token_endpoint = "/".join((authority, tenant_id, "oauth2/v2.0/token")) config = {"token_endpoint": token_endpoint} self._cache = cache or TokenCache() diff --git a/sdk/identity/azure-identity/azure/identity/_internal/msal_credentials.py b/sdk/identity/azure-identity/azure/identity/_internal/msal_credentials.py index 78552ec106d7..b7fefdf30237 100644 --- a/sdk/identity/azure-identity/azure/identity/_internal/msal_credentials.py +++ b/sdk/identity/azure-identity/azure/identity/_internal/msal_credentials.py @@ -14,7 +14,7 @@ from .exception_wrapper import wrap_exceptions from .msal_transport_adapter import MsalTransportAdapter -from .._internal import get_default_authority +from .._internal import get_default_authority, normalize_authority try: ABC = abc.ABC @@ -37,8 +37,10 @@ class MsalCredential(ABC): def __init__(self, client_id, client_credential=None, **kwargs): # type: (str, Optional[Union[str, Mapping[str, str]]], **Any) -> None tenant_id = kwargs.pop("tenant_id", "organizations") - authority = kwargs.pop("authority", None) or get_default_authority() - self._base_url = "https://" + "/".join((authority.strip("/"), tenant_id.strip("/"))) + authority = kwargs.pop("authority", None) + authority = normalize_authority(authority) if authority else get_default_authority() + + self._base_url = "/".join((authority, tenant_id.strip("/"))) self._client_credential = client_credential self._client_id = client_id diff --git a/sdk/identity/azure-identity/azure/identity/_internal/shared_token_cache.py b/sdk/identity/azure-identity/azure/identity/_internal/shared_token_cache.py index d874ba5701f2..22a42bfdf509 100644 --- a/sdk/identity/azure-identity/azure/identity/_internal/shared_token_cache.py +++ b/sdk/identity/azure-identity/azure/identity/_internal/shared_token_cache.py @@ -7,10 +7,11 @@ import sys from msal import TokenCache +from six.moves.urllib_parse import urlparse from .. import CredentialUnavailableError from .._constants import KnownAuthorities -from .._internal import get_default_authority +from .._internal import get_default_authority, normalize_authority try: ABC = abc.ABC @@ -87,8 +88,11 @@ class SharedTokenCacheBase(ABC): def __init__(self, username=None, **kwargs): # pylint:disable=unused-argument # type: (Optional[str], **Any) -> None - self._authority = kwargs.pop("authority", None) or get_default_authority() - self._authority_aliases = KNOWN_ALIASES.get(self._authority) or frozenset((self._authority,)) + authority = kwargs.pop("authority", None) + self._authority = normalize_authority(authority) if authority else get_default_authority() + + environment = urlparse(self._authority).netloc + self._environment_aliases = KNOWN_ALIASES.get(environment) or frozenset((environment,)) self._username = username self._tenant_id = kwargs.pop("tenant_id", None) @@ -125,7 +129,7 @@ def _get_cache_items_for_authority(self, credential_type): items = [] for item in self._cache.find(credential_type): environment = item.get("environment") - if environment in self._authority_aliases: + if environment in self._environment_aliases: items.append(item) return items diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/default.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/default.py index 3caf443bf793..f8feca5fb0ea 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/default.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/default.py @@ -7,7 +7,7 @@ from typing import TYPE_CHECKING from ..._constants import EnvironmentVariables -from ..._internal import get_default_authority +from ..._internal import get_default_authority, normalize_authority from .azure_cli import AzureCliCredential from .chained import ChainedTokenCredential from .environment import EnvironmentCredential @@ -53,7 +53,8 @@ class DefaultAzureCredential(ChainedTokenCredential): """ def __init__(self, **kwargs): - authority = kwargs.pop("authority", None) or get_default_authority() + authority = kwargs.pop("authority", None) + authority = normalize_authority(authority) if authority else get_default_authority() shared_cache_username = kwargs.pop("shared_cache_username", os.environ.get(EnvironmentVariables.AZURE_USERNAME)) shared_cache_tenant_id = kwargs.pop( diff --git a/sdk/identity/azure-identity/tests/test_aad_client.py b/sdk/identity/azure-identity/tests/test_aad_client.py index 5e3012e06061..744cb0e70939 100644 --- a/sdk/identity/azure-identity/tests/test_aad_client.py +++ b/sdk/identity/azure-identity/tests/test_aad_client.py @@ -99,15 +99,17 @@ def assert_secrets_not_exposed(): assert_secrets_not_exposed() -def test_request_url(): - authority = "authority.com" +@pytest.mark.parametrize("authority", ("localhost", "https://localhost")) +def test_request_url(authority): tenant_id = "expected_tenant" + parsed_authority = urlparse(authority) + expected_netloc = parsed_authority.netloc or authority # "localhost" parses to netloc "", path "localhost" def send(request, **_): - scheme, netloc, path, _, _, _ = urlparse(request.url) - assert scheme == "https" - assert netloc == authority - assert path.startswith("/" + tenant_id) + actual = urlparse(request.url) + assert actual.scheme == "https" + assert actual.netloc == expected_netloc + assert actual.path.startswith("/" + tenant_id) return mock_response(json_payload={"token_type": "Bearer", "expires_in": 42, "access_token": "***"}) client = AadClient(tenant_id, "client id", transport=Mock(send=send), authority=authority) diff --git a/sdk/identity/azure-identity/tests/test_aad_client_async.py b/sdk/identity/azure-identity/tests/test_aad_client_async.py index a2d24ce8e3d9..de77ab384736 100644 --- a/sdk/identity/azure-identity/tests/test_aad_client_async.py +++ b/sdk/identity/azure-identity/tests/test_aad_client_async.py @@ -43,24 +43,26 @@ async def test_uses_msal_correctly(): @pytest.mark.asyncio -async def test_request_url(): - authority = "authority.com" - tenant = "expected_tenant" +@pytest.mark.parametrize("authority", ("localhost", "https://localhost")) +async def test_request_url(authority): + tenant_id = "expected_tenant" + parsed_authority = urlparse(authority) + expected_netloc = parsed_authority.netloc or authority # "localhost" parses to netloc "", path "localhost" async def send(request, **_): - scheme, netloc, path, _, _, _ = urlparse(request.url) - assert scheme == "https" - assert netloc == authority - assert path.startswith("/" + tenant) + actual = urlparse(request.url) + assert actual.scheme == "https" + assert actual.netloc == expected_netloc + assert actual.path.startswith("/" + tenant_id) return mock_response(json_payload={"token_type": "Bearer", "expires_in": 42, "access_token": "***"}) - client = AadClient(tenant, "client id", transport=Mock(send=send), authority=authority) + client = AadClient(tenant_id, "client id", transport=Mock(send=send), authority=authority) await client.obtain_token_by_authorization_code("code", "uri", "scope") await client.obtain_token_by_refresh_token("refresh token", "scope") # authority can be configured via environment variable with patch.dict("os.environ", {EnvironmentVariables.AZURE_AUTHORITY_HOST: authority}, clear=True): - client = AadClient(tenant_id=tenant, client_id="client id", transport=Mock(send=send)) + client = AadClient(tenant_id=tenant_id, client_id="client id", transport=Mock(send=send)) await client.obtain_token_by_authorization_code("code", "uri", "scope") await client.obtain_token_by_refresh_token("refresh token", "scope") diff --git a/sdk/identity/azure-identity/tests/test_authn_client.py b/sdk/identity/azure-identity/tests/test_authn_client.py index e2772e130d27..6732d43cd4dc 100644 --- a/sdk/identity/azure-identity/tests/test_authn_client.py +++ b/sdk/identity/azure-identity/tests/test_authn_client.py @@ -15,6 +15,7 @@ from azure.core.credentials import AccessToken from azure.identity._authn_client import AuthnClient from azure.identity._constants import EnvironmentVariables +import pytest from six.moves.urllib_parse import urlparse from helpers import mock_response @@ -205,28 +206,30 @@ def mock_send(request, **kwargs): assert not client.get_cached_token([scope_a, scope_b]) -def test_request_url(): - authority = "localhost" - tenant = "expected_tenant" +@pytest.mark.parametrize("authority", ("localhost", "https://localhost")) +def test_request_url(authority): + tenant_id = "expected_tenant" + parsed_authority = urlparse(authority) + expected_netloc = parsed_authority.netloc or authority # "localhost" parses to netloc "", path "localhost" def validate_url(url): - scheme, netloc, path, _, _, _ = urlparse(url) - assert scheme == "https" - assert netloc == authority - assert path.startswith("/" + tenant) + actual = urlparse(url) + assert actual.scheme == "https" + assert actual.netloc == expected_netloc + assert actual.path.startswith("/" + tenant_id) def mock_send(request, **kwargs): validate_url(request.url) return mock_response(json_payload={"token_type": "Bearer", "expires_in": 42, "access_token": "***"}) - client = AuthnClient(tenant=tenant, transport=Mock(send=mock_send), authority=authority) + client = AuthnClient(tenant=tenant_id, transport=Mock(send=mock_send), authority=authority) client.request_token(("scope",)) request = client.get_refresh_token_grant_request({"secret": "***"}, "scope") validate_url(request.url) # authority can be configured via environment variable with patch.dict("os.environ", {EnvironmentVariables.AZURE_AUTHORITY_HOST: authority}, clear=True): - client = AuthnClient(tenant=tenant, transport=Mock(send=mock_send)) + client = AuthnClient(tenant=tenant_id, transport=Mock(send=mock_send)) client.request_token(("scope",)) request = client.get_refresh_token_grant_request({"secret": "***"}, "scope") validate_url(request.url) diff --git a/sdk/identity/azure-identity/tests/test_authn_client_async.py b/sdk/identity/azure-identity/tests/test_authn_client_async.py index 7966787d9b52..ab94c2c236c4 100644 --- a/sdk/identity/azure-identity/tests/test_authn_client_async.py +++ b/sdk/identity/azure-identity/tests/test_authn_client_async.py @@ -15,21 +15,23 @@ @pytest.mark.asyncio -async def test_request_url(): - authority = "authority.com" - tenant = "expected_tenant" +@pytest.mark.parametrize("authority", ("localhost", "https://localhost")) +async def test_request_url(authority): + tenant_id = "expected_tenant" + parsed_authority = urlparse(authority) + expected_netloc = parsed_authority.netloc or authority # "localhost" parses to netloc "", path "localhost" def mock_send(request, **kwargs): - scheme, netloc, path, _, _, _ = urlparse(request.url) - assert scheme == "https" - assert netloc == authority - assert path.startswith("/" + tenant) - return mock_response(json_payload={"token_type": "Bearer", "expires_in": 42, "access_token": "***"}) + actual = urlparse(request.url) + assert actual.scheme == "https" + assert actual.netloc == expected_netloc + assert actual.path.startswith("/" + tenant_id) + return mock_response(json_payload={"token_type": "Bearer", "expires_in": 42, "access_token": "*"}) - client = AsyncAuthnClient(tenant=tenant, transport=Mock(send=wrap_in_future(mock_send)), authority=authority) + client = AsyncAuthnClient(tenant=tenant_id, transport=Mock(send=wrap_in_future(mock_send)), authority=authority) await client.request_token(("scope",)) # authority can be configured via environment variable with patch.dict("os.environ", {EnvironmentVariables.AZURE_AUTHORITY_HOST: authority}, clear=True): - client = AsyncAuthnClient(tenant=tenant, transport=Mock(send=wrap_in_future(mock_send))) + client = AsyncAuthnClient(tenant=tenant_id, transport=Mock(send=wrap_in_future(mock_send))) await client.request_token(("scope",)) diff --git a/sdk/identity/azure-identity/tests/test_certificate_credential.py b/sdk/identity/azure-identity/tests/test_certificate_credential.py index c861ed73d936..06b65ee54c88 100644 --- a/sdk/identity/azure-identity/tests/test_certificate_credential.py +++ b/sdk/identity/azure-identity/tests/test_certificate_credential.py @@ -63,20 +63,21 @@ def test_user_agent(): credential.get_token("scope") +@pytest.mark.parametrize("authority", ("localhost", "https://localhost")) @pytest.mark.parametrize("cert_path,cert_password", BOTH_CERTS) -def test_request_url(cert_path, cert_password): - authority = "authority.com" +def test_request_url(cert_path, cert_password, authority): + """the credential should accept an authority, with or without scheme, as an argument or environment variable""" + tenant_id = "expected_tenant" access_token = "***" - - def validate_url(url): - parsed = urlparse(url) - assert parsed.scheme == "https" - assert parsed.netloc == authority - assert parsed.path.startswith("/" + tenant_id) + parsed_authority = urlparse(authority) + expected_netloc = parsed_authority.netloc or authority # "localhost" parses to netloc "", path "localhost" def mock_send(request, **kwargs): - validate_url(request.url) + actual = urlparse(request.url) + assert actual.scheme == "https" + assert actual.netloc == expected_netloc + assert actual.path.startswith("/" + tenant_id) return mock_response(json_payload={"token_type": "Bearer", "expires_in": 42, "access_token": access_token}) cred = CertificateCredential( diff --git a/sdk/identity/azure-identity/tests/test_certificate_credential_async.py b/sdk/identity/azure-identity/tests/test_certificate_credential_async.py index 1efb467a56be..948c411bddbe 100644 --- a/sdk/identity/azure-identity/tests/test_certificate_credential_async.py +++ b/sdk/identity/azure-identity/tests/test_certificate_credential_async.py @@ -77,20 +77,21 @@ async def test_user_agent(): @pytest.mark.asyncio +@pytest.mark.parametrize("authority", ("localhost", "https://localhost")) @pytest.mark.parametrize("cert_path,cert_password", BOTH_CERTS) -async def test_request_url(cert_path, cert_password): - authority = "authority.com" +async def test_request_url(cert_path, cert_password, authority): + """the credential should accept an authority, with or without scheme, as an argument or environment variable""" + tenant_id = "expected_tenant" access_token = "***" - - def validate_url(url): - parsed = urlparse(url) - assert parsed.scheme == "https" - assert parsed.netloc == authority - assert parsed.path.startswith("/" + tenant_id) + parsed_authority = urlparse(authority) + expected_netloc = parsed_authority.netloc or authority # "localhost" parses to netloc "", path "localhost" async def mock_send(request, **kwargs): - validate_url(request.url) + actual = urlparse(request.url) + assert actual.scheme == "https" + assert actual.netloc == expected_netloc + assert actual.path.startswith("/" + tenant_id) return mock_response(json_payload={"token_type": "Bearer", "expires_in": 42, "access_token": access_token}) cred = CertificateCredential( diff --git a/sdk/identity/azure-identity/tests/test_client_secret_credential.py b/sdk/identity/azure-identity/tests/test_client_secret_credential.py index e92c8cccfdb8..732ef67f97c3 100644 --- a/sdk/identity/azure-identity/tests/test_client_secret_credential.py +++ b/sdk/identity/azure-identity/tests/test_client_secret_credential.py @@ -80,17 +80,20 @@ def test_client_secret_credential(): assert token.token == access_token -def test_request_url(): - authority = "localhost" +@pytest.mark.parametrize("authority", ("localhost", "https://localhost")) +def test_request_url(authority): + """the credential should accept an authority, with or without scheme, as an argument or environment variable""" + tenant_id = "expected_tenant" access_token = "***" + parsed_authority = urlparse(authority) + expected_netloc = parsed_authority.netloc or authority # "localhost" parses to netloc "", path "localhost" def mock_send(request, **kwargs): - parsed = urlparse(request.url) - assert parsed.scheme == "https" - assert parsed.netloc == authority - assert parsed.path.startswith("/" + tenant_id) - + actual = urlparse(request.url) + assert actual.scheme == "https" + assert actual.netloc == expected_netloc + assert actual.path.startswith("/" + tenant_id) return mock_response(json_payload={"token_type": "Bearer", "expires_in": 42, "access_token": access_token}) credential = ClientSecretCredential( diff --git a/sdk/identity/azure-identity/tests/test_client_secret_credential_async.py b/sdk/identity/azure-identity/tests/test_client_secret_credential_async.py index 2253910fafb2..b37ff09dd5bc 100644 --- a/sdk/identity/azure-identity/tests/test_client_secret_credential_async.py +++ b/sdk/identity/azure-identity/tests/test_client_secret_credential_async.py @@ -107,17 +107,20 @@ async def test_client_secret_credential(): @pytest.mark.asyncio -async def test_request_url(): - authority = "localhost" +@pytest.mark.parametrize("authority", ("localhost", "https://localhost")) +async def test_request_url(authority): + """the credential should accept an authority, with or without scheme, as an argument or environment variable""" + tenant_id = "expected_tenant" access_token = "***" + parsed_authority = urlparse(authority) + expected_netloc = parsed_authority.netloc or authority # "localhost" parses to netloc "", path "localhost" async def mock_send(request, **kwargs): - parsed = urlparse(request.url) - assert parsed.scheme == "https" - assert parsed.netloc == authority - assert parsed.path.startswith("/" + tenant_id) - + actual = urlparse(request.url) + assert actual.scheme == "https" + assert actual.netloc == expected_netloc + assert actual.path.startswith("/" + tenant_id) return mock_response(json_payload={"token_type": "Bearer", "expires_in": 42, "access_token": access_token}) credential = ClientSecretCredential( diff --git a/sdk/identity/azure-identity/tests/test_default.py b/sdk/identity/azure-identity/tests/test_default.py index 842611afc7e6..5f0c79ae9b1d 100644 --- a/sdk/identity/azure-identity/tests/test_default.py +++ b/sdk/identity/azure-identity/tests/test_default.py @@ -15,6 +15,7 @@ from azure.identity._constants import EnvironmentVariables from azure.identity._credentials.azure_cli import AzureCliCredential from azure.identity._credentials.managed_identity import ManagedIdentityCredential +import pytest from six.moves.urllib_parse import urlparse from helpers import mock_response, Request, validating_transport @@ -45,12 +46,14 @@ def test_iterates_only_once(): assert successful_credential.get_token.call_count == n + 1 -def test__authority(): +@pytest.mark.parametrize("authority", ("localhost", "https://localhost")) +def test_authority(authority): """the credential should accept authority configuration by keyword argument or environment""" - def test_initialization(mock_credential, expect_argument): - authority = "localhost" + parsed_authority = urlparse(authority) + expected_netloc = parsed_authority.netloc or authority # "localhost" parses to netloc "", path "localhost" + def test_initialization(mock_credential, expect_argument): DefaultAzureCredential(authority=authority) assert mock_credential.call_count == 1 @@ -62,7 +65,9 @@ def test_initialization(mock_credential, expect_argument): for _, kwargs in mock_credential.call_args_list: if expect_argument: - assert kwargs["authority"] == authority + actual = urlparse(kwargs["authority"]) + assert actual.scheme == "https" + assert actual.netloc == expected_netloc else: assert "authority" not in kwargs diff --git a/sdk/identity/azure-identity/tests/test_default_async.py b/sdk/identity/azure-identity/tests/test_default_async.py index 0472d351d8fd..463aaa235256 100644 --- a/sdk/identity/azure-identity/tests/test_default_async.py +++ b/sdk/identity/azure-identity/tests/test_default_async.py @@ -40,12 +40,14 @@ async def test_iterates_only_once(): assert successful_credential.get_token.call_count == n + 1 -def test_authority(): +@pytest.mark.parametrize("authority", ("localhost", "https://localhost")) +def test_authority(authority): """the credential should accept authority configuration by keyword argument or environment""" - def test_initialization(mock_credential, expect_argument): - authority = "localhost" + parsed_authority = urlparse(authority) + expected_netloc = parsed_authority.netloc or authority # "localhost" parses to netloc "", path "localhost" + def test_initialization(mock_credential, expect_argument): DefaultAzureCredential(authority=authority) assert mock_credential.call_count == 1 @@ -57,7 +59,9 @@ def test_initialization(mock_credential, expect_argument): for _, kwargs in mock_credential.call_args_list: if expect_argument: - assert kwargs["authority"] == authority + actual = urlparse(kwargs["authority"]) + assert actual.scheme == "https" + assert actual.netloc == expected_netloc else: assert "authority" not in kwargs diff --git a/sdk/identity/azure-identity/tests/test_identity.py b/sdk/identity/azure-identity/tests/test_identity.py index 60216c0e9d23..a2afb279941c 100644 --- a/sdk/identity/azure-identity/tests/test_identity.py +++ b/sdk/identity/azure-identity/tests/test_identity.py @@ -21,12 +21,52 @@ EnvironmentCredential, ) from azure.identity._credentials.managed_identity import ImdsCredential -from azure.identity._constants import EnvironmentVariables +from azure.identity._constants import EnvironmentVariables, KnownAuthorities +from azure.identity._internal import get_default_authority, normalize_authority import pytest from helpers import mock_response, Request, validating_transport +def test_get_default_authority(): + """get_default_authority should return public cloud or the value of $AZURE_AUTHORITY_HOST, with 'https' scheme""" + + # default scheme is https + for authority in ("localhost", "https://localhost"): + with patch.dict("os.environ", {EnvironmentVariables.AZURE_AUTHORITY_HOST: authority}, clear=True): + assert get_default_authority() == "https://localhost" + + # default to public cloud + for environ in ({}, {EnvironmentVariables.AZURE_AUTHORITY_HOST: KnownAuthorities.AZURE_PUBLIC_CLOUD}): + with patch.dict("os.environ", environ, clear=True): + assert get_default_authority() == "https://" + KnownAuthorities.AZURE_PUBLIC_CLOUD + + # require https + with pytest.raises(ValueError): + with patch.dict("os.environ", {EnvironmentVariables.AZURE_AUTHORITY_HOST: "http://localhost"}, clear=True): + get_default_authority() + + +def test_normalize_authority(): + """normalize_authority should return a URI with a scheme and no trailing spaces or forward slashes""" + + localhost = "localhost" + localhost_tls = "https://" + localhost + + # accept https if specified, default to it when no scheme specified + for uri in (localhost, localhost_tls): + assert normalize_authority(uri) == localhost_tls + + # remove trailing characters + for string in ("/", " ", "/ ", " /"): + assert normalize_authority(uri + string) == localhost_tls + + # raise for other schemes + for scheme in ("http", "file"): + with pytest.raises(ValueError): + normalize_authority(scheme + "://localhost") + + def test_client_secret_environment_credential(): client_id = "fake-client-id" secret = "fake-client-secret" diff --git a/sdk/identity/azure-identity/tests/test_shared_cache_credential.py b/sdk/identity/azure-identity/tests/test_shared_cache_credential.py index e6f7572be250..daf5a73dd503 100644 --- a/sdk/identity/azure-identity/tests/test_shared_cache_credential.py +++ b/sdk/identity/azure-identity/tests/test_shared_cache_credential.py @@ -17,6 +17,7 @@ from azure.identity._internal.user_agent import USER_AGENT from msal import TokenCache import pytest +from six.moves.urllib_parse import urlparse try: from unittest.mock import Mock, patch @@ -64,6 +65,26 @@ def test_user_agent(): credential.get_token("scope") +@pytest.mark.parametrize("authority", ("localhost", "https://localhost")) +def test_authority(authority): + """the credential should accept an authority, with or without scheme, as an argument or environment variable""" + + parsed_authority = urlparse(authority) + expected_netloc = parsed_authority.netloc or authority # "localhost" parses to netloc "", path "localhost" + + class MockCredential(SharedTokenCacheCredential): + def _get_auth_client(self, authority=None, **kwargs): + actual = urlparse(authority) + assert actual.scheme == "https" + assert actual.netloc == expected_netloc + + transport = Mock(send=Mock(side_effect=Exception("credential shouldn't send a request"))) + MockCredential(_cache=TokenCache(), authority=authority, transport=transport) + + with patch.dict("os.environ", {EnvironmentVariables.AZURE_AUTHORITY_HOST: authority}, clear=True): + MockCredential(_cache=TokenCache(), authority=authority, transport=transport) + + def test_empty_cache(): """the credential should raise CredentialUnavailableError when the cache is empty""" @@ -484,6 +505,11 @@ def test_authority_environment_variable(): def get_account_event( username, uid, utid, authority=None, client_id="client-id", refresh_token="refresh-token", scopes=None ): + if authority: + endpoint = "https://" + "/".join((authority, utid, "path",)) + else: + endpoint = get_default_authority() + "/{}/{}".format(utid, "path") + return { "response": build_aad_response( uid=uid, @@ -493,7 +519,7 @@ def get_account_event( foci="1", ), "client_id": client_id, - "token_endpoint": "https://" + "/".join((authority or get_default_authority(), utid, "/path",)), + "token_endpoint": endpoint, "scope": scopes or ["scope"], } diff --git a/sdk/identity/azure-identity/tests/test_shared_cache_credential_async.py b/sdk/identity/azure-identity/tests/test_shared_cache_credential_async.py index 9ee971517c98..c718eaf3481c 100644 --- a/sdk/identity/azure-identity/tests/test_shared_cache_credential_async.py +++ b/sdk/identity/azure-identity/tests/test_shared_cache_credential_async.py @@ -3,6 +3,7 @@ # Licensed under the MIT License. # ------------------------------------ from unittest.mock import Mock, patch +from urllib.parse import urlparse from azure.core.exceptions import ClientAuthenticationError from azure.core.pipeline.policies import SansIOHTTPPolicy @@ -108,6 +109,26 @@ async def test_user_agent(): await credential.get_token("scope") +@pytest.mark.parametrize("authority", ("localhost", "https://localhost")) +def test_authority(authority): + """the credential should accept an authority, with or without scheme, as an argument or environment variable""" + + parsed_authority = urlparse(authority) + expected_netloc = parsed_authority.netloc or authority # "localhost" parses to netloc "", path "localhost" + + class MockCredential(SharedTokenCacheCredential): + def _get_auth_client(self, authority=None, **kwargs): + actual = urlparse(authority) + assert actual.scheme == "https" + assert actual.netloc == expected_netloc + + transport = Mock(send=Mock(side_effect=Exception("credential shouldn't send a request"))) + MockCredential(_cache=TokenCache(), authority=authority, transport=transport) + + with patch.dict("os.environ", {EnvironmentVariables.AZURE_AUTHORITY_HOST: authority}, clear=True): + MockCredential(_cache=TokenCache(), authority=authority, transport=transport) + + @pytest.mark.asyncio async def test_empty_cache(): """the credential should raise CredentialUnavailableError when the cache is empty"""