Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Accept authority option with or without scheme #11050

Merged
merged 3 commits into from
Apr 28, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions sdk/identity/azure-identity/azure/identity/_authn_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
)
from azure.core.pipeline.transport import RequestsTransport, HttpRequest
from ._constants import AZURE_CLI_CLIENT_ID
from ._internal import get_default_authority
from ._internal import get_default_authority, normalize_authority
from ._internal.user_agent import USER_AGENT

try:
Expand Down Expand Up @@ -62,8 +62,8 @@ def __init__(self, endpoint=None, authority=None, tenant=None, **kwargs): # pyl
else:
if not tenant:
raise ValueError("'tenant' is required")
authority = authority or get_default_authority()
self._auth_url = "https://" + "/".join((authority.strip("/"), tenant.strip("/"), "oauth2/v2.0/token"))
authority = normalize_authority(authority) if authority else get_default_authority()
self._auth_url = "/".join((authority, tenant.strip("/"), "oauth2/v2.0/token"))
self._cache = kwargs.get("cache") or TokenCache() # type: TokenCache

@property
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import os

from .._constants import EnvironmentVariables
from .._internal import get_default_authority
from .._internal import get_default_authority, normalize_authority
from .browser import InteractiveBrowserCredential
from .chained import ChainedTokenCredential
from .environment import EnvironmentCredential
Expand Down Expand Up @@ -62,7 +62,8 @@ class DefaultAzureCredential(ChainedTokenCredential):
"""

def __init__(self, **kwargs):
authority = kwargs.pop("authority", None) or get_default_authority()
authority = kwargs.pop("authority", None)
authority = normalize_authority(authority) if authority else get_default_authority()

shared_cache_username = kwargs.pop("shared_cache_username", os.environ.get(EnvironmentVariables.AZURE_USERNAME))
shared_cache_tenant_id = kwargs.pop(
Expand Down
20 changes: 19 additions & 1 deletion sdk/identity/azure-identity/azure/identity/_internal/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,30 @@
# Licensed under the MIT License.
# ------------------------------------
import os
from six.moves.urllib_parse import urlparse

from .._constants import EnvironmentVariables, KnownAuthorities


def normalize_authority(authority):
# type: (str) -> str
"""Ensure authority uses https, strip trailing spaces and /"""

parsed = urlparse(authority)
if not parsed.scheme:
return "https://" + authority.rstrip(" /")
if parsed.scheme != "https":
raise ValueError(
"'{}' is an invalid authority. The value must be a TLS protected (https) URL.".format(authority)
)

return authority.rstrip(" /")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if authority is an invalid url?

Copy link
Member Author

@chlowell chlowell Apr 27, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Runtime exception depending on what authority is, probably not in this method because urlparse is tolerant. For example:

>>> c = ClientSecretCredential('_','_','_',authority='http:/foo')
>>> c.get_token('scope')
Traceback (most recent call last):
  ...
azure.core.exceptions.ServiceRequestError: Invalid URL 'http:/foo/_/oauth2/v2.0/token': No host supplied

>>> c = ClientSecretCredential('_','_','_',authority='htp:/foo')
>>> c.get_token('scope')
  ...
azure.core.exceptions.ServiceRequestError: No connection adapters were found for 'htp:/foo/_/oauth2/v2.0/to
ken'

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean should we check if the url is valid, and if not, still call get_default?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should use a value provided by the user as best we can. Using a different one because we think we know better creates opportunities for confusing failure and undesired success.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree with you. Changing host w/o user's awareness is bad. But as you see, if it is an invalid url, our code raise ServiceRequestError and user needs to check the error message to know what happened. How about raise a more meaningful type of error?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In another word, I would prefer, if the error happens in url, we want it fail on the line of
authority = normalize_authority(authority) if authority else get_default_authority()

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's reasonable to expect users to read error messages, and focus our effort on making those useful. We can't communicate everything a user needs in the name of the exception. So, ServiceRequestError: Invalid URL ... seems on the mark to me. What more meaningful type do you have in mind?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not requests.exceptions.InvalidURL?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No. I don't want to take a dependency on requests or raise a foreign exception.



def get_default_authority():
return os.environ.get(EnvironmentVariables.AZURE_AUTHORITY_HOST, KnownAuthorities.AZURE_PUBLIC_CLOUD)
# type: () -> str
authority = os.environ.get(EnvironmentVariables.AZURE_AUTHORITY_HOST, KnownAuthorities.AZURE_PUBLIC_CLOUD)
return normalize_authority(authority)


# pylint:disable=wrong-import-position
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from azure.core.credentials import AccessToken
from azure.core.exceptions import ClientAuthenticationError
from . import get_default_authority
from . import get_default_authority, normalize_authority

try:
ABC = abc.ABC
Expand All @@ -34,10 +34,10 @@ class AadClientBase(ABC):

def __init__(self, tenant_id, client_id, cache=None, **kwargs):
# type: (str, str, Optional[TokenCache], **Any) -> None
authority = kwargs.pop("authority", None) or get_default_authority()
if authority[-1] == "/":
authority = authority[:-1]
token_endpoint = "https://" + "/".join((authority, tenant_id, "oauth2/v2.0/token"))
authority = kwargs.pop("authority", None)
authority = normalize_authority(authority) if authority else get_default_authority()

token_endpoint = "/".join((authority, tenant_id, "oauth2/v2.0/token"))
config = {"token_endpoint": token_endpoint}

self._cache = cache or TokenCache()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from .exception_wrapper import wrap_exceptions
from .msal_transport_adapter import MsalTransportAdapter
from .._internal import get_default_authority
from .._internal import get_default_authority, normalize_authority

try:
ABC = abc.ABC
Expand All @@ -37,8 +37,10 @@ class MsalCredential(ABC):
def __init__(self, client_id, client_credential=None, **kwargs):
# type: (str, Optional[Union[str, Mapping[str, str]]], **Any) -> None
tenant_id = kwargs.pop("tenant_id", "organizations")
authority = kwargs.pop("authority", None) or get_default_authority()
self._base_url = "https://" + "/".join((authority.strip("/"), tenant_id.strip("/")))
authority = kwargs.pop("authority", None)
authority = normalize_authority(authority) if authority else get_default_authority()

self._base_url = "/".join((authority, tenant_id.strip("/")))
self._client_credential = client_credential
self._client_id = client_id

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@
import sys

from msal import TokenCache
from six.moves.urllib_parse import urlparse

from .. import CredentialUnavailableError
from .._constants import KnownAuthorities
from .._internal import get_default_authority
from .._internal import get_default_authority, normalize_authority

try:
ABC = abc.ABC
Expand Down Expand Up @@ -87,8 +88,11 @@ class SharedTokenCacheBase(ABC):
def __init__(self, username=None, **kwargs): # pylint:disable=unused-argument
# type: (Optional[str], **Any) -> None

self._authority = kwargs.pop("authority", None) or get_default_authority()
self._authority_aliases = KNOWN_ALIASES.get(self._authority) or frozenset((self._authority,))
authority = kwargs.pop("authority", None)
self._authority = normalize_authority(authority) if authority else get_default_authority()

environment = urlparse(self._authority).netloc
self._environment_aliases = KNOWN_ALIASES.get(environment) or frozenset((environment,))
self._username = username
self._tenant_id = kwargs.pop("tenant_id", None)

Expand Down Expand Up @@ -125,7 +129,7 @@ def _get_cache_items_for_authority(self, credential_type):
items = []
for item in self._cache.find(credential_type):
environment = item.get("environment")
if environment in self._authority_aliases:
if environment in self._environment_aliases:
items.append(item)
return items

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from typing import TYPE_CHECKING

from ..._constants import EnvironmentVariables
from ..._internal import get_default_authority
from ..._internal import get_default_authority, normalize_authority
from .azure_cli import AzureCliCredential
from .chained import ChainedTokenCredential
from .environment import EnvironmentCredential
Expand Down Expand Up @@ -53,7 +53,8 @@ class DefaultAzureCredential(ChainedTokenCredential):
"""

def __init__(self, **kwargs):
authority = kwargs.pop("authority", None) or get_default_authority()
authority = kwargs.pop("authority", None)
authority = normalize_authority(authority) if authority else get_default_authority()

shared_cache_username = kwargs.pop("shared_cache_username", os.environ.get(EnvironmentVariables.AZURE_USERNAME))
shared_cache_tenant_id = kwargs.pop(
Expand Down
14 changes: 8 additions & 6 deletions sdk/identity/azure-identity/tests/test_aad_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,15 +99,17 @@ def assert_secrets_not_exposed():
assert_secrets_not_exposed()


def test_request_url():
authority = "authority.com"
@pytest.mark.parametrize("authority", ("localhost", "https://localhost"))
def test_request_url(authority):
tenant_id = "expected_tenant"
parsed_authority = urlparse(authority)
expected_netloc = parsed_authority.netloc or authority # "localhost" parses to netloc "", path "localhost"

def send(request, **_):
scheme, netloc, path, _, _, _ = urlparse(request.url)
assert scheme == "https"
assert netloc == authority
assert path.startswith("/" + tenant_id)
actual = urlparse(request.url)
assert actual.scheme == "https"
assert actual.netloc == expected_netloc
assert actual.path.startswith("/" + tenant_id)
return mock_response(json_payload={"token_type": "Bearer", "expires_in": 42, "access_token": "***"})

client = AadClient(tenant_id, "client id", transport=Mock(send=send), authority=authority)
Expand Down
20 changes: 11 additions & 9 deletions sdk/identity/azure-identity/tests/test_aad_client_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,24 +43,26 @@ async def test_uses_msal_correctly():


@pytest.mark.asyncio
async def test_request_url():
authority = "authority.com"
tenant = "expected_tenant"
@pytest.mark.parametrize("authority", ("localhost", "https://localhost"))
async def test_request_url(authority):
tenant_id = "expected_tenant"
parsed_authority = urlparse(authority)
expected_netloc = parsed_authority.netloc or authority # "localhost" parses to netloc "", path "localhost"

async def send(request, **_):
scheme, netloc, path, _, _, _ = urlparse(request.url)
assert scheme == "https"
assert netloc == authority
assert path.startswith("/" + tenant)
actual = urlparse(request.url)
assert actual.scheme == "https"
assert actual.netloc == expected_netloc
assert actual.path.startswith("/" + tenant_id)
return mock_response(json_payload={"token_type": "Bearer", "expires_in": 42, "access_token": "***"})

client = AadClient(tenant, "client id", transport=Mock(send=send), authority=authority)
client = AadClient(tenant_id, "client id", transport=Mock(send=send), authority=authority)

await client.obtain_token_by_authorization_code("code", "uri", "scope")
await client.obtain_token_by_refresh_token("refresh token", "scope")

# authority can be configured via environment variable
with patch.dict("os.environ", {EnvironmentVariables.AZURE_AUTHORITY_HOST: authority}, clear=True):
client = AadClient(tenant_id=tenant, client_id="client id", transport=Mock(send=send))
client = AadClient(tenant_id=tenant_id, client_id="client id", transport=Mock(send=send))
await client.obtain_token_by_authorization_code("code", "uri", "scope")
await client.obtain_token_by_refresh_token("refresh token", "scope")
21 changes: 12 additions & 9 deletions sdk/identity/azure-identity/tests/test_authn_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from azure.core.credentials import AccessToken
from azure.identity._authn_client import AuthnClient
from azure.identity._constants import EnvironmentVariables
import pytest
from six.moves.urllib_parse import urlparse
from helpers import mock_response

Expand Down Expand Up @@ -205,28 +206,30 @@ def mock_send(request, **kwargs):
assert not client.get_cached_token([scope_a, scope_b])


def test_request_url():
authority = "localhost"
tenant = "expected_tenant"
@pytest.mark.parametrize("authority", ("localhost", "https://localhost"))
def test_request_url(authority):
tenant_id = "expected_tenant"
parsed_authority = urlparse(authority)
expected_netloc = parsed_authority.netloc or authority # "localhost" parses to netloc "", path "localhost"

def validate_url(url):
scheme, netloc, path, _, _, _ = urlparse(url)
assert scheme == "https"
assert netloc == authority
assert path.startswith("/" + tenant)
actual = urlparse(url)
assert actual.scheme == "https"
assert actual.netloc == expected_netloc
assert actual.path.startswith("/" + tenant_id)

def mock_send(request, **kwargs):
validate_url(request.url)
return mock_response(json_payload={"token_type": "Bearer", "expires_in": 42, "access_token": "***"})

client = AuthnClient(tenant=tenant, transport=Mock(send=mock_send), authority=authority)
client = AuthnClient(tenant=tenant_id, transport=Mock(send=mock_send), authority=authority)
client.request_token(("scope",))
request = client.get_refresh_token_grant_request({"secret": "***"}, "scope")
validate_url(request.url)

# authority can be configured via environment variable
with patch.dict("os.environ", {EnvironmentVariables.AZURE_AUTHORITY_HOST: authority}, clear=True):
client = AuthnClient(tenant=tenant, transport=Mock(send=mock_send))
client = AuthnClient(tenant=tenant_id, transport=Mock(send=mock_send))
client.request_token(("scope",))
request = client.get_refresh_token_grant_request({"secret": "***"}, "scope")
validate_url(request.url)
22 changes: 12 additions & 10 deletions sdk/identity/azure-identity/tests/test_authn_client_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,23 @@


@pytest.mark.asyncio
async def test_request_url():
authority = "authority.com"
tenant = "expected_tenant"
@pytest.mark.parametrize("authority", ("localhost", "https://localhost"))
async def test_request_url(authority):
tenant_id = "expected_tenant"
parsed_authority = urlparse(authority)
expected_netloc = parsed_authority.netloc or authority # "localhost" parses to netloc "", path "localhost"

def mock_send(request, **kwargs):
scheme, netloc, path, _, _, _ = urlparse(request.url)
assert scheme == "https"
assert netloc == authority
assert path.startswith("/" + tenant)
return mock_response(json_payload={"token_type": "Bearer", "expires_in": 42, "access_token": "***"})
actual = urlparse(request.url)
assert actual.scheme == "https"
assert actual.netloc == expected_netloc
assert actual.path.startswith("/" + tenant_id)
return mock_response(json_payload={"token_type": "Bearer", "expires_in": 42, "access_token": "*"})

client = AsyncAuthnClient(tenant=tenant, transport=Mock(send=wrap_in_future(mock_send)), authority=authority)
client = AsyncAuthnClient(tenant=tenant_id, transport=Mock(send=wrap_in_future(mock_send)), authority=authority)
await client.request_token(("scope",))

# authority can be configured via environment variable
with patch.dict("os.environ", {EnvironmentVariables.AZURE_AUTHORITY_HOST: authority}, clear=True):
client = AsyncAuthnClient(tenant=tenant, transport=Mock(send=wrap_in_future(mock_send)))
client = AsyncAuthnClient(tenant=tenant_id, transport=Mock(send=wrap_in_future(mock_send)))
await client.request_token(("scope",))
19 changes: 10 additions & 9 deletions sdk/identity/azure-identity/tests/test_certificate_credential.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,20 +63,21 @@ def test_user_agent():
credential.get_token("scope")


@pytest.mark.parametrize("authority", ("localhost", "https://localhost"))
@pytest.mark.parametrize("cert_path,cert_password", BOTH_CERTS)
def test_request_url(cert_path, cert_password):
authority = "authority.com"
def test_request_url(cert_path, cert_password, authority):
"""the credential should accept an authority, with or without scheme, as an argument or environment variable"""

tenant_id = "expected_tenant"
access_token = "***"

def validate_url(url):
parsed = urlparse(url)
assert parsed.scheme == "https"
assert parsed.netloc == authority
assert parsed.path.startswith("/" + tenant_id)
parsed_authority = urlparse(authority)
expected_netloc = parsed_authority.netloc or authority # "localhost" parses to netloc "", path "localhost"

def mock_send(request, **kwargs):
validate_url(request.url)
actual = urlparse(request.url)
assert actual.scheme == "https"
assert actual.netloc == expected_netloc
assert actual.path.startswith("/" + tenant_id)
return mock_response(json_payload={"token_type": "Bearer", "expires_in": 42, "access_token": access_token})

cred = CertificateCredential(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,20 +77,21 @@ async def test_user_agent():


@pytest.mark.asyncio
@pytest.mark.parametrize("authority", ("localhost", "https://localhost"))
@pytest.mark.parametrize("cert_path,cert_password", BOTH_CERTS)
async def test_request_url(cert_path, cert_password):
authority = "authority.com"
async def test_request_url(cert_path, cert_password, authority):
"""the credential should accept an authority, with or without scheme, as an argument or environment variable"""

tenant_id = "expected_tenant"
access_token = "***"

def validate_url(url):
parsed = urlparse(url)
assert parsed.scheme == "https"
assert parsed.netloc == authority
assert parsed.path.startswith("/" + tenant_id)
parsed_authority = urlparse(authority)
expected_netloc = parsed_authority.netloc or authority # "localhost" parses to netloc "", path "localhost"

async def mock_send(request, **kwargs):
validate_url(request.url)
actual = urlparse(request.url)
assert actual.scheme == "https"
assert actual.netloc == expected_netloc
assert actual.path.startswith("/" + tenant_id)
return mock_response(json_payload={"token_type": "Bearer", "expires_in": 42, "access_token": access_token})

cred = CertificateCredential(
Expand Down
Loading