diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index f443bf34ef8..701658bf4e4 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -14,7 +14,7 @@ /src/azure-cli-core/ @jiasli @evelyn-ys @jsntcy @kairu-ms @zhoxing-ms /src/azure-cli-core/azure/cli/core/_profile.py @jiasli @evelyn-ys -/src/azure-cli-core/azure/cli/core/adal_authentication.py @jiasli @evelyn-ys +/src/azure-cli-core/azure/cli/core/auth/ @jiasli /src/azure-cli-core/azure/cli/core/extension/ @jsntcy @kairu-ms /src/azure-cli-core/azure/cli/core/msal_authentication.py @jiasli @evelyn-ys /src/azure-cli-core/azure/cli/core/style.py @jiasli @evelyn-ys @zhoxing-ms diff --git a/scripts/ci/credscan/CredScanSuppressions.json b/scripts/ci/credscan/CredScanSuppressions.json index d21f94a5d22..6dc082be08e 100644 --- a/scripts/ci/credscan/CredScanSuppressions.json +++ b/scripts/ci/credscan/CredScanSuppressions.json @@ -412,9 +412,13 @@ "_justification": "[AppService] Test certs" }, { - "file": "src\\azure-cli-core\\azure\\cli\\core\\tests\\sp_cert.pem", + "file": "src\\azure-cli-core\\azure\\cli\\core\\auth\\tests\\sp_cert.pem", "_justification": "[Core] Test certs" }, + { + "placeholder": "test_secret", + "_justification": "[Core] Test secret" + }, { "placeholder": "0abf356884d74b4aacbd7b1ebd3da0f7", "_justification": "[AMS] hard code accessToken in test_ams_live_event_scenarios.py" diff --git a/src/azure-cli-core/azure/cli/core/_debug.py b/src/azure-cli-core/azure/cli/core/_debug.py index e66b5b1c386..b8028791ce9 100644 --- a/src/azure-cli-core/azure/cli/core/_debug.py +++ b/src/azure-cli-core/azure/cli/core/_debug.py @@ -45,8 +45,3 @@ def change_ssl_cert_verification_track2(): logger.debug("Using CA bundle file at '%s'.", ca_bundle_file) client_kwargs['connection_verify'] = ca_bundle_file return client_kwargs - - -def allow_debug_adal_connection(): - if should_disable_connection_verify(): - os.environ[ADAL_PYTHON_SSL_NO_VERIFY] = '1' diff --git a/src/azure-cli-core/azure/cli/core/_profile.py b/src/azure-cli-core/azure/cli/core/_profile.py index 9d34820dc38..fd948429a67 100644 --- a/src/azure-cli-core/azure/cli/core/_profile.py +++ b/src/azure-cli-core/azure/cli/core/_profile.py @@ -3,24 +3,20 @@ # Licensed under the MIT License. See License.txt in the project root for license information. # -------------------------------------------------------------------------------------------- -import collections -import errno -import json import os import os.path -import re -import string +import sys from copy import deepcopy from enum import Enum -from knack.log import get_logger -from knack.util import CLIError - -from azure.cli.core._environment import get_config_dir from azure.cli.core._session import ACCOUNT -from azure.cli.core.util import get_file_json, in_cloud_console, open_page_in_browser, can_launch_browser,\ - is_windows, is_wsl, scopes_to_resource, resource_to_scopes +from azure.cli.core.auth.identity import Identity, AZURE_CLI_CLIENT_ID +from azure.cli.core.auth.util import resource_to_scopes +from azure.cli.core.azclierror import AuthenticationError from azure.cli.core.cloud import get_active_cloud, set_cloud_subscription +from azure.cli.core.util import in_cloud_console, can_launch_browser +from knack.log import get_logger +from knack.util import CLIError logger = get_logger(__name__) @@ -39,6 +35,7 @@ _MANAGED_BY_TENANTS = 'managedByTenants' _USER_ENTITY = 'user' _USER_NAME = 'name' +_CLIENT_ID = 'clientId' _CLOUD_SHELL_ID = 'cloudShellID' _SUBSCRIPTIONS = 'subscriptions' _INSTALLATION_ID = 'installationId' @@ -47,25 +44,9 @@ _USER_TYPE = 'type' _USER = 'user' _SERVICE_PRINCIPAL = 'servicePrincipal' -_SERVICE_PRINCIPAL_ID = 'servicePrincipalId' -_SERVICE_PRINCIPAL_TENANT = 'servicePrincipalTenant' -_SERVICE_PRINCIPAL_CERT_FILE = 'certificateFile' -_SERVICE_PRINCIPAL_CERT_THUMBPRINT = 'thumbprint' _SERVICE_PRINCIPAL_CERT_SN_ISSUER_AUTH = 'useCertSNIssuerAuth' _TOKEN_ENTRY_USER_ID = 'userId' _TOKEN_ENTRY_TOKEN_TYPE = 'tokenType' -# This could mean either real access token, or client secret of a service principal -# This naming is no good, but can't change because xplat-cli does so. -_ACCESS_TOKEN = 'accessToken' -_REFRESH_TOKEN = 'refreshToken' - -TOKEN_FIELDS_EXCLUDED_FROM_PERSISTENCE = ['familyName', - 'givenName', - 'isUserIdDisplayable', - 'tenantId'] - -_CLIENT_ID = '04b07795-8ddb-461a-bbee-02f9e1bf7b46' -_COMMON_TENANT = 'common' _TENANT_LEVEL_ACCOUNT_NAME = 'N/A(tenant level account)' @@ -75,8 +56,6 @@ _AZ_LOGIN_MESSAGE = "Please run 'az login' to setup account." -_USE_VENDORED_SUBSCRIPTION_SDK = False - def load_subscriptions(cli_ctx, all_clouds=False, refresh=False): profile = Profile(cli_ctx=cli_ctx) @@ -86,47 +65,23 @@ def load_subscriptions(cli_ctx, all_clouds=False, refresh=False): return subscriptions -def _get_authority_url(cli_ctx, tenant): - authority_url = cli_ctx.cloud.endpoints.active_directory - is_adfs = bool(re.match('.+(/adfs|/adfs/)$', authority_url, re.I)) - if is_adfs: - authority_url = authority_url.rstrip('/') # workaround: ADAL is known to reject auth urls with trailing / - else: - authority_url = authority_url.rstrip('/') + '/' + (tenant or _COMMON_TENANT) - return authority_url, is_adfs - - -def _authentication_context_factory(cli_ctx, tenant, cache): - import adal - authority_url, is_adfs = _get_authority_url(cli_ctx, tenant) - return adal.AuthenticationContext(authority_url, cache=cache, api_version=None, validate_authority=(not is_adfs)) - - -_AUTH_CTX_FACTORY = _authentication_context_factory - - -def _load_tokens_from_file(file_path): - if os.path.isfile(file_path): - try: - return get_file_json(file_path, throw_on_empty=False) or [] - except (CLIError, ValueError) as ex: - raise CLIError("Failed to load token files. If you have a repro, please log an issue at " - "https://github.com/Azure/azure-cli/issues. At the same time, you can clean " - "up by running 'az account clear' and then 'az login'. (Inner Error: {})".format(ex)) - logger.debug("'%s' is not a file or doesn't exist.", file_path) - return [] +def _detect_adfs_authority(authority_url, tenant): + """Prepare authority and tenant for Azure Identity with ADFS support. + If `authority_url` ends with '/adfs', `tenant` will be set to 'adfs'. For example: + 'https://adfs.redmond.azurestack.corp.microsoft.com/adfs' + -> ('https://adfs.redmond.azurestack.corp.microsoft.com/', 'adfs') + """ + authority_url = authority_url.rstrip('/') + if authority_url.endswith('/adfs'): + authority_url = authority_url[:-len('/adfs')] + # The custom tenant is discarded in ADFS environment + tenant = 'adfs' -def _delete_file(file_path): - try: - os.remove(file_path) - except OSError as e: - if e.errno != errno.ENOENT: - raise + return authority_url, tenant def get_credential_types(cli_ctx): - class CredentialType(Enum): # pylint: disable=too-few-public-methods cloud = get_active_cloud(cli_ctx) management = cli_ctx.cloud.endpoints.management @@ -139,185 +94,118 @@ def _get_cloud_console_token_endpoint(): return os.environ.get('MSI_ENDPOINT') -# pylint: disable=too-many-lines,too-many-instance-attributes +def _attach_token_tenant(subscription, tenant): + """Attach the token tenant ID to the subscription as tenant_id, so that CLI knows which token should be used + to access the subscription. + + This function supports multiple APIs: + - v2016_06_01's Subscription doesn't have tenant_id + - v2019_11_01's Subscription has tenant_id representing the home tenant ID. It will mapped to home_tenant_id + """ + if hasattr(subscription, "tenant_id"): + setattr(subscription, 'home_tenant_id', subscription.tenant_id) + setattr(subscription, 'tenant_id', tenant) + + +# pylint: disable=too-many-lines,too-many-instance-attributes,unused-argument class Profile: - _global_creds_cache = None + def __init__(self, cli_ctx=None, storage=None): + """Class to manage CLI's accounts (profiles) and identities (credentials). - def __init__(self, storage=None, auth_ctx_factory=None, use_global_creds_cache=True, - async_persist=True, cli_ctx=None): + :param cli_ctx: The CLI context + :param storage: A dict to store accounts, by default persisted to ~/.azure/azureProfile.json as JSON + """ from azure.cli.core import get_default_cli self.cli_ctx = cli_ctx or get_default_cli() self._storage = storage or ACCOUNT - self.auth_ctx_factory = auth_ctx_factory or _AUTH_CTX_FACTORY - - if use_global_creds_cache: - # for perf, use global cache - if not Profile._global_creds_cache: - Profile._global_creds_cache = CredsCache(self.cli_ctx, self.auth_ctx_factory, - async_persist=async_persist) - self._creds_cache = Profile._global_creds_cache - else: - self._creds_cache = CredsCache(self.cli_ctx, self.auth_ctx_factory, async_persist=async_persist) - - self._management_resource_uri = self.cli_ctx.cloud.endpoints.management - self._ad_resource_uri = self.cli_ctx.cloud.endpoints.active_directory_resource_id - self._ad = self.cli_ctx.cloud.endpoints.active_directory - self._msi_creds = None - - def find_subscriptions_on_login(self, - interactive, - username, - password, - is_service_principal, - tenant, - scopes=None, - use_device_code=False, - allow_no_subscriptions=False, - subscription_finder=None, - use_cert_sn_issuer=None): - from azure.cli.core._debug import allow_debug_adal_connection - allow_debug_adal_connection() - subscriptions = [] - - if scopes: - auth_resource = scopes_to_resource(scopes) - else: - auth_resource = self._ad_resource_uri + self._authority = self.cli_ctx.cloud.endpoints.active_directory + self._arm_scope = resource_to_scopes(self.cli_ctx.cloud.endpoints.active_directory_resource_id) + + # Only enable token cache encryption for Windows (for now) + token_encryption_fallback = sys.platform.startswith('win32') + Identity.token_encryption = self.cli_ctx.config.getboolean('core', 'token_encryption', + fallback=token_encryption_fallback) + + # pylint: disable=too-many-branches,too-many-statements,too-many-locals + def login(self, + interactive, + username, + password, + is_service_principal, + tenant, + scopes=None, + client_id=AZURE_CLI_CLIENT_ID, + use_device_code=False, + allow_no_subscriptions=False, + use_cert_sn_issuer=None, + **kwargs): + """ + For service principal, `password` is a dict returned by ServicePrincipalAuth.build_credential + """ + if not scopes: + scopes = self._arm_scope - if not subscription_finder: - subscription_finder = SubscriptionFinder(self.cli_ctx, - self.auth_ctx_factory, - self._creds_cache.adal_token_cache) + # For ADFS, auth_tenant is 'adfs' + # https://github.com/Azure/azure-sdk-for-python/blob/661cd524e88f480c14220ed1f86de06aaff9a977/sdk/identity/azure-identity/CHANGELOG.md#L19 + authority, auth_tenant = _detect_adfs_authority(self.cli_ctx.cloud.endpoints.active_directory, tenant) + identity = Identity(authority=authority, tenant_id=auth_tenant, client_id=client_id) + + user_identity = None if interactive: - if not use_device_code and (in_cloud_console() or not can_launch_browser()): - logger.info('Detect no GUI is available, so fall back to device code') + if not use_device_code and not can_launch_browser(): + logger.info('No web browser is available. Fall back to device code.') use_device_code = True - if not use_device_code: - try: - authority_url, _ = _get_authority_url(self.cli_ctx, tenant) - subscriptions = subscription_finder.find_through_authorization_code_flow( - tenant, self._ad_resource_uri, authority_url, auth_resource=auth_resource) - except RuntimeError: - use_device_code = True - logger.warning('Not able to launch a browser to log you in, falling back to device code...') - if use_device_code: - subscriptions = subscription_finder.find_through_interactive_flow( - tenant, self._ad_resource_uri, auth_resource=auth_resource) + user_identity = identity.login_with_device_code(scopes=scopes, **kwargs) + else: + user_identity = identity.login_with_auth_code(scopes=scopes, **kwargs) else: - if is_service_principal: - if not tenant: - raise CLIError('Please supply tenant using "--tenant"') - sp_auth = ServicePrincipalAuth(password, use_cert_sn_issuer) - subscriptions = subscription_finder.find_from_service_principal_id( - username, sp_auth, tenant, self._ad_resource_uri) - + if not is_service_principal: + user_identity = identity.login_with_username_password(username, password, scopes=scopes, **kwargs) else: - subscriptions = subscription_finder.find_from_user_account( - username, password, tenant, self._ad_resource_uri) + identity.login_with_service_principal(username, password, scopes=scopes) - if not allow_no_subscriptions and not subscriptions: - if username: - msg = "No subscriptions found for {}.".format(username) - else: - # Don't show username if bare 'az login' is used - msg = "No subscriptions found." - raise CLIError(msg) + # We have finished login. Let's find all subscriptions. + if user_identity: + username = user_identity['username'] + + subscription_finder = SubscriptionFinder(self.cli_ctx) + + # Create credentials + if user_identity: + credential = identity.get_user_credential(username) + else: + credential = identity.get_service_principal_credential(username) + + if tenant: + subscriptions = subscription_finder.find_using_specific_tenant(tenant, credential) + else: + subscriptions = subscription_finder.find_using_common_tenant(username, credential) - if is_service_principal: - self._creds_cache.save_service_principal_cred(sp_auth.get_entry_to_persist(username, - tenant)) - if self._creds_cache.adal_token_cache.has_state_changed: - self._creds_cache.persist_cached_creds() + if not subscriptions and not allow_no_subscriptions: + raise CLIError("No subscriptions found for {}.".format(username)) if allow_no_subscriptions: t_list = [s.tenant_id for s in subscriptions] bare_tenants = [t for t in subscription_finder.tenants if t not in t_list] - profile = Profile(cli_ctx=self.cli_ctx) - tenant_accounts = profile._build_tenant_level_accounts(bare_tenants) # pylint: disable=protected-access + tenant_accounts = self._build_tenant_level_accounts(bare_tenants) subscriptions.extend(tenant_accounts) if not subscriptions: return [] - consolidated = self._normalize_properties(subscription_finder.user_id, subscriptions, + consolidated = self._normalize_properties(username, subscriptions, is_service_principal, bool(use_cert_sn_issuer)) self._set_subscriptions(consolidated) - # use deepcopy as we don't want to persist these changes to file. return deepcopy(consolidated) - def _normalize_properties(self, user, subscriptions, is_service_principal, cert_sn_issuer_auth=None, - user_assigned_identity_id=None): - import sys - consolidated = [] - for s in subscriptions: - display_name = s.display_name - if display_name is None: - display_name = '' - try: - display_name.encode(sys.getdefaultencoding()) - except (UnicodeEncodeError, UnicodeDecodeError): # mainly for Python 2.7 with ascii as the default encoding - display_name = re.sub(r'[^\x00-\x7f]', lambda x: '?', display_name) - - subscription_dict = { - _SUBSCRIPTION_ID: s.id.rpartition('/')[2], - _SUBSCRIPTION_NAME: display_name, - _STATE: s.state, - _USER_ENTITY: { - _USER_NAME: user, - _USER_TYPE: _SERVICE_PRINCIPAL if is_service_principal else _USER - }, - _IS_DEFAULT_SUBSCRIPTION: False, - _TENANT_ID: s.tenant_id, - _ENVIRONMENT_NAME: self.cli_ctx.cloud.name - } - - if subscription_dict[_SUBSCRIPTION_NAME] != _TENANT_LEVEL_ACCOUNT_NAME: - _transform_subscription_for_multiapi(s, subscription_dict) - - consolidated.append(subscription_dict) - - if cert_sn_issuer_auth: - consolidated[-1][_USER_ENTITY][_SERVICE_PRINCIPAL_CERT_SN_ISSUER_AUTH] = True - if user_assigned_identity_id: - consolidated[-1][_USER_ENTITY][_ASSIGNED_IDENTITY_INFO] = user_assigned_identity_id - return consolidated - - def _build_tenant_level_accounts(self, tenants): - result = [] - for t in tenants: - s = self._new_account() - s.id = '/subscriptions/' + t - s.subscription = t - s.tenant_id = t - s.display_name = _TENANT_LEVEL_ACCOUNT_NAME - result.append(s) - return result - - def _new_account(self): - """Build an empty Subscription which will be used as a tenant account. - API version doesn't matter as only specified attributes are preserved by _normalize_properties.""" - if _USE_VENDORED_SUBSCRIPTION_SDK: - # pylint: disable=no-name-in-module, import-error - from azure.cli.core.vendored_sdks.subscriptions.models import Subscription - SubscriptionType = Subscription - else: - from azure.cli.core.profiles import ResourceType, get_sdk - SubscriptionType = get_sdk(self.cli_ctx, ResourceType.MGMT_RESOURCE_SUBSCRIPTIONS, - 'Subscription', mod='models') - s = SubscriptionType() - s.state = 'Enabled' - return s - - def find_subscriptions_in_vm_with_msi(self, identity_id=None, allow_no_subscriptions=None): - # pylint: disable=too-many-statements - + def login_with_managed_identity(self, identity_id=None, allow_no_subscriptions=None): import jwt from msrestazure.tools import is_valid_resource_id - from azure.cli.core.adal_authentication import MSIAuthenticationWrapper + from azure.cli.core.auth.adal_authentication import MSIAuthenticationWrapper resource = self.cli_ctx.cloud.endpoints.active_directory_resource_id if identity_id: @@ -361,8 +249,8 @@ def find_subscriptions_in_vm_with_msi(self, identity_id=None, allow_no_subscript decode = jwt.decode(token, algorithms=['RS256'], options={"verify_signature": False}) tenant = decode['tid'] - subscription_finder = SubscriptionFinder(self.cli_ctx, self.auth_ctx_factory, None) - subscriptions = subscription_finder.find_from_raw_token(tenant, token) + subscription_finder = SubscriptionFinder(self.cli_ctx) + subscriptions = subscription_finder.find_using_specific_tenant(tenant, msi_creds) base_name = ('{}-{}'.format(identity_type, identity_id) if identity_id else identity_type) user = _USER_ASSIGNED_IDENTITY if identity_id else _SYSTEM_ASSIGNED_IDENTITY if not subscriptions: @@ -377,16 +265,19 @@ def find_subscriptions_in_vm_with_msi(self, identity_id=None, allow_no_subscript self._set_subscriptions(consolidated) return deepcopy(consolidated) - def find_subscriptions_in_cloud_console(self): + def login_in_cloud_shell(self): import jwt + from azure.cli.core.auth.adal_authentication import MSIAuthenticationWrapper - _, token, _ = self._get_token_from_cloud_shell(self.cli_ctx.cloud.endpoints.active_directory_resource_id) + msi_creds = MSIAuthenticationWrapper(resource=self.cli_ctx.cloud.endpoints.active_directory_resource_id) + token_entry = msi_creds.token + token = token_entry['access_token'] logger.info('MSI: token was retrieved. Now trying to initialize local accounts...') decode = jwt.decode(token, algorithms=['RS256'], options={"verify_signature": False}) tenant = decode['tid'] - subscription_finder = SubscriptionFinder(self.cli_ctx, self.auth_ctx_factory, None) - subscriptions = subscription_finder.find_from_raw_token(tenant, token) + subscription_finder = SubscriptionFinder(self.cli_ctx) + subscriptions = subscription_finder.find_using_specific_tenant(tenant, msi_creds) if not subscriptions: raise CLIError('No subscriptions were found in the cloud shell') user = decode.get('unique_name', 'N/A') @@ -397,12 +288,172 @@ def find_subscriptions_in_cloud_console(self): self._set_subscriptions(consolidated) return deepcopy(consolidated) - def _get_token_from_cloud_shell(self, resource): # pylint: disable=no-self-use - from azure.cli.core.adal_authentication import MSIAuthenticationWrapper - auth = MSIAuthenticationWrapper(resource=resource) - auth.set_token() - token_entry = auth.token - return (token_entry['token_type'], token_entry['access_token'], token_entry) + def logout(self, user_or_sp): + subscriptions = self.load_cached_subscriptions(all_clouds=True) + result = [x for x in subscriptions + if user_or_sp.lower() == x[_USER_ENTITY][_USER_NAME].lower()] + subscriptions = [x for x in subscriptions if x not in result] + self._storage[_SUBSCRIPTIONS] = subscriptions + + identity = Identity(self._authority) + identity.logout_user(user_or_sp) + identity.logout_service_principal(user_or_sp) + + def logout_all(self): + self._storage[_SUBSCRIPTIONS] = [] + + identity = Identity(self._authority) + identity.logout_all_users() + identity.logout_all_service_principal() + + def get_login_credentials(self, resource=None, client_id=None, subscription_id=None, aux_subscriptions=None, + aux_tenants=None): + """Get a CredentialAdaptor instance to be used with both Track 1 and Track 2 SDKs. + + :param resource: The resource ID to acquire an access token. Only provide it for Track 1 SDKs. + :param client_id: + :param subscription_id: + :param aux_subscriptions: + :param aux_tenants: + """ + resource = resource or self.cli_ctx.cloud.endpoints.active_directory_resource_id + + if aux_tenants and aux_subscriptions: + raise CLIError("Please specify only one of aux_subscriptions and aux_tenants, not both") + + account = self.get_subscription(subscription_id) + + managed_identity_type, managed_identity_id = Profile._try_parse_msi_account_name(account) + + # Cloud Shell is just a system assignment managed identity + if in_cloud_console() and account[_USER_ENTITY].get(_CLOUD_SHELL_ID): + managed_identity_type = MsiAccountTypes.system_assigned + + if managed_identity_type is None: + # user and service principal + external_tenants = [] + if aux_tenants: + external_tenants = [tenant for tenant in aux_tenants if tenant != account[_TENANT_ID]] + if aux_subscriptions: + ext_subs = [aux_sub for aux_sub in aux_subscriptions if aux_sub != subscription_id] + for ext_sub in ext_subs: + sub = self.get_subscription(ext_sub) + if sub[_TENANT_ID] != account[_TENANT_ID]: + external_tenants.append(sub[_TENANT_ID]) + + credential = self._create_credential(account, client_id=client_id) + external_credentials = [] + for external_tenant in external_tenants: + external_credentials.append(self._create_credential(account, external_tenant, client_id=client_id)) + from azure.cli.core.auth.credential_adaptor import CredentialAdaptor + cred = CredentialAdaptor(credential, + auxiliary_credentials=external_credentials, + resource=resource) + else: + # managed identity + cred = MsiAccountTypes.msi_auth_factory(managed_identity_type, managed_identity_id, resource) + return (cred, + str(account[_SUBSCRIPTION_ID]), + str(account[_TENANT_ID])) + + def get_raw_token(self, resource=None, scopes=None, subscription=None, tenant=None): + # Convert resource to scopes + if resource and not scopes: + scopes = resource_to_scopes(resource) + + # Use ARM as the default scopes + if not scopes: + scopes = resource_to_scopes(self.cli_ctx.cloud.endpoints.active_directory_resource_id) + + if subscription and tenant: + raise CLIError("Please specify only one of subscription and tenant, not both") + + account = self.get_subscription(subscription) + resource = resource or self.cli_ctx.cloud.endpoints.active_directory_resource_id + + identity_type, identity_id = Profile._try_parse_msi_account_name(account) + if identity_type: + # MSI + if tenant: + raise CLIError("Tenant shouldn't be specified for MSI account") + msi_creds = MsiAccountTypes.msi_auth_factory(identity_type, identity_id, resource) + msi_creds.set_token() + token_entry = msi_creds.token + creds = (token_entry['token_type'], token_entry['access_token'], token_entry) + elif in_cloud_console() and account[_USER_ENTITY].get(_CLOUD_SHELL_ID): + # Cloud Shell + if tenant: + raise CLIError("Tenant shouldn't be specified for Cloud Shell account") + creds = self._get_token_from_cloud_shell(resource) + else: + credential = self._create_credential(account, tenant) + token = credential.get_token(*scopes) + + import datetime + expiresOn = datetime.datetime.fromtimestamp(token.expires_on).strftime("%Y-%m-%d %H:%M:%S.%f") + + token_entry = { + 'accessToken': token.token, + 'expires_on': token.expires_on, + 'expiresOn': expiresOn + } + + # (tokenType, accessToken, tokenEntry) + creds = 'Bearer', token.token, token_entry + # (cred, subscription, tenant) + return (creds, + None if tenant else str(account[_SUBSCRIPTION_ID]), + str(tenant if tenant else account[_TENANT_ID])) + + def _normalize_properties(self, user, subscriptions, is_service_principal, cert_sn_issuer_auth=None, + user_assigned_identity_id=None): + consolidated = [] + for s in subscriptions: + subscription_dict = { + _SUBSCRIPTION_ID: s.id.rpartition('/')[2], + _SUBSCRIPTION_NAME: s.display_name, + _STATE: s.state, + _USER_ENTITY: { + _USER_NAME: user, + _USER_TYPE: _SERVICE_PRINCIPAL if is_service_principal else _USER + }, + _IS_DEFAULT_SUBSCRIPTION: False, + _TENANT_ID: s.tenant_id, + _ENVIRONMENT_NAME: self.cli_ctx.cloud.name + } + + if subscription_dict[_SUBSCRIPTION_NAME] != _TENANT_LEVEL_ACCOUNT_NAME: + _transform_subscription_for_multiapi(s, subscription_dict) + + consolidated.append(subscription_dict) + + if cert_sn_issuer_auth: + consolidated[-1][_USER_ENTITY][_SERVICE_PRINCIPAL_CERT_SN_ISSUER_AUTH] = True + if user_assigned_identity_id: + consolidated[-1][_USER_ENTITY][_ASSIGNED_IDENTITY_INFO] = user_assigned_identity_id + + return consolidated + + def _build_tenant_level_accounts(self, tenants): + result = [] + for t in tenants: + s = self._new_account() + s.id = '/subscriptions/' + t + s.subscription = t + s.tenant_id = t + s.display_name = _TENANT_LEVEL_ACCOUNT_NAME + result.append(s) + return result + + def _new_account(self): + """Build an empty Subscription which will be used as a tenant account. + API version doesn't matter as only specified attributes are preserved by _normalize_properties.""" + from azure.cli.core.profiles import ResourceType, get_sdk + SubscriptionType = get_sdk(self.cli_ctx, ResourceType.MGMT_RESOURCE_SUBSCRIPTIONS, + 'Subscription', mod='models') + s = SubscriptionType() + s.state = 'Enabled' + return s def _set_subscriptions(self, new_subscriptions, merge=True, secondary_key_name=None): @@ -423,9 +474,9 @@ def _match_account(account, subscription_id, secondary_key_name, secondary_key_v # merge with existing ones if merge: - dic = collections.OrderedDict((_get_key_name(x, secondary_key_name), x) for x in existing_ones) + dic = {_get_key_name(x, secondary_key_name): x for x in existing_ones} else: - dic = collections.OrderedDict() + dic = {} dic.update((_get_key_name(x, secondary_key_name), x) for x in new_subscriptions) subscriptions = list(dic.values()) @@ -477,19 +528,6 @@ def set_active_subscription(self, subscription): # take id or name set_cloud_subscription(self.cli_ctx, active_cloud.name, result[0][_SUBSCRIPTION_ID]) self._storage[_SUBSCRIPTIONS] = subscriptions - def logout(self, user_or_sp): - subscriptions = self.load_cached_subscriptions(all_clouds=True) - result = [x for x in subscriptions - if user_or_sp.lower() == x[_USER_ENTITY][_USER_NAME].lower()] - subscriptions = [x for x in subscriptions if x not in result] - - self._storage[_SUBSCRIPTIONS] = subscriptions - self._creds_cache.remove_cached_creds(user_or_sp) - - def logout_all(self): - self._storage[_SUBSCRIPTIONS] = [] - self._creds_cache.remove_all_cached_creds() - def load_cached_subscriptions(self, all_clouds=False): subscriptions = self._storage.get(_SUBSCRIPTIONS) or [] active_cloud = self.cli_ctx.cloud @@ -528,12 +566,6 @@ def get_subscription(self, subscription=None): # take id or name def get_subscription_id(self, subscription=None): # take id or name return self.get_subscription(subscription)[_SUBSCRIPTION_ID] - def get_access_token_for_resource(self, username, tenant, resource): - tenant = tenant or 'common' - _, access_token, _ = self._creds_cache.retrieve_token_for_user( - username, tenant, resource) - return access_token - @staticmethod def _try_parse_msi_account_name(account): msi_info, user = account[_USER_ENTITY].get(_ASSIGNED_IDENTITY_INFO), account[_USER_ENTITY].get(_USER_NAME) @@ -546,191 +578,34 @@ def _try_parse_msi_account_name(account): return parts[0], (None if len(parts) <= 1 else parts[1]) return None, None - def get_login_credentials(self, resource=None, subscription_id=None, aux_subscriptions=None, aux_tenants=None): - if aux_tenants and aux_subscriptions: - raise CLIError("Please specify only one of aux_subscriptions and aux_tenants, not both") + def _create_credential(self, account, tenant_id=None, client_id=None): + """Create a credential object driven by MSAL - account = self.get_subscription(subscription_id) - user_type = account[_USER_ENTITY][_USER_TYPE] - username_or_sp_id = account[_USER_ENTITY][_USER_NAME] - resource = resource or self.cli_ctx.cloud.endpoints.active_directory_resource_id - - identity_type, identity_id = Profile._try_parse_msi_account_name(account) - - # Make sure external_tenants_info only contains real external tenant (no current tenant). - external_tenants_info = [] - if aux_tenants: - external_tenants_info = [tenant for tenant in aux_tenants if tenant != account[_TENANT_ID]] - if aux_subscriptions: - ext_subs = [aux_sub for aux_sub in aux_subscriptions if aux_sub != subscription_id] - for ext_sub in ext_subs: - sub = self.get_subscription(ext_sub) - if sub[_TENANT_ID] != account[_TENANT_ID]: - external_tenants_info.append(sub[_TENANT_ID]) - - if external_tenants_info and \ - (in_cloud_console() and account[_USER_ENTITY].get(_CLOUD_SHELL_ID) or identity_type): - raise CLIError("Cross-tenant authentication is not supported by managed identity and Cloud Shell account. " - "Please run `az login` with a user account or a service principal.") - - if identity_type is None: - def _retrieve_token(token_resource): - logger.debug("Retrieving token from ADAL for resource %r", token_resource) - - if in_cloud_console() and account[_USER_ENTITY].get(_CLOUD_SHELL_ID): - return self._get_token_from_cloud_shell(token_resource) - if user_type == _USER: - return self._creds_cache.retrieve_token_for_user(username_or_sp_id, - account[_TENANT_ID], token_resource) - use_cert_sn_issuer = account[_USER_ENTITY].get(_SERVICE_PRINCIPAL_CERT_SN_ISSUER_AUTH) - return self._creds_cache.retrieve_token_for_service_principal(username_or_sp_id, token_resource, - account[_TENANT_ID], - use_cert_sn_issuer) - - def _retrieve_tokens_from_external_tenants(token_resource): - logger.debug("Retrieving token from ADAL for external tenants and resource %r", token_resource) - - external_tokens = [] - for sub_tenant_id in external_tenants_info: - if user_type == _USER: - external_tokens.append(self._creds_cache.retrieve_token_for_user( - username_or_sp_id, sub_tenant_id, token_resource)) - else: - external_tokens.append(self._creds_cache.retrieve_token_for_service_principal( - username_or_sp_id, token_resource, sub_tenant_id, token_resource)) - return external_tokens - - from azure.cli.core.adal_authentication import AdalAuthentication - auth_object = AdalAuthentication(_retrieve_token, - _retrieve_tokens_from_external_tenants if external_tenants_info else None, - resource=resource) - else: - if self._msi_creds is None: - self._msi_creds = MsiAccountTypes.msi_auth_factory(identity_type, identity_id, resource) - auth_object = self._msi_creds - - return (auth_object, - str(account[_SUBSCRIPTION_ID]), - str(account[_TENANT_ID])) - - def get_msal_token(self, scopes, data): - """ - This is added only for vmssh feature. - It is a temporary solution and will deprecate after MSAL adopted completely. + :param account: + :param tenant_id: If not None, override tenantId from 'account' + :param client_id: + :return: """ - account = self.get_subscription() - identity_type = account[_USER_ENTITY][_USER_TYPE] + user_type = account[_USER_ENTITY][_USER_TYPE] username_or_sp_id = account[_USER_ENTITY][_USER_NAME] - tenant = account[_TENANT_ID] - - import posixpath - authority = posixpath.join(self.cli_ctx.cloud.endpoints.active_directory, tenant) - - # Raise error for managed identity and Cloud Shell - not_support_message = "VM SSH currently doesn't support {}." - - # managed identity - managed_identity_type, _ = Profile._try_parse_msi_account_name(account) - if managed_identity_type: - raise CLIError(not_support_message.format("managed identity")) - - # Cloud Shell - if in_cloud_console() and account[_USER_ENTITY].get(_CLOUD_SHELL_ID): - raise CLIError(not_support_message.format("Cloud Shell")) - - # user - if identity_type == _USER: - # Use ARM as resource to get the refresh token from ADAL token cache - resource = self.cli_ctx.cloud.endpoints.active_directory_resource_id - _, _, token_entry = self._creds_cache.retrieve_token_for_user( - username_or_sp_id, account[_TENANT_ID], resource) - refresh_token = token_entry.get(_REFRESH_TOKEN) - - from azure.cli.core.msal_authentication import UserCredential - cred = UserCredential(_CLIENT_ID, authority=authority) - result = cred.acquire_token_by_refresh_token(refresh_token, scopes, data=data) - - # In case of being rejected by Conditional Access, launch browser automatically to retry - # with VM SSH as resource. - if 'error' in result: - logger.warning(result['error_description']) - - token_entry = self._login_with_authorization_code_flow(tenant, scopes_to_resource(scopes)) - result = cred.acquire_token_by_refresh_token(token_entry['refreshToken'], scopes, data=data) + tenant_id = tenant_id if tenant_id else account[_TENANT_ID] + identity = Identity(client_id=client_id, authority=self._authority, tenant_id=tenant_id) - # service principal - elif identity_type == _SERVICE_PRINCIPAL: - from azure.cli.core.msal_authentication import ServicePrincipalCredential + # User + if user_type == _USER: + return identity.get_user_credential(username_or_sp_id) - sp_id = username_or_sp_id - sp_credential = self._creds_cache.retrieve_cred_for_service_principal(sp_id) - cred = ServicePrincipalCredential(sp_id, secret_or_certificate=sp_credential, authority=authority) - result = cred.get_token(scopes=scopes, data=data) - - else: - raise CLIError("Unknown identity type {}".format(identity_type)) - - if 'error' in result: - from azure.cli.core.auth.util import aad_error_handler - aad_error_handler(result) - - return username_or_sp_id, result["access_token"] - - def get_raw_token(self, resource=None, subscription=None, tenant=None): - logger.debug("Profile.get_raw_token invoked with resource=%r, subscription=%r, tenant=%r", - resource, subscription, tenant) - if subscription and tenant: - raise CLIError("Please specify only one of subscription and tenant, not both") - account = self.get_subscription(subscription) - user_type = account[_USER_ENTITY][_USER_TYPE] - username_or_sp_id = account[_USER_ENTITY][_USER_NAME] - resource = resource or self.cli_ctx.cloud.endpoints.active_directory_resource_id + # Service Principal + if user_type == _SERVICE_PRINCIPAL: + return identity.get_service_principal_credential(username_or_sp_id) - identity_type, identity_id = Profile._try_parse_msi_account_name(account) - if identity_type: - # MSI - if tenant: - raise CLIError("Tenant shouldn't be specified for MSI account") - msi_creds = MsiAccountTypes.msi_auth_factory(identity_type, identity_id, resource) - msi_creds.set_token() - token_entry = msi_creds.token - creds = (token_entry['token_type'], token_entry['access_token'], token_entry) - elif in_cloud_console() and account[_USER_ENTITY].get(_CLOUD_SHELL_ID): - # Cloud Shell - if tenant: - raise CLIError("Tenant shouldn't be specified for Cloud Shell account") - creds = self._get_token_from_cloud_shell(resource) - else: - tenant_dest = tenant if tenant else account[_TENANT_ID] - import adal - try: - if user_type == _USER: - # User - creds = self._creds_cache.retrieve_token_for_user(username_or_sp_id, - tenant_dest, resource) - else: - # Service Principal - use_cert_sn_issuer = bool(account[_USER_ENTITY].get(_SERVICE_PRINCIPAL_CERT_SN_ISSUER_AUTH)) - creds = self._creds_cache.retrieve_token_for_service_principal(username_or_sp_id, - resource, - tenant_dest, - use_cert_sn_issuer) - except adal.AdalError as ex: - from azure.cli.core.adal_authentication import adal_error_handler - adal_error_handler(ex, scopes=resource_to_scopes(resource)) - return (creds, - None if tenant else str(account[_SUBSCRIPTION_ID]), - str(tenant if tenant else account[_TENANT_ID])) + raise NotImplementedError - def refresh_accounts(self, subscription_finder=None): + def refresh_accounts(self): subscriptions = self.load_cached_subscriptions() to_refresh = subscriptions - from azure.cli.core._debug import allow_debug_adal_connection - allow_debug_adal_connection() - subscription_finder = subscription_finder or SubscriptionFinder(self.cli_ctx, - self.auth_ctx_factory, - self._creds_cache.adal_token_cache) + subscription_finder = SubscriptionFinder(self.cli_ctx) refreshed_list = set() result = [] for s in to_refresh: @@ -742,13 +617,12 @@ def refresh_accounts(self, subscription_finder=None): tenant = s[_TENANT_ID] subscriptions = [] try: + identity_credential = self._create_credential(s, tenant) if is_service_principal: - sp_auth = ServicePrincipalAuth(self._creds_cache.retrieve_cred_for_service_principal(user_name)) - subscriptions = subscription_finder.find_from_service_principal_id(user_name, sp_auth, tenant, - self._ad_resource_uri) + subscriptions = subscription_finder.find_using_specific_tenant(tenant, identity_credential) else: - subscriptions = subscription_finder.find_from_user_account(user_name, None, None, - self._ad_resource_uri) + # pylint: disable=protected-access + subscriptions = subscription_finder.find_using_common_tenant(user_name, identity_credential) except Exception as ex: # pylint: disable=broad-except logger.warning("Refreshing for '%s' failed with an error '%s'. The existing accounts were not " "modified. You can run 'az login' later to explicitly refresh them", user_name, ex) @@ -767,56 +641,8 @@ def refresh_accounts(self, subscription_finder=None): is_service_principal) result += consolidated - if self._creds_cache.adal_token_cache.has_state_changed: - self._creds_cache.persist_cached_creds() - self._set_subscriptions(result, merge=False) - def get_sp_auth_info(self, subscription_id=None, name=None, password=None, cert_file=None): - from collections import OrderedDict - account = self.get_subscription(subscription_id) - - # is the credential created through command like 'create-for-rbac'? - result = OrderedDict() - if name and (password or cert_file): - result['clientId'] = name - if password: - result['clientSecret'] = password - else: - result['clientCertificate'] = cert_file - result['subscriptionId'] = subscription_id or account[_SUBSCRIPTION_ID] - else: # has logged in through cli - user_type = account[_USER_ENTITY].get(_USER_TYPE) - if user_type == _SERVICE_PRINCIPAL: - result['clientId'] = account[_USER_ENTITY][_USER_NAME] - sp_auth = ServicePrincipalAuth(self._creds_cache.retrieve_cred_for_service_principal( - account[_USER_ENTITY][_USER_NAME])) - secret = getattr(sp_auth, 'secret', None) - if secret: - result['clientSecret'] = secret - else: - # we can output 'clientCertificateThumbprint' if asked - result['clientCertificate'] = sp_auth.certificate_file - result['subscriptionId'] = account[_SUBSCRIPTION_ID] - else: - raise CLIError('SDK Auth file is only applicable when authenticated using a service principal') - - result[_TENANT_ID] = account[_TENANT_ID] - endpoint_mappings = OrderedDict() # use OrderedDict to control the output sequence - endpoint_mappings['active_directory'] = 'activeDirectoryEndpointUrl' - endpoint_mappings['resource_manager'] = 'resourceManagerEndpointUrl' - endpoint_mappings['active_directory_graph_resource_id'] = 'activeDirectoryGraphResourceId' - endpoint_mappings['sql_management'] = 'sqlManagementEndpointUrl' - endpoint_mappings['gallery'] = 'galleryEndpointUrl' - endpoint_mappings['management'] = 'managementEndpointUrl' - from azure.cli.core.cloud import CloudEndpointNotSetException - for e in endpoint_mappings: - try: - result[endpoint_mappings[e]] = getattr(get_active_cloud(self.cli_ctx).endpoints, e) - except CloudEndpointNotSetException: - result[endpoint_mappings[e]] = None - return result - def get_installation_id(self): installation_id = self._storage.get(_INSTALLATION_ID) if not installation_id: @@ -825,17 +651,12 @@ def get_installation_id(self): self._storage[_INSTALLATION_ID] = installation_id return installation_id - def _login_with_authorization_code_flow(self, tenant, resource): - authority_url, _ = _get_authority_url(self.cli_ctx, tenant) - results = _get_authorization_code(resource, authority_url) - - if not results.get('code'): - raise CLIError('Login failed') - - context = _authentication_context_factory(self.cli_ctx, tenant, self._creds_cache.adal_token_cache) - token_entry = context.acquire_token_with_authorization_code( - results['code'], results['reply_url'], resource, _CLIENT_ID) - return token_entry + def _get_token_from_cloud_shell(self, resource): # pylint: disable=no-self-use + from azure.cli.core.auth.adal_authentication import MSIAuthenticationWrapper + auth = MSIAuthenticationWrapper(resource=resource) + auth.set_token() + token_entry = auth.token + return (token_entry['token_type'], token_entry['access_token'], token_entry) class MsiAccountTypes: @@ -852,7 +673,7 @@ def valid_msi_account_types(): @staticmethod def msi_auth_factory(cli_account_name, identity, resource): - from azure.cli.core.adal_authentication import MSIAuthenticationWrapper + from azure.cli.core.auth.adal_authentication import MSIAuthenticationWrapper if cli_account_name == MsiAccountTypes.system_assigned: return MSIAuthenticationWrapper(resource=resource) if cli_account_name == MsiAccountTypes.user_assigned_client_id: @@ -865,138 +686,57 @@ def msi_auth_factory(cli_account_name, identity, resource): class SubscriptionFinder: - '''finds all subscriptions for a user or service principal''' - - def __init__(self, cli_ctx, auth_context_factory, adal_token_cache, arm_client_factory=None): + # An ARM client. It finds subscriptions for a user or service principal. It shouldn't do any + # authentication work, but only find subscriptions + def __init__(self, cli_ctx): - self._adal_token_cache = adal_token_cache - self._auth_context_factory = auth_context_factory self.user_id = None # will figure out after log user in self.cli_ctx = cli_ctx - - def create_arm_client_factory(credentials): - if arm_client_factory: - return arm_client_factory(credentials) - from azure.cli.core.profiles import ResourceType, get_api_version - from azure.cli.core.commands.client_factory import _prepare_client_kwargs_track2 - - client_type = self._get_subscription_client_class() - if client_type is None: - from azure.cli.core.azclierror import CLIInternalError - raise CLIInternalError("Unable to get '{}' in profile '{}'" - .format(ResourceType.MGMT_RESOURCE_SUBSCRIPTIONS, cli_ctx.cloud.profile)) - api_version = get_api_version(cli_ctx, ResourceType.MGMT_RESOURCE_SUBSCRIPTIONS) - client_kwargs = _prepare_client_kwargs_track2(cli_ctx) - # We don't need to change credential_scopes as 'scopes' is ignored by BasicTokenCredential anyway - client = client_type(credentials, api_version=api_version, - base_url=self.cli_ctx.cloud.endpoints.resource_manager, **client_kwargs) - return client - - self._arm_client_factory = create_arm_client_factory + self.secret = None + self._arm_resource_id = cli_ctx.cloud.endpoints.active_directory_resource_id + self.authority = self.cli_ctx.cloud.endpoints.active_directory self.tenants = [] - def find_from_user_account(self, username, password, tenant, resource): - context = self._create_auth_context(tenant) - if password: - token_entry = context.acquire_token_with_username_password(resource, username, password, _CLIENT_ID) - else: # when refresh account, we will leverage local cached tokens - token_entry = context.acquire_token(resource, username, _CLIENT_ID) - - if not token_entry: - return [] - self.user_id = token_entry[_TOKEN_ENTRY_USER_ID] - - if tenant is None: - result = self._find_using_common_tenant(token_entry[_ACCESS_TOKEN], resource) - else: - result = self._find_using_specific_tenant(tenant, token_entry[_ACCESS_TOKEN]) - return result - - def find_through_authorization_code_flow(self, tenant, resource, authority_url, auth_resource=None): - # launch browser and get the code - results = _get_authorization_code(auth_resource or resource, authority_url) - - if not results.get('code'): - raise CLIError('Login failed') # error detail is already displayed through previous steps - - # exchange the code for the token - context = self._create_auth_context(tenant) - token_entry = context.acquire_token_with_authorization_code(results['code'], results['reply_url'], - resource, _CLIENT_ID, None) - self.user_id = token_entry[_TOKEN_ENTRY_USER_ID] - logger.warning("You have logged in. Now let us find all the subscriptions to which you have access...") - if tenant is None: - result = self._find_using_common_tenant(token_entry[_ACCESS_TOKEN], resource) - else: - result = self._find_using_specific_tenant(tenant, token_entry[_ACCESS_TOKEN]) - return result - - def find_through_interactive_flow(self, tenant, resource, auth_resource=None): - context = self._create_auth_context(tenant) - code = context.acquire_user_code(auth_resource or resource, _CLIENT_ID) - logger.warning(code['message']) - token_entry = context.acquire_token_with_device_code(resource, code, _CLIENT_ID) - self.user_id = token_entry[_TOKEN_ENTRY_USER_ID] - if tenant is None: - result = self._find_using_common_tenant(token_entry[_ACCESS_TOKEN], resource) - else: - result = self._find_using_specific_tenant(tenant, token_entry[_ACCESS_TOKEN]) - return result - - def find_from_service_principal_id(self, client_id, sp_auth, tenant, resource): - context = self._create_auth_context(tenant, False) - token_entry = sp_auth.acquire_token(context, resource, client_id) - self.user_id = client_id - result = self._find_using_specific_tenant(tenant, token_entry[_ACCESS_TOKEN]) - self.tenants = [tenant] - return result - - # only occur inside cloud console or VM with identity - def find_from_raw_token(self, tenant, token): - # decode the token, so we know the tenant - result = self._find_using_specific_tenant(tenant, token) - self.tenants = [tenant] - return result - - def _create_auth_context(self, tenant, use_token_cache=True): - token_cache = self._adal_token_cache if use_token_cache else None - return self._auth_context_factory(self.cli_ctx, tenant, token_cache) - - def _find_using_common_tenant(self, access_token, resource): - import adal - from azure.cli.core.adal_authentication import BasicTokenCredential - + def find_using_common_tenant(self, username, credential=None): + # pylint: disable=too-many-statements all_subscriptions = [] empty_tenants = [] mfa_tenants = [] - token_credential = BasicTokenCredential(access_token) - client = self._arm_client_factory(token_credential) + + client = self._create_subscription_client(credential) tenants = client.tenants.list() + for t in tenants: tenant_id = t.tenant_id - logger.debug("Finding subscriptions under tenant %s", tenant_id) # display_name is available since /tenants?api-version=2018-06-01, # not available in /tenants?api-version=2016-06-01 if not hasattr(t, 'display_name'): t.display_name = None - temp_context = self._create_auth_context(tenant_id) + + t.tenant_id_name = tenant_id + if t.display_name: + # e.g. '72f988bf-86f1-41af-91ab-2d7cd011db47 Microsoft' + t.tenant_id_name = "{} '{}'".format(tenant_id, t.display_name) + + logger.info("Finding subscriptions under tenant %s", t.tenant_id_name) + + identity = Identity(self.authority, tenant_id) + + specific_tenant_credential = identity.get_user_credential(username) + try: - logger.debug("Acquiring a token with tenant=%s, resource=%s", tenant_id, resource) - temp_credentials = temp_context.acquire_token(resource, self.user_id, _CLIENT_ID) - except adal.AdalError as ex: + subscriptions = self.find_using_specific_tenant(tenant_id, specific_tenant_credential) + except AuthenticationError as ex: # because user creds went through the 'common' tenant, the error here must be # tenant specific, like the account was disabled. For such errors, we will continue # with other tenants. - msg = (getattr(ex, 'error_response', None) or {}).get('error_description') or '' + msg = ex.error_msg if 'AADSTS50076' in msg: # The tenant requires MFA and can't be accessed with home tenant's refresh token mfa_tenants.append(t) else: logger.warning("Failed to authenticate '%s' due to error '%s'", t, ex) continue - subscriptions = self._find_using_specific_tenant( - tenant_id, - temp_credentials[_ACCESS_TOKEN]) if not subscriptions: empty_tenants.append(t) @@ -1021,380 +761,43 @@ def _find_using_common_tenant(self, access_token, resource): logger.warning("The following tenants don't contain accessible subscriptions. " "Use 'az login --allow-no-subscriptions' to have tenant level access.") for t in empty_tenants: - if t.display_name: - logger.warning("%s '%s'", t.tenant_id, t.display_name) - else: - logger.warning("%s", t.tenant_id) + logger.warning("%s", t.tenant_id_name) # Show warning for MFA tenants if mfa_tenants: logger.warning("The following tenants require Multi-Factor Authentication (MFA). " "Use 'az login --tenant TENANT_ID' to explicitly login to a tenant.") for t in mfa_tenants: - if t.display_name: - logger.warning("%s '%s'", t.tenant_id, t.display_name) - else: - logger.warning("%s", t.tenant_id) + logger.warning("%s", t.tenant_id_name) return all_subscriptions - def _find_using_specific_tenant(self, tenant, access_token): - from azure.cli.core.adal_authentication import BasicTokenCredential - - token_credential = BasicTokenCredential(access_token) - client = self._arm_client_factory(token_credential) + def find_using_specific_tenant(self, tenant, credential): + client = self._create_subscription_client(credential) subscriptions = client.subscriptions.list() all_subscriptions = [] for s in subscriptions: - # map tenantId from REST API to homeTenantId - if hasattr(s, "tenant_id"): - setattr(s, 'home_tenant_id', s.tenant_id) - setattr(s, 'tenant_id', tenant) + _attach_token_tenant(s, tenant) all_subscriptions.append(s) self.tenants.append(tenant) return all_subscriptions - def _get_subscription_client_class(self): # pylint: disable=no-self-use - """Get the subscription client class. It can come from either the vendored SDK or public SDK, depending - on the design of architecture. - """ - if _USE_VENDORED_SUBSCRIPTION_SDK: - # Use vendered subscription SDK to decouple from `resource` command module - # pylint: disable=no-name-in-module, import-error - from azure.cli.core.vendored_sdks.subscriptions import SubscriptionClient - client_type = SubscriptionClient - else: - # Use the public SDK - from azure.cli.core.profiles import ResourceType - from azure.cli.core.profiles._shared import get_client_class - client_type = get_client_class(ResourceType.MGMT_RESOURCE_SUBSCRIPTIONS) - return client_type - - -class CredsCache: - '''Caches AAD tokena and service principal secrets, and persistence will - also be handled - ''' - - def __init__(self, cli_ctx, auth_ctx_factory=None, async_persist=True): - # AZURE_ACCESS_TOKEN_FILE is used by Cloud Console and not meant to be user configured - self._token_file = (os.environ.get('AZURE_ACCESS_TOKEN_FILE', None) or - os.path.join(get_config_dir(), 'accessTokens.json')) - self._service_principal_creds = [] - self._auth_ctx_factory = auth_ctx_factory - self._adal_token_cache_attr = None - self._should_flush_to_disk = False - self._async_persist = async_persist - self._ctx = cli_ctx - if async_persist: - import atexit - atexit.register(self.flush_to_disk) - - def persist_cached_creds(self): - self._should_flush_to_disk = True - if not self._async_persist: - self.flush_to_disk() - self.adal_token_cache.has_state_changed = False - - def flush_to_disk(self): - if self._should_flush_to_disk: - with os.fdopen(os.open(self._token_file, os.O_RDWR | os.O_CREAT | os.O_TRUNC, 0o600), - 'w+') as cred_file: - items = self.adal_token_cache.read_items() - all_creds = [entry for _, entry in items] - - # trim away useless fields (needed for cred sharing with xplat) - for i in all_creds: - for key in TOKEN_FIELDS_EXCLUDED_FROM_PERSISTENCE: - i.pop(key, None) - - all_creds.extend(self._service_principal_creds) - cred_file.write(json.dumps(all_creds)) - - def retrieve_token_for_user(self, username, tenant, resource): - context = self._auth_ctx_factory(self._ctx, tenant, cache=self.adal_token_cache) - token_entry = context.acquire_token(resource, username, _CLIENT_ID) - if not token_entry: - raise CLIError("Could not retrieve token from local cache.{}".format( - " Please run 'az login'." if not in_cloud_console() else '')) - - if self.adal_token_cache.has_state_changed: - self.persist_cached_creds() - return (token_entry[_TOKEN_ENTRY_TOKEN_TYPE], token_entry[_ACCESS_TOKEN], token_entry) - - def retrieve_token_for_service_principal(self, sp_id, resource, tenant, use_cert_sn_issuer=False): - self.load_adal_token_cache() - matched = [x for x in self._service_principal_creds if sp_id == x[_SERVICE_PRINCIPAL_ID]] - if not matched: - raise CLIError("Could not retrieve credential from local cache for service principal {}. " - "Please run 'az login' for this service principal." - .format(sp_id)) - matched_with_tenant = [x for x in matched if tenant == x[_SERVICE_PRINCIPAL_TENANT]] - if matched_with_tenant: - cred = matched_with_tenant[0] - else: - logger.warning("Could not retrieve credential from local cache for service principal %s under tenant %s. " - "Trying credential under tenant %s, assuming that is an app credential.", - sp_id, tenant, matched[0][_SERVICE_PRINCIPAL_TENANT]) - cred = matched[0] - - context = self._auth_ctx_factory(self._ctx, tenant, None) - sp_auth = ServicePrincipalAuth(cred.get(_ACCESS_TOKEN, None) or - cred.get(_SERVICE_PRINCIPAL_CERT_FILE, None), - use_cert_sn_issuer) - token_entry = sp_auth.acquire_token(context, resource, sp_id) - return (token_entry[_TOKEN_ENTRY_TOKEN_TYPE], token_entry[_ACCESS_TOKEN], token_entry) - - def retrieve_cred_for_service_principal(self, sp_id): - """Returns the secret or certificate of the specified service principal.""" - self.load_adal_token_cache() - matched = [x for x in self._service_principal_creds if sp_id == x[_SERVICE_PRINCIPAL_ID]] - if not matched: - raise CLIError("No matched service principal found") - cred = matched[0] - return cred.get(_ACCESS_TOKEN) or cred.get(_SERVICE_PRINCIPAL_CERT_FILE) - - @property - def adal_token_cache(self): - return self.load_adal_token_cache() - - def load_adal_token_cache(self): - if self._adal_token_cache_attr is None: - import adal - all_entries = _load_tokens_from_file(self._token_file) - self._load_service_principal_creds(all_entries) - real_token = [x for x in all_entries if x not in self._service_principal_creds] - self._adal_token_cache_attr = adal.TokenCache(json.dumps(real_token)) - return self._adal_token_cache_attr - - def save_service_principal_cred(self, sp_entry): - self.load_adal_token_cache() - matched = [x for x in self._service_principal_creds - if sp_entry[_SERVICE_PRINCIPAL_ID] == x[_SERVICE_PRINCIPAL_ID] and - sp_entry[_SERVICE_PRINCIPAL_TENANT] == x[_SERVICE_PRINCIPAL_TENANT]] - state_changed = False - if matched: - # pylint: disable=line-too-long - if (sp_entry.get(_ACCESS_TOKEN, None) != matched[0].get(_ACCESS_TOKEN, None) or - sp_entry.get(_SERVICE_PRINCIPAL_CERT_FILE, None) != matched[0].get(_SERVICE_PRINCIPAL_CERT_FILE, None)): - self._service_principal_creds.remove(matched[0]) - self._service_principal_creds.append(sp_entry) - state_changed = True - else: - self._service_principal_creds.append(sp_entry) - state_changed = True - - if state_changed: - self.persist_cached_creds() - - def _load_service_principal_creds(self, creds): - for c in creds: - if c.get(_SERVICE_PRINCIPAL_ID): - self._service_principal_creds.append(c) - return self._service_principal_creds - - def remove_cached_creds(self, user_or_sp): - state_changed = False - # clear AAD tokens - tokens = self.adal_token_cache.find({_TOKEN_ENTRY_USER_ID: user_or_sp}) - if tokens: - state_changed = True - self.adal_token_cache.remove(tokens) - - # clear service principal creds - matched = [x for x in self._service_principal_creds - if x[_SERVICE_PRINCIPAL_ID] == user_or_sp] - if matched: - state_changed = True - self._service_principal_creds = [x for x in self._service_principal_creds - if x not in matched] - - if state_changed: - self.persist_cached_creds() - - def remove_all_cached_creds(self): - # we can clear file contents, but deleting it is simpler - _delete_file(self._token_file) - - -class ServicePrincipalAuth: - - def __init__(self, password_arg_value, use_cert_sn_issuer=None): - if not password_arg_value: - raise CLIError('missing secret or certificate in order to ' - 'authenticate through a service principal') - if os.path.isfile(password_arg_value): - certificate_file = password_arg_value - from OpenSSL.crypto import load_certificate, FILETYPE_PEM, Error - self.certificate_file = certificate_file - self.public_certificate = None - try: - with open(certificate_file, 'r') as file_reader: - self.cert_file_string = file_reader.read() - cert = load_certificate(FILETYPE_PEM, self.cert_file_string) - self.thumbprint = cert.digest("sha1").decode() - if use_cert_sn_issuer: - # low-tech but safe parsing based on - # https://github.com/libressl-portable/openbsd/blob/master/src/lib/libcrypto/pem/pem.h - match = re.search(r'\-+BEGIN CERTIFICATE.+\-+(?P[^-]+)\-+END CERTIFICATE.+\-+', - self.cert_file_string, re.I) - self.public_certificate = match.group('public').strip() - except (UnicodeDecodeError, Error): - raise CLIError('Invalid certificate, please use a valid PEM file.') - else: - self.secret = password_arg_value - - def acquire_token(self, authentication_context, resource, client_id): - if hasattr(self, 'secret'): - return authentication_context.acquire_token_with_client_credentials(resource, client_id, self.secret) - return authentication_context.acquire_token_with_client_certificate(resource, client_id, self.cert_file_string, - self.thumbprint, self.public_certificate) - - def get_entry_to_persist(self, sp_id, tenant): - entry = { - _SERVICE_PRINCIPAL_ID: sp_id, - _SERVICE_PRINCIPAL_TENANT: tenant, - } - if hasattr(self, 'secret'): - entry[_ACCESS_TOKEN] = self.secret - else: - entry[_SERVICE_PRINCIPAL_CERT_FILE] = self.certificate_file - entry[_SERVICE_PRINCIPAL_CERT_THUMBPRINT] = self.thumbprint - - return entry - - -def _get_authorization_code_worker(authority_url, resource, results): - # pylint: disable=too-many-statements - import socket - import random - import http.server - - class ClientRedirectServer(http.server.HTTPServer): # pylint: disable=too-few-public-methods - query_params = {} - - class ClientRedirectHandler(http.server.BaseHTTPRequestHandler): - # pylint: disable=line-too-long - - def do_GET(self): - try: - from urllib.parse import parse_qs - except ImportError: - from urlparse import parse_qs # pylint: disable=import-error - - if self.path.endswith('/favicon.ico'): # deal with legacy IE - self.send_response(204) - return - - query = self.path.split('?', 1)[-1] - query = parse_qs(query, keep_blank_values=True) - self.server.query_params = query - - self.send_response(200) - self.send_header('Content-type', 'text/html') - self.end_headers() - - landing_file = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'auth_landing_pages', - 'ok.html' if 'code' in query else 'fail.html') - with open(landing_file, 'rb') as html_file: - self.wfile.write(html_file.read()) - - def log_message(self, format, *args): # pylint: disable=redefined-builtin,unused-argument,no-self-use - pass # this prevent http server from dumping messages to stdout - - reply_url = None - - # On Windows, HTTPServer by default doesn't throw error if the port is in-use - # https://github.com/Azure/azure-cli/issues/10578 - if is_windows(): - logger.debug('Windows is detected. Set HTTPServer.allow_reuse_address to False') - ClientRedirectServer.allow_reuse_address = False - elif is_wsl(): - logger.debug('WSL is detected. Set HTTPServer.allow_reuse_address to False') - ClientRedirectServer.allow_reuse_address = False - - for port in range(8400, 9000): - try: - web_server = ClientRedirectServer(('localhost', port), ClientRedirectHandler) - reply_url = "http://localhost:{}".format(port) - break - except socket.error as ex: - logger.warning("Port '%s' is taken with error '%s'. Trying with the next one", port, ex) - except UnicodeDecodeError: - logger.warning("Please make sure there is no international (Unicode) character in the computer name " - r"or C:\Windows\System32\drivers\etc\hosts file's 127.0.0.1 entries. " - "For more details, please see https://github.com/Azure/azure-cli/issues/12957") - break - - if reply_url is None: - logger.warning("Error: can't reserve a port for authentication reply url") - return - - try: - request_state = ''.join(random.SystemRandom().choice(string.ascii_lowercase + string.digits) for _ in range(20)) - except NotImplementedError: - request_state = 'code' - - # launch browser: - url = ('{0}/oauth2/authorize?response_type=code&client_id={1}' - '&redirect_uri={2}&state={3}&resource={4}&prompt=select_account') - url = url.format(authority_url, _CLIENT_ID, reply_url, request_state, resource) - logger.info('Open browser with url: %s', url) - succ = open_page_in_browser(url) - if succ is False: - web_server.server_close() - results['no_browser'] = True - return - - # Emit a warning to inform that a browser is opened. - # Only show the path part of the URL and hide the query string. - logger.warning("The default web browser has been opened at %s. Please continue the login in the web browser. " - "If no web browser is available or if the web browser fails to open, use device code flow " - "with `az login --use-device-code`.", url.split('?')[0]) - - # wait for callback from browser. - while True: - web_server.handle_request() - if 'error' in web_server.query_params or 'code' in web_server.query_params: - break - - if 'error' in web_server.query_params: - logger.warning('Authentication Error: "%s". Description: "%s" ', web_server.query_params['error'], - web_server.query_params.get('error_description')) - return - - if 'code' in web_server.query_params: - code = web_server.query_params['code'] - else: - logger.warning('Authentication Error: Authorization code was not captured in query strings "%s"', - web_server.query_params) - return - - if 'state' in web_server.query_params: - response_state = web_server.query_params['state'][0] - if response_state != request_state: - raise RuntimeError("mismatched OAuth state") - else: - raise RuntimeError("missing OAuth state") - - results['code'] = code[0] - results['reply_url'] = reply_url - - -def _get_authorization_code(resource, authority_url): - import threading - import time - results = {} - t = threading.Thread(target=_get_authorization_code_worker, - args=(authority_url, resource, results)) - t.daemon = True - t.start() - while True: - time.sleep(2) # so that ctrl+c can stop the command - if not t.is_alive(): - break # done - if results.get('no_browser'): - raise RuntimeError() - return results + def _create_subscription_client(self, credential): + from azure.cli.core.profiles import ResourceType, get_api_version + from azure.cli.core.profiles._shared import get_client_class + from azure.cli.core.commands.client_factory import _prepare_mgmt_client_kwargs_track2 + + client_type = get_client_class(ResourceType.MGMT_RESOURCE_SUBSCRIPTIONS) + if client_type is None: + from azure.cli.core.azclierror import CLIInternalError + raise CLIInternalError("Unable to get '{}' in profile '{}'" + .format(ResourceType.MGMT_RESOURCE_SUBSCRIPTIONS, self.cli_ctx.cloud.profile)) + api_version = get_api_version(self.cli_ctx, ResourceType.MGMT_RESOURCE_SUBSCRIPTIONS) + client_kwargs = _prepare_mgmt_client_kwargs_track2(self.cli_ctx, credential) + + client = client_type(credential, api_version=api_version, + base_url=self.cli_ctx.cloud.endpoints.resource_manager, + **client_kwargs) + return client def _transform_subscription_for_multiapi(s, s_dict): diff --git a/src/azure-cli-core/azure/cli/core/adal_authentication.py b/src/azure-cli-core/azure/cli/core/adal_authentication.py deleted file mode 100644 index 0f0febd8e69..00000000000 --- a/src/azure-cli-core/azure/cli/core/adal_authentication.py +++ /dev/null @@ -1,256 +0,0 @@ -# -------------------------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for license information. -# -------------------------------------------------------------------------------------------- - -import requests -import adal - -from msrest.authentication import Authentication -from msrestazure.azure_active_directory import MSIAuthentication -from azure.core.credentials import AccessToken -from azure.cli.core.util import in_cloud_console, scopes_to_resource, resource_to_scopes - -from knack.util import CLIError -from knack.log import get_logger - -logger = get_logger(__name__) - - -class AdalAuthentication(Authentication): # pylint: disable=too-few-public-methods - - def __init__(self, token_retriever, external_tenant_token_retriever=None, resource=None): - # DO NOT call _token_retriever from outside azure-cli-core. It is only available for user or - # Service Principal credential (AdalAuthentication), but not for Managed Identity credential - # (MSIAuthenticationWrapper). - # To retrieve a raw token, either call - # - Profile.get_raw_token, which is more direct - # - AdalAuthentication.get_token, which is designed for Track 2 SDKs - self._token_retriever = token_retriever - self._external_tenant_token_retriever = external_tenant_token_retriever - self._resource = resource - - def _get_token(self, sdk_resource=None): - """ - :param sdk_resource: `resource` converted from Track 2 SDK's `scopes` - """ - - # When called by - # - Track 1 SDK, use `resource` specified by CLI - # - Track 2 SDK, use `sdk_resource` specified by SDK and ignore `resource` specified by CLI - token_resource = sdk_resource or self._resource - - external_tenant_tokens = None - try: - scheme, token, token_entry = self._token_retriever(token_resource) - if self._external_tenant_token_retriever: - external_tenant_tokens = self._external_tenant_token_retriever(token_resource) - except CLIError as err: - if in_cloud_console(): - AdalAuthentication._log_hostname() - raise err - except adal.AdalError as err: - if in_cloud_console(): - AdalAuthentication._log_hostname() - adal_error_handler(err, scopes=resource_to_scopes(token_resource)) - except requests.exceptions.SSLError as err: - from .util import SSLERROR_TEMPLATE - raise CLIError(SSLERROR_TEMPLATE.format(str(err))) - except requests.exceptions.ConnectionError as err: - raise CLIError('Please ensure you have network connection. Error detail: ' + str(err)) - - # scheme: str. The token scheme. Should always be 'Bearer'. - # token: str. The raw access token. - # token_entry: dict. The full token entry. - # external_tenant_tokens: [(scheme: str, token: str, token_entry: dict), ...] - return scheme, token, token_entry, external_tenant_tokens - - def get_all_tokens(self, *scopes): - return self._get_token(_try_scopes_to_resource(scopes)) - - # This method is exposed for Azure Core. - def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument - logger.debug("AdalAuthentication.get_token invoked by Track 2 SDK with scopes=%s", scopes) - - _, token, token_entry, _ = self._get_token(_try_scopes_to_resource(scopes)) - - # NEVER use expiresIn (expires_in) as the token is cached and expiresIn will be already out-of date - # when being retrieved. - - # User token entry sample: - # { - # "tokenType": "Bearer", - # "expiresOn": "2020-11-13 14:44:42.492318", - # "resource": "https://management.core.windows.net/", - # "userId": "test@azuresdkteam.onmicrosoft.com", - # "accessToken": "eyJ0eXAiOiJKV...", - # "refreshToken": "0.ATcAImuCVN...", - # "_clientId": "04b07795-8ddb-461a-bbee-02f9e1bf7b46", - # "_authority": "https://login.microsoftonline.com/54826b22-38d6-4fb2-bad9-b7b93a3e9c5a", - # "isMRRT": True, - # "expiresIn": 3599 - # } - - # Service Principal token entry sample: - # { - # "tokenType": "Bearer", - # "expiresIn": 3599, - # "expiresOn": "2020-11-12 13:50:47.114324", - # "resource": "https://management.core.windows.net/", - # "accessToken": "eyJ0eXAiOiJKV...", - # "isMRRT": True, - # "_clientId": "22800c35-46c2-4210-b8a7-d8c3ec3b526f", - # "_authority": "https://login.microsoftonline.com/54826b22-38d6-4fb2-bad9-b7b93a3e9c5a" - # } - if 'expiresOn' in token_entry: - import datetime - expires_on_timestamp = int(_timestamp( - datetime.datetime.strptime(token_entry['expiresOn'], '%Y-%m-%d %H:%M:%S.%f'))) - return AccessToken(token, expires_on_timestamp) - - # Cloud Shell (Managed Identity) token entry sample: - # { - # "access_token": "eyJ0eXAiOiJKV...", - # "refresh_token": "", - # "expires_in": "2106", - # "expires_on": "1605686811", - # "not_before": "1605682911", - # "resource": "https://management.core.windows.net/", - # "token_type": "Bearer" - # } - if 'expires_on' in token_entry: - return AccessToken(token, int(token_entry['expires_on'])) - - from azure.cli.core.azclierror import CLIInternalError - raise CLIInternalError("No expiresOn or expires_on is available in the token entry.") - - # This method is exposed for msrest. - def signed_session(self, session=None): # pylint: disable=arguments-differ - logger.debug("AdalAuthentication.signed_session invoked by Track 1 SDK") - session = session or super(AdalAuthentication, self).signed_session() - - scheme, token, _, external_tenant_tokens = self._get_token() - - header = "{} {}".format(scheme, token) - session.headers['Authorization'] = header - if external_tenant_tokens: - aux_tokens = ';'.join(['{} {}'.format(scheme2, tokens2) for scheme2, tokens2, _ in external_tenant_tokens]) - session.headers['x-ms-authorization-auxiliary'] = aux_tokens - return session - - @staticmethod - def _log_hostname(): - import socket - logger.warning("A Cloud Shell credential problem occurred. When you report the issue with the error " - "below, please mention the hostname '%s'", socket.gethostname()) - - -class MSIAuthenticationWrapper(MSIAuthentication): - # This method is exposed for Azure Core. Add *scopes, **kwargs to fit azure.core requirement - def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument - logger.debug("MSIAuthenticationWrapper.get_token invoked by Track 2 SDK with scopes=%s", scopes) - resource = _try_scopes_to_resource(scopes) - if resource: - # If available, use resource provided by SDK - self.resource = resource - self.set_token() - # Managed Identity token entry sample: - # { - # "access_token": "eyJ0eXAiOiJKV...", - # "client_id": "da95e381-d7ab-4fdc-8047-2457909c723b", - # "expires_in": "86386", - # "expires_on": "1605238724", - # "ext_expires_in": "86399", - # "not_before": "1605152024", - # "resource": "https://management.azure.com/", - # "token_type": "Bearer" - # } - return AccessToken(self.token['access_token'], int(self.token['expires_on'])) - - def set_token(self): - import traceback - from azure.cli.core.azclierror import AzureConnectionError, AzureResponseError - try: - super(MSIAuthenticationWrapper, self).set_token() - except requests.exceptions.ConnectionError as err: - logger.debug('throw requests.exceptions.ConnectionError when doing MSIAuthentication: \n%s', - traceback.format_exc()) - raise AzureConnectionError('Failed to connect to MSI. Please make sure MSI is configured correctly ' - 'and check the network connection.\nError detail: {}'.format(str(err))) - except requests.exceptions.HTTPError as err: - logger.debug('throw requests.exceptions.HTTPError when doing MSIAuthentication: \n%s', - traceback.format_exc()) - try: - raise AzureResponseError('Failed to connect to MSI. Please make sure MSI is configured correctly.\n' - 'Get Token request returned http error: {}, reason: {}' - .format(err.response.status, err.response.reason)) - except AttributeError: - raise AzureResponseError('Failed to connect to MSI. Please make sure MSI is configured correctly.\n' - 'Get Token request returned: {}'.format(err.response)) - except TimeoutError as err: - logger.debug('throw TimeoutError when doing MSIAuthentication: \n%s', - traceback.format_exc()) - raise AzureConnectionError('MSI endpoint is not responding. Please make sure MSI is configured correctly.\n' - 'Error detail: {}'.format(str(err))) - - def signed_session(self, session=None): - logger.debug("MSIAuthenticationWrapper.signed_session invoked by Track 1 SDK") - super().signed_session(session) - - -def _try_scopes_to_resource(scopes): - """Wrap scopes_to_resource to workaround some SDK issues.""" - - # Track 2 SDKs generated before https://github.com/Azure/autorest.python/pull/239 don't maintain - # credential_scopes and call `get_token` with empty scopes. - # As a workaround, return None so that the CLI-managed resource is used. - if not scopes: - logger.debug("No scope is provided by the SDK, use the CLI-managed resource.") - return None - - # Track 2 SDKs generated before https://github.com/Azure/autorest.python/pull/745 extend default - # credential_scopes with custom credential_scopes. Instead, credential_scopes should be replaced by - # custom credential_scopes. https://github.com/Azure/azure-sdk-for-python/issues/12947 - # As a workaround, remove the first one if there are multiple scopes provided. - if len(scopes) > 1: - logger.debug("Multiple scopes are provided by the SDK, discarding the first one: %s", scopes[0]) - return scopes_to_resource(scopes[1:]) - - # Exactly only one scope is provided - return scopes_to_resource(scopes) - - -class BasicTokenCredential: - # pylint:disable=too-few-public-methods - """A Track 2 implementation of msrest.authentication.BasicTokenAuthentication. - This credential shouldn't be used by any command module, expect azure-cli-core. - """ - def __init__(self, access_token): - self.access_token = access_token - - def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument - # Because get_token can't refresh the access token, always mark the token as unexpired - import time - return AccessToken(self.access_token, int(time.time() + 3600)) - - -def _timestamp(dt): - # datetime.datetime can't be patched: - # TypeError: can't set attributes of built-in/extension type 'datetime.datetime' - # So we wrap datetime.datetime.timestamp with this function. - # https://docs.python.org/3/library/unittest.mock-examples.html#partial-mocking - # https://williambert.online/2011/07/how-to-unit-testing-in-django-with-mocking-and-patching/ - return dt.timestamp() - - -def adal_error_handler(err: adal.AdalError, **kwargs): - """ Handle AdalError. """ - try: - from azure.cli.core.auth.util import aad_error_handler - aad_error_handler(err.error_response, **kwargs) - except AttributeError: - # In case of AdalError created as - # AdalError('More than one token matches the criteria. The result is ambiguous.') - # https://github.com/Azure/azure-cli/issues/15320 - from azure.cli.core.azclierror import UnknownError - raise UnknownError(str(err), recommendation="Please run `az account clear`, then `az login`.") diff --git a/src/azure-cli-core/azure/cli/core/auth/adal_authentication.py b/src/azure-cli-core/azure/cli/core/auth/adal_authentication.py new file mode 100644 index 00000000000..8b8252679e7 --- /dev/null +++ b/src/azure-cli-core/azure/cli/core/auth/adal_authentication.py @@ -0,0 +1,66 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- + +import requests +from azure.core.credentials import AccessToken +from knack.log import get_logger +from msrestazure.azure_active_directory import MSIAuthentication + +from .util import _normalize_scopes, scopes_to_resource + +logger = get_logger(__name__) + + +class MSIAuthenticationWrapper(MSIAuthentication): + # This method is exposed for Azure Core. Add *scopes, **kwargs to fit azure.core requirement + def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument + logger.debug("MSIAuthenticationWrapper.get_token invoked by Track 2 SDK with scopes=%s", scopes) + resource = scopes_to_resource(_normalize_scopes(scopes)) + if resource: + # If available, use resource provided by SDK + self.resource = resource + self.set_token() + # Managed Identity token entry sample: + # { + # "access_token": "eyJ0eXAiOiJKV...", + # "client_id": "da95e381-d7ab-4fdc-8047-2457909c723b", + # "expires_in": "86386", + # "expires_on": "1605238724", + # "ext_expires_in": "86399", + # "not_before": "1605152024", + # "resource": "https://management.azure.com/", + # "token_type": "Bearer" + # } + return AccessToken(self.token['access_token'], int(self.token['expires_on'])) + + def set_token(self): + import traceback + from azure.cli.core.azclierror import AzureConnectionError, AzureResponseError + try: + super(MSIAuthenticationWrapper, self).set_token() + except requests.exceptions.ConnectionError as err: + logger.debug('throw requests.exceptions.ConnectionError when doing MSIAuthentication: \n%s', + traceback.format_exc()) + raise AzureConnectionError('Failed to connect to MSI. Please make sure MSI is configured correctly ' + 'and check the network connection.\nError detail: {}'.format(str(err))) + except requests.exceptions.HTTPError as err: + logger.debug('throw requests.exceptions.HTTPError when doing MSIAuthentication: \n%s', + traceback.format_exc()) + try: + raise AzureResponseError('Failed to connect to MSI. Please make sure MSI is configured correctly.\n' + 'Get Token request returned http error: {}, reason: {}' + .format(err.response.status, err.response.reason)) + except AttributeError: + raise AzureResponseError('Failed to connect to MSI. Please make sure MSI is configured correctly.\n' + 'Get Token request returned: {}'.format(err.response)) + except TimeoutError as err: + logger.debug('throw TimeoutError when doing MSIAuthentication: \n%s', + traceback.format_exc()) + raise AzureConnectionError('MSI endpoint is not responding. Please make sure MSI is configured correctly.\n' + 'Error detail: {}'.format(str(err))) + + def signed_session(self, session=None): + logger.debug("MSIAuthenticationWrapper.signed_session invoked by Track 1 SDK") + super().signed_session(session) diff --git a/src/azure-cli-core/azure/cli/core/auth/credential_adaptor.py b/src/azure-cli-core/azure/cli/core/auth/credential_adaptor.py new file mode 100644 index 00000000000..01ab8637d39 --- /dev/null +++ b/src/azure-cli-core/azure/cli/core/auth/credential_adaptor.py @@ -0,0 +1,66 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- + +import requests +from knack.log import get_logger +from knack.util import CLIError + +from .util import resource_to_scopes, _normalize_scopes + +logger = get_logger(__name__) + + +class CredentialAdaptor: + def __init__(self, credential, resource=None, auxiliary_credentials=None): + """ + Adaptor to both + - Track 1: msrest.authentication.Authentication, which exposes signed_session + - Track 2: azure.core.credentials.TokenCredential, which exposes get_token + + :param credential: Main credential from .msal_authentication + :param resource: AAD resource for Track 1 only + :param auxiliary_credentials: Credentials from .msal_authentication for cross tenant authentication. + Details about cross tenant authentication: + https://docs.microsoft.com/en-us/azure/azure-resource-manager/management/authenticate-multi-tenant + """ + + self._credential = credential + self._auxiliary_credentials = auxiliary_credentials + self._resource = resource + + def _get_token(self, scopes=None, **kwargs): + external_tenant_tokens = [] + # If scopes is not provided, use CLI-managed resource + scopes = scopes or resource_to_scopes(self._resource) + try: + token = self._credential.get_token(*scopes, **kwargs) + if self._auxiliary_credentials: + external_tenant_tokens = [cred.get_token(*scopes) for cred in self._auxiliary_credentials] + return token, external_tenant_tokens + except requests.exceptions.SSLError as err: + from azure.cli.core.util import SSLERROR_TEMPLATE + raise CLIError(SSLERROR_TEMPLATE.format(str(err))) + + def signed_session(self, session=None): + logger.debug("CredentialAdaptor.get_token") + session = session or requests.Session() + token, external_tenant_tokens = self._get_token() + header = "{} {}".format('Bearer', token.token) + session.headers['Authorization'] = header + if external_tenant_tokens: + aux_tokens = ';'.join(['{} {}'.format('Bearer', tokens2.token) for tokens2 in external_tenant_tokens]) + session.headers['x-ms-authorization-auxiliary'] = aux_tokens + return session + + def get_token(self, *scopes, **kwargs): + logger.debug("CredentialAdaptor.get_token: scopes=%r, kwargs=%r", scopes, kwargs) + scopes = _normalize_scopes(scopes) + token, _ = self._get_token(scopes, **kwargs) + return token + + def get_auxiliary_tokens(self, *scopes, **kwargs): + if self._auxiliary_credentials: + return [cred.get_token(*scopes, **kwargs) for cred in self._auxiliary_credentials] + return None diff --git a/src/azure-cli-core/azure/cli/core/auth/identity.py b/src/azure-cli-core/azure/cli/core/auth/identity.py new file mode 100644 index 00000000000..37e065b0334 --- /dev/null +++ b/src/azure-cli-core/azure/cli/core/auth/identity.py @@ -0,0 +1,325 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- + +import json +import os +import re + +from azure.cli.core._environment import get_config_dir +from knack.log import get_logger +from knack.util import CLIError + +from .msal_authentication import UserCredential, ServicePrincipalCredential +from .util import check_result + +# Service principal entry properties +from .msal_authentication import _CLIENT_ID, _TENANT, _CLIENT_SECRET, _CERTIFICATE, _CLIENT_ASSERTION,\ + _USE_CERT_SN_ISSUER + +AZURE_CLI_CLIENT_ID = '04b07795-8ddb-461a-bbee-02f9e1bf7b46' + + +logger = get_logger(__name__) + + +class Identity: # pylint: disable=too-many-instance-attributes + """Class to manage identities: + - user + - service principal + - TODO: managed identity + """ + # Whether token and secrets should be encrypted. Change its value to turn on/off token encryption. + token_encryption = False + + # HTTP cache for MSAL's tenant discovery, retry-after error cache, etc. + # It must follow singleton pattern. Otherwise, a new dbm.dumb http_cache can read out-of-sync dat and dir. + # https://github.com/AzureAD/microsoft-authentication-library-for-python/pull/407 + http_cache = None + + def __init__(self, authority=None, tenant_id=None, client_id=None): + """ + + :param authority: AAD endpoint, like https://login.microsoftonline.com/ + :param tenant_id: Tenant GUID, like 00000000-0000-0000-0000-000000000000 + :param client_id: Client ID of the CLI application. + """ + self.authority = authority + self.tenant_id = tenant_id or "organizations" + # Build the authority in MSAL style, like https://login.microsoftonline.com/your_tenant + self.msal_authority = "{}/{}".format(self.authority, self.tenant_id) + self.client_id = client_id or AZURE_CLI_CLIENT_ID + + config_dir = get_config_dir() + self._token_cache_file = os.path.join(config_dir, "msal_token_cache") + self._secret_file = os.path.join(config_dir, "service_principal_entries") + self._http_cache_file = os.path.join(config_dir, "msal_http_cache") + + # Prepare HTTP cache. + # https://github.com/AzureAD/microsoft-authentication-library-for-python/pull/407 + # if not Identity.http_cache: + # Identity.http_cache = self._load_http_cache() + + self._msal_app_instance = None + # Store for Service principal credential persistence + self._msal_secret_store = ServicePrincipalStore(self._secret_file, self.token_encryption) + self._msal_app_kwargs = { + "authority": self.msal_authority, + "token_cache": self._load_msal_cache() + # "http_cache": Identity.http_cache + } + + def _load_msal_cache(self): + from .persistence import load_persisted_token_cache + # Store for user token persistence + cache = load_persisted_token_cache(self._token_cache_file, self.token_encryption) + return cache + + def _load_http_cache(self): + import atexit + import pickle + + try: + with open(self._http_cache_file, 'rb') as f: + persisted_http_cache = pickle.load(f) # Take a snapshot + except: # pylint: disable=bare-except + persisted_http_cache = {} # Ignore a non-exist or corrupted http_cache + atexit.register(lambda: pickle.dump( + # When exit, flush it back to the file. + # If 2 processes write at the same time, the cache will be corrupted, + # but that is fine. Subsequent runs would reach eventual consistency. + persisted_http_cache, open(self._http_cache_file, 'wb'))) + + return persisted_http_cache + + def _build_persistent_msal_app(self): + # Initialize _msal_app for login and logout + from msal import PublicClientApplication + msal_app = PublicClientApplication(self.client_id, **self._msal_app_kwargs) + return msal_app + + @property + def msal_app(self): + if not self._msal_app_instance: + self._msal_app_instance = self._build_persistent_msal_app() + return self._msal_app_instance + + def login_with_auth_code(self, scopes=None, **kwargs): + # Emit a warning to inform that a browser is opened. + # Only show the path part of the URL and hide the query string. + logger.warning("The default web browser has been opened at %s/oauth2/v2.0/authorize. " + "Please continue the login in the web browser. " + "If no web browser is available or if the web browser fails to open, use device code flow " + "with `az login --use-device-code`.", self.msal_authority) + + success_template, error_template = _read_response_templates() + + result = self.msal_app.acquire_token_interactive( + scopes, prompt='select_account', success_template=success_template, error_template=error_template, **kwargs) + return check_result(result) + + def login_with_device_code(self, scopes=None, **kwargs): + flow = self.msal_app.initiate_device_flow(scopes, **kwargs) + if "user_code" not in flow: + raise ValueError( + "Fail to create device flow. Err: %s" % json.dumps(flow, indent=4)) + logger.warning(flow["message"]) + result = self.msal_app.acquire_token_by_device_flow(flow, **kwargs) # By default it will block + return check_result(result) + + def login_with_username_password(self, username, password, scopes=None, **kwargs): + result = self.msal_app.acquire_token_by_username_password(username, password, scopes, **kwargs) + return check_result(result) + + def login_with_service_principal(self, client_id, credential, scopes=None): + """ + `credential` is a dict returned by ServicePrincipalAuth.build_credential + """ + sp_auth = ServicePrincipalAuth.build_from_credential(self.tenant_id, client_id, credential) + + # This cred means SDK credential object + cred = ServicePrincipalCredential(sp_auth, **self._msal_app_kwargs) + result = cred.acquire_token_for_client(scopes) + check_result(result) + + # Only persist the service principal after a successful login + entry = sp_auth.get_entry_to_persist() + self._msal_secret_store.save_entry(entry) + + def login_with_managed_identity(self, scopes, identity_id=None): # pylint: disable=too-many-statements + raise NotImplementedError + + def login_in_cloud_shell(self, scopes): + raise NotImplementedError + + def logout_user(self, user): + accounts = self.msal_app.get_accounts(user) + for account in accounts: + self.msal_app.remove_account(account) + + def logout_all_users(self): + try: + os.remove(self._token_cache_file) + except FileNotFoundError: + pass + + def logout_service_principal(self, sp): + # remove service principal secrets + self._msal_secret_store.remove_entry(sp) + + def logout_all_service_principal(self): + # remove service principal secrets + self._msal_secret_store.remove_all_entries() + + def get_user(self, user=None): + accounts = self.msal_app.get_accounts(user) if user else self.msal_app.get_accounts() + return accounts + + def get_user_credential(self, username): + return UserCredential(self.client_id, username, **self._msal_app_kwargs) + + def get_service_principal_credential(self, client_id): + entry = self._msal_secret_store.load_entry(client_id, self.tenant_id) + sp_auth = ServicePrincipalAuth(entry) + return ServicePrincipalCredential(sp_auth, **self._msal_app_kwargs) + + def get_managed_identity_credential(self, client_id=None): + raise NotImplementedError + + +class ServicePrincipalAuth: + + def __init__(self, entry): + self.__dict__.update(entry) + + if _CERTIFICATE in entry: + from OpenSSL.crypto import load_certificate, FILETYPE_PEM, Error + self.public_certificate = None + try: + with open(self.certificate, 'r') as file_reader: + self.certificate_string = file_reader.read() + cert = load_certificate(FILETYPE_PEM, self.certificate_string) + self.thumbprint = cert.digest("sha1").decode().replace(':', '') + if entry.get(_USE_CERT_SN_ISSUER): + # low-tech but safe parsing based on + # https://github.com/libressl-portable/openbsd/blob/master/src/lib/libcrypto/pem/pem.h + match = re.search(r'-----BEGIN CERTIFICATE-----(?P[^-]+)-----END CERTIFICATE-----', + self.certificate_string, re.I) + self.public_certificate = match.group() + except (UnicodeDecodeError, Error) as ex: + raise CLIError('Invalid certificate, please use a valid PEM file. Error detail: {}'.format(ex)) + + @classmethod + def build_from_credential(cls, tenant_id, client_id, credential): + entry = { + _TENANT: tenant_id, + _CLIENT_ID: client_id + } + entry.update(credential) + return ServicePrincipalAuth(entry) + + @classmethod + def build_credential(cls, secret_or_certificate=None, client_assertion=None, use_cert_sn_issuer=None): + """Build credential from user input. The credential looks like below, but only one key can exist. + { + 'client_secret': 'my_secret', + 'certificate': '/path/to/cert.pem', + 'client_assertion': 'my_federated_token' + } + """ + entry = {} + if secret_or_certificate: + if os.path.isfile(secret_or_certificate): + entry[_CERTIFICATE] = secret_or_certificate + if use_cert_sn_issuer: + entry[_USE_CERT_SN_ISSUER] = use_cert_sn_issuer + else: + entry[_CLIENT_SECRET] = secret_or_certificate + elif client_assertion: + entry[_CLIENT_ASSERTION] = client_assertion + return entry + + def get_entry_to_persist(self): + persisted_keys = [_CLIENT_ID, _TENANT, _CLIENT_SECRET, _CERTIFICATE, _USE_CERT_SN_ISSUER, _CLIENT_ASSERTION] + return {k: v for k, v in self.__dict__.items() if k in persisted_keys} + + +class ServicePrincipalStore: + """Save secrets in MSAL custom secret store for Service Principal authentication. + """ + + def __init__(self, secret_file, encrypt): + from .persistence import load_secret_store + self._secret_store = load_secret_store(secret_file, encrypt) + self._secret_file = secret_file + self._entries = [] + + def load_entry(self, sp_id, tenant): + self._load_persistence() + matched = [x for x in self._entries if sp_id == x[_CLIENT_ID]] + if not matched: + raise CLIError("Could not retrieve credential from local cache for service principal {}. " + "Please run `az login` for this service principal." + .format(sp_id)) + matched_with_tenant = [x for x in matched if tenant == x[_TENANT]] + if matched_with_tenant: + cred = matched_with_tenant[0] + else: + logger.warning("Could not retrieve credential from local cache for service principal %s under tenant %s. " + "Trying credential under tenant %s, assuming that is an app credential.", + sp_id, tenant, matched[0][_TENANT]) + cred = matched[0] + + return cred + + def save_entry(self, sp_entry): + self._load_persistence() + + self._entries = [ + x for x in self._entries + if not (sp_entry[_CLIENT_ID] == x[_CLIENT_ID] and + sp_entry[_TENANT] == x[_TENANT])] + + self._entries.append(sp_entry) + self._save_persistence() + + def remove_entry(self, sp_id): + self._load_persistence() + state_changed = False + + # clear service principal creds + matched = [x for x in self._entries + if x[_CLIENT_ID] == sp_id] + if matched: + state_changed = True + self._entries = [x for x in self._entries + if x not in matched] + + if state_changed: + self._save_persistence() + + def remove_all_entries(self): + try: + os.remove(self._secret_file) + except FileNotFoundError: + pass + + def _save_persistence(self): + self._secret_store.save(self._entries) + + def _load_persistence(self): + self._entries = self._secret_store.load() + + +def _read_response_templates(): + """Read from success.html and error.html to strings and pass them to MSAL. """ + success_file = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'landing_pages', 'success.html') + with open(success_file) as f: + success_template = f.read() + + error_file = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'landing_pages', 'error.html') + with open(error_file) as f: + error_template = f.read() + + return success_template, error_template diff --git a/src/azure-cli-core/azure/cli/core/auth_landing_pages/fail.html b/src/azure-cli-core/azure/cli/core/auth/landing_pages/error.html similarity index 76% rename from src/azure-cli-core/azure/cli/core/auth_landing_pages/fail.html rename to src/azure-cli-core/azure/cli/core/auth/landing_pages/error.html index e4635b0ea58..a7998994c43 100644 --- a/src/azure-cli-core/azure/cli/core/auth_landing_pages/fail.html +++ b/src/azure-cli-core/azure/cli/core/auth/landing_pages/error.html @@ -5,7 +5,8 @@ Login failed -

Some failures occurred during the authentication

+

Authentication failed

+

$error: $error_description. ($error_uri)

You can log an issue at Azure CLI GitHub Repository and we will assist you in resolving it.

diff --git a/src/azure-cli-core/azure/cli/core/auth_landing_pages/ok.html b/src/azure-cli-core/azure/cli/core/auth/landing_pages/success.html similarity index 98% rename from src/azure-cli-core/azure/cli/core/auth_landing_pages/ok.html rename to src/azure-cli-core/azure/cli/core/auth/landing_pages/success.html index 8d506ffce17..c39bcdaf7a6 100644 --- a/src/azure-cli-core/azure/cli/core/auth_landing_pages/ok.html +++ b/src/azure-cli-core/azure/cli/core/auth/landing_pages/success.html @@ -9,4 +9,4 @@

You have logged into Microsoft Azure!

You can close this window, or we will redirect you to the Azure CLI documents in 10 seconds.

- \ No newline at end of file + diff --git a/src/azure-cli-core/azure/cli/core/auth/msal_authentication.py b/src/azure-cli-core/azure/cli/core/auth/msal_authentication.py new file mode 100644 index 00000000000..799aabfc568 --- /dev/null +++ b/src/azure-cli-core/azure/cli/core/auth/msal_authentication.py @@ -0,0 +1,106 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- + +""" +Credentials defined in this module are alternative implementations of credentials provided by Azure Identity. + +These credentials implements azure.core.credentials.TokenCredential by exposing get_token method for Track 2 +SDK invocation. +""" + +from azure.core.credentials import AccessToken +from knack.log import get_logger +from knack.util import CLIError +from msal import PublicClientApplication, ConfidentialClientApplication + +from .util import check_result + +# OAuth 2.0 client credentials flow parameter +# https://docs.microsoft.com/en-us/azure/active-directory/develop/v2-oauth2-client-creds-grant-flow +_TENANT = 'tenant' +_CLIENT_ID = 'client_id' +_CLIENT_SECRET = 'client_secret' +_CERTIFICATE = 'certificate' +_CLIENT_ASSERTION = 'client_assertion' +_USE_CERT_SN_ISSUER = 'use_cert_sn_issuer' + +logger = get_logger(__name__) + + +class UserCredential(PublicClientApplication): + + def __init__(self, client_id, username=None, **kwargs): + super().__init__(client_id, **kwargs) + if username: + accounts = self.get_accounts(username) + + if not accounts: + raise CLIError("User {} doesn't exist in the credential cache. The user could have been logged out by " + "another application that uses Single Sign-On. " + "Please run `az login` to re-login.".format(username)) + + if len(accounts) > 1: + raise CLIError("Found multiple accounts with the same username. Please report to us via Github: " + "https://github.com/Azure/azure-cli/issues/new") + + account = accounts[0] + self.account = account + else: + self.account = None + + def get_token(self, *scopes, **kwargs): + # scopes = ['https://pas.windows.net/CheckMyAccess/Linux/.default'] + logger.debug("UserCredential.get_token: scopes=%r, kwargs=%r", scopes, kwargs) + + result = self.acquire_token_silent_with_error(list(scopes), self.account, **kwargs) + check_result(result, scopes=scopes) + return _build_sdk_access_token(result) + + +class ServicePrincipalCredential(ConfidentialClientApplication): + + def __init__(self, service_principal_auth, **kwargs): + + client_credential = None + + # client_secret + client_secret = getattr(service_principal_auth, _CLIENT_SECRET, None) + if client_secret: + client_credential = client_secret + + # certificate + certificate = getattr(service_principal_auth, _CERTIFICATE, None) + if certificate: + client_credential = { + "private_key": getattr(service_principal_auth, 'certificate_string'), + "thumbprint": getattr(service_principal_auth, 'thumbprint') + } + public_certificate = getattr(service_principal_auth, 'public_certificate', None) + if public_certificate: + client_credential['public_certificate'] = public_certificate + + # client_assertion + client_assertion = getattr(service_principal_auth, _CLIENT_ASSERTION, None) + if client_assertion: + client_credential = {'client_assertion': client_assertion} + + super().__init__(service_principal_auth.client_id, client_credential=client_credential, **kwargs) + + def get_token(self, *scopes, **kwargs): + logger.debug("ServicePrincipalCredential.get_token: scopes=%r, kwargs=%r", scopes, kwargs) + + scopes = list(scopes) + result = self.acquire_token_silent(scopes, None, **kwargs) + if not result: + result = self.acquire_token_for_client(scopes, **kwargs) + check_result(result) + return _build_sdk_access_token(result) + + +def _build_sdk_access_token(token_entry): + import time + request_time = int(time.time()) + + return AccessToken(token_entry["access_token"], request_time + token_entry["expires_in"]) diff --git a/src/azure-cli-core/azure/cli/core/auth/persistence.py b/src/azure-cli-core/azure/cli/core/auth/persistence.py new file mode 100644 index 00000000000..1620abf892f --- /dev/null +++ b/src/azure-cli-core/azure/cli/core/auth/persistence.py @@ -0,0 +1,69 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- + +# This file is modified from +# https://github.com/AzureAD/microsoft-authentication-extensions-for-python/blob/dev/sample/token_cache_sample.py + +import json +import sys + +from msal_extensions import (FilePersistenceWithDataProtection, KeychainPersistence, LibsecretPersistence, + FilePersistence, PersistedTokenCache, CrossPlatLock) +from msal_extensions.persistence import PersistenceNotFound + +from knack.util import CLIError +from knack.log import get_logger + +logger = get_logger(__name__) + + +def load_persisted_token_cache(location, encrypt): + persistence = build_persistence(location, encrypt) + return PersistedTokenCache(persistence) + + +def load_secret_store(location, encrypt): + persistence = build_persistence(location, encrypt) + return SecretStore(persistence) + + +def build_persistence(location, encrypt): + """Build a suitable persistence instance based your current OS""" + if encrypt: + location += '.bin' + if sys.platform.startswith('win'): + return FilePersistenceWithDataProtection(location) + if sys.platform.startswith('darwin'): + return KeychainPersistence(location, "my_service_name", "my_account_name") + if sys.platform.startswith('linux'): + return LibsecretPersistence( + location, + schema_name="my_schema_name", + attributes={"my_attr1": "foo", "my_attr2": "bar"} + ) + else: + location += '.json' + return FilePersistence(location) + + +class SecretStore: + def __init__(self, persistence): + self._lock_file = persistence.get_location() + ".lockfile" + self._persistence = persistence + + def save(self, content): + with CrossPlatLock(self._lock_file): + self._persistence.save(json.dumps(content, indent=4)) + + def load(self): + with CrossPlatLock(self._lock_file): + try: + return json.loads(self._persistence.load()) + except PersistenceNotFound: + return [] + except Exception as ex: + raise CLIError("Failed to load token files. If you can reproduce, please log an issue at " + "https://github.com/Azure/azure-cli/issues. At the same time, you can clean " + "up by running 'az account clear' and then 'az login'. (Inner Error: {})".format(ex)) diff --git a/src/azure-cli-core/azure/cli/core/auth/tests/__init__.py b/src/azure-cli-core/azure/cli/core/auth/tests/__init__.py new file mode 100644 index 00000000000..34913fb394d --- /dev/null +++ b/src/azure-cli-core/azure/cli/core/auth/tests/__init__.py @@ -0,0 +1,4 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- diff --git a/src/azure-cli-core/azure/cli/core/tests/err_sp_cert.pem b/src/azure-cli-core/azure/cli/core/auth/tests/err_sp_cert.pem similarity index 100% rename from src/azure-cli-core/azure/cli/core/tests/err_sp_cert.pem rename to src/azure-cli-core/azure/cli/core/auth/tests/err_sp_cert.pem diff --git a/src/azure-cli-core/azure/cli/core/tests/sp_cert.pem b/src/azure-cli-core/azure/cli/core/auth/tests/sp_cert.pem similarity index 100% rename from src/azure-cli-core/azure/cli/core/tests/sp_cert.pem rename to src/azure-cli-core/azure/cli/core/auth/tests/sp_cert.pem diff --git a/src/azure-cli-core/azure/cli/core/auth/tests/test_credential_adaptor.py b/src/azure-cli-core/azure/cli/core/auth/tests/test_credential_adaptor.py new file mode 100644 index 00000000000..e955f492cd2 --- /dev/null +++ b/src/azure-cli-core/azure/cli/core/auth/tests/test_credential_adaptor.py @@ -0,0 +1,18 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- + +# pylint: disable=protected-access +import os +import json +import unittest +from unittest import mock + + +class TestIdentity(unittest.TestCase): + pass + + +if __name__ == '__main__': + unittest.main() diff --git a/src/azure-cli-core/azure/cli/core/auth/tests/test_identity.py b/src/azure-cli-core/azure/cli/core/auth/tests/test_identity.py new file mode 100644 index 00000000000..e38fd8d9eb5 --- /dev/null +++ b/src/azure-cli-core/azure/cli/core/auth/tests/test_identity.py @@ -0,0 +1,155 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- + +import os +import unittest +from unittest import mock + +from azure.cli.core.auth.identity import Identity, ServicePrincipalAuth, ServicePrincipalStore +from knack.util import CLIError + + +class TestIdentity(unittest.TestCase): + + def test_login_with_service_principal_certificate_cert_err(self): + import os + identity = Identity() + current_dir = os.path.dirname(os.path.realpath(__file__)) + test_cert_file = os.path.join(current_dir, 'err_sp_cert.pem') + + with self.assertRaisesRegex(CLIError, "Invalid certificate"): + identity.login_with_service_principal("00000000-0000-0000-0000-000000000000", + {"certificate": test_cert_file}) + + +class TestServicePrincipalAuth(unittest.TestCase): + + def test_service_principal_auth_client_secret(self): + sp_auth = ServicePrincipalAuth.build_from_credential('tenant1', 'sp_id1', {'client_secret': "test_secret"}) + result = sp_auth.get_entry_to_persist() + + assert result == { + 'client_id': 'sp_id1', + 'tenant': 'tenant1', + 'client_secret': 'test_secret' + } + + def test_service_principal_auth_client_cert(self): + curr_dir = os.path.dirname(os.path.realpath(__file__)) + test_cert_file = os.path.join(curr_dir, 'sp_cert.pem') + sp_auth = ServicePrincipalAuth.build_from_credential('tenant1', 'sp_id1', {'certificate': test_cert_file}) + + result = sp_auth.get_entry_to_persist() + # To compute the thumb print: + # openssl x509 -in sp_cert.pem -noout -fingerprint + assert sp_auth.thumbprint == 'F06A53848BBE714A4290D69D335279C1D01073FD' + assert result == { + 'client_id': 'sp_id1', + 'tenant': 'tenant1', + 'certificate': test_cert_file + } + + def test_build_credential(self): + # secret + cred = ServicePrincipalAuth.build_credential("test_secret") + assert cred == {"client_secret": "test_secret"} + + # certificate + current_dir = os.path.dirname(os.path.realpath(__file__)) + test_cert_file = os.path.join(current_dir, 'sp_cert.pem') + cred = ServicePrincipalAuth.build_credential(test_cert_file) + assert cred == {'certificate': test_cert_file} + + cred = ServicePrincipalAuth.build_credential(test_cert_file, use_cert_sn_issuer=True) + assert cred == {'certificate': test_cert_file, 'use_cert_sn_issuer': True} + + # client assertion + cred = ServicePrincipalAuth.build_credential(client_assertion="test_jwt") + assert cred == {"client_assertion": "test_jwt"} + + +class TestMsalSecretStore(unittest.TestCase): + + test_sp = { + 'client_id': 'myapp', + 'tenant': 'mytenant', + 'client_secret': 'test_secret' + } + + @mock.patch('azure.cli.core.auth.persistence.load_secret_store') + def test_load_entry(self, load_secret_store_mock): + store = MemoryStore() + load_secret_store_mock.return_value = store + + secret_store = ServicePrincipalStore(None, None) + store._content = [self.test_sp] + + entry = secret_store.load_entry("myapp", "mytenant") + self.assertEqual(entry['client_secret'], "test_secret") + + @mock.patch('azure.cli.core.auth.persistence.load_secret_store') + def test_save_entry(self, load_secret_store_mock): + store = MemoryStore() + load_secret_store_mock.return_value = store + + secret_store = ServicePrincipalStore(None, None) + secret_store.save_entry(self.test_sp) + + assert store._content == [self.test_sp] + + @mock.patch('azure.cli.core.auth.persistence.load_secret_store') + def test_save_entry_add_new(self, load_secret_store_mock): + store = MemoryStore() + load_secret_store_mock.return_value = store + + test_sp2 = { + 'client_id': "myapp2", + 'tenant': "mytenant2", + 'client_secret': "test_secret2" + } + + store._content = [self.test_sp] + secret_store = ServicePrincipalStore(None, None) + secret_store.save_entry(test_sp2) + assert store._content == [self.test_sp, test_sp2] + + @mock.patch('azure.cli.core.auth.persistence.load_secret_store') + def test_save_entry_update_existing(self, load_secret_store_mock): + store = MemoryStore() + load_secret_store_mock.return_value = store + + store._content = [self.test_sp] + new_creds = self.test_sp.copy() + new_creds['client_secret'] = 'test_secret' + + secret_store = ServicePrincipalStore(None, None) + secret_store.save_entry(new_creds) + assert store._content == [new_creds] + + @mock.patch('azure.cli.core.auth.persistence.load_secret_store') + def test_remove_entry(self, load_secret_store_mock): + store = MemoryStore() + load_secret_store_mock.return_value = store + + store._content = [self.test_sp] + secret_store = ServicePrincipalStore(None, None) + secret_store.remove_entry('myapp') + assert store._content == [] + + +class MemoryStore: + + def __init__(self): + self._content = [] + + def save(self, content): + self._content = content + + def load(self): + return self._content + + +if __name__ == '__main__': + unittest.main() diff --git a/src/azure-cli-core/azure/cli/core/auth/tests/test_util.py b/src/azure-cli-core/azure/cli/core/auth/tests/test_util.py new file mode 100644 index 00000000000..f5db382d736 --- /dev/null +++ b/src/azure-cli-core/azure/cli/core/auth/tests/test_util.py @@ -0,0 +1,78 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- + +# pylint: disable=protected-access + +import unittest +from ..util import scopes_to_resource, resource_to_scopes, _normalize_scopes, _generate_login_command + + +class TestUtil(unittest.TestCase): + + def test_scopes_to_resource(self): + # scopes as a list + self.assertEqual(scopes_to_resource(['https://management.core.windows.net//.default']), + 'https://management.core.windows.net/') + # scopes as a tuple + self.assertEqual(scopes_to_resource(('https://storage.azure.com/.default',)), + 'https://storage.azure.com') + + # resource with trailing slash + self.assertEqual(scopes_to_resource(('https://management.azure.com//.default',)), + 'https://management.azure.com/') + self.assertEqual(scopes_to_resource(['https://datalake.azure.net//.default']), + 'https://datalake.azure.net/') + + # resource without trailing slash + self.assertEqual(scopes_to_resource(('https://managedhsm.azure.com/.default',)), + 'https://managedhsm.azure.com') + + # VM SSH + self.assertEqual(scopes_to_resource(["https://pas.windows.net/CheckMyAccess/Linux/.default"]), + 'https://pas.windows.net/CheckMyAccess/Linux') + self.assertEqual(scopes_to_resource(["https://pas.windows.net/CheckMyAccess/Linux/user_impersonation"]), + 'https://pas.windows.net/CheckMyAccess/Linux') + + def test_resource_to_scopes(self): + # resource converted to a scopes list + self.assertEqual(resource_to_scopes('https://management.core.windows.net/'), + ['https://management.core.windows.net//.default']) + + # resource with trailing slash + self.assertEqual(resource_to_scopes('https://management.azure.com/'), + ['https://management.azure.com//.default']) + self.assertEqual(resource_to_scopes('https://datalake.azure.net/'), + ['https://datalake.azure.net//.default']) + + # resource without trailing slash + self.assertEqual(resource_to_scopes('https://managedhsm.azure.com'), + ['https://managedhsm.azure.com/.default']) + + def test_normalize_scopes(self): + # Test no scopes + self.assertIsNone(_normalize_scopes(())) + self.assertIsNone(_normalize_scopes([])) + self.assertIsNone(_normalize_scopes(None)) + + # Test multiple scopes, with the first one discarded + scopes = _normalize_scopes(("https://management.core.windows.net//.default", + "https://management.core.chinacloudapi.cn//.default")) + self.assertEqual(list(scopes), ["https://management.core.chinacloudapi.cn//.default"]) + + # Test single scopes (the correct usage) + scopes = _normalize_scopes(("https://management.core.chinacloudapi.cn//.default",)) + self.assertEqual(list(scopes), ["https://management.core.chinacloudapi.cn//.default"]) + + def test_generate_login_command(self): + # No parameter is given + assert _generate_login_command() == 'az login' + + # scopes + actual = _generate_login_command(scopes=["https://management.core.windows.net//.default"]) + assert actual == 'az login --scope https://management.core.windows.net//.default' + + +if __name__ == '__main__': + unittest.main() diff --git a/src/azure-cli-core/azure/cli/core/auth/util.py b/src/azure-cli-core/azure/cli/core/auth/util.py index 372dfd0bc81..5d79080b907 100644 --- a/src/azure-cli-core/azure/cli/core/auth/util.py +++ b/src/azure-cli-core/azure/cli/core/auth/util.py @@ -3,12 +3,23 @@ # Licensed under the MIT License. See License.txt in the project root for license information. # -------------------------------------------------------------------------------------------- +from knack.log import get_logger + +logger = get_logger(__name__) + def aad_error_handler(error, **kwargs): """ Handle the error from AAD server returned by ADAL or MSAL. """ # https://docs.microsoft.com/en-us/azure/active-directory/develop/reference-aadsts-error-codes # Search for an error code at https://login.microsoftonline.com/error + + from azure.cli.core.util import in_cloud_console + if in_cloud_console(): + import socket + logger.warning("A Cloud Shell credential problem occurred. When you report the issue with the error " + "below, please mention the hostname '%s'", socket.gethostname()) + msg = error.get('error_description') login_message = _generate_login_message(**kwargs) @@ -19,6 +30,7 @@ def aad_error_handler(error, **kwargs): def _generate_login_command(scopes=None): login_command = ['az login'] + # Rejected by Conditional Access policy, like MFA if scopes: login_command.append('--scope {}'.format(' '.join(scopes))) @@ -29,10 +41,96 @@ def _generate_login_message(**kwargs): from azure.cli.core.util import in_cloud_console login_command = _generate_login_command(**kwargs) - msg = "To re-authenticate, please {}" .format( + login_msg = "To re-authenticate, please {}" .format( "refresh Azure Portal." if in_cloud_console() else "run:\n{}".format(login_command)) - return msg + contact_admin_msg = "If the problem persists, please contact your tenant administrator." + return "{}\n\n{}".format(login_msg, contact_admin_msg) + + +def resource_to_scopes(resource): + """Convert the ADAL resource ID to MSAL scopes by appending the /.default suffix and return a list. + For example: + 'https://management.core.windows.net/' -> ['https://management.core.windows.net//.default'] + 'https://managedhsm.azure.com' -> ['https://managedhsm.azure.com/.default'] + + :param resource: The ADAL resource ID + :return: A list of scopes + """ + # https://docs.microsoft.com/en-us/azure/active-directory/develop/v2-permissions-and-consent#trailing-slash-and-default + # We should not trim the trailing slash, like in https://management.azure.com/ + # In other word, the trailing slash should be preserved and scope should be https://management.azure.com//.default + scope = resource + '/.default' + return [scope] + + +def scopes_to_resource(scopes): + """Convert MSAL scopes to ADAL resource by stripping the /.default suffix and return a str. + For example: + ['https://management.core.windows.net//.default'] -> 'https://management.core.windows.net/' + ['https://managedhsm.azure.com/.default'] -> 'https://managedhsm.azure.com' + + :param scopes: The MSAL scopes. It can be a list or tuple of string + :return: The ADAL resource + :rtype: str + """ + if not scopes: + return None + + scope = scopes[0] + suffixes = ['/.default', '/user_impersonation'] + for s in suffixes: + if scope.endswith(s): + return scope[:-len(s)] + + return scope + + +def _normalize_scopes(scopes): + """Normalize scopes to workaround some SDK issues.""" + + # Track 2 SDKs generated before https://github.com/Azure/autorest.python/pull/239 don't maintain + # credential_scopes and call `get_token` with empty scopes. + # As a workaround, return None so that the CLI-managed resource is used. + if not scopes: + logger.debug("No scope is provided by the SDK, use the CLI-managed resource.") + return None + + # Track 2 SDKs generated before https://github.com/Azure/autorest.python/pull/745 extend default + # credential_scopes with custom credential_scopes. Instead, credential_scopes should be replaced by + # custom credential_scopes. https://github.com/Azure/azure-sdk-for-python/issues/12947 + # As a workaround, remove the first one if there are multiple scopes provided. + if len(scopes) > 1: + logger.debug("Multiple scopes are provided by the SDK, discarding the first one: %s", scopes[0]) + return scopes[1:] + + return scopes + + +def check_result(result, **kwargs): + """ + 1. Check if the MSAL result contains a valid access token. + 2. If there is error, handle the error and show re-login message. + 3. For user login, return the username and tenant_id in a dict. + """ + from azure.cli.core.azclierror import AuthenticationError + + if not result: + raise AuthenticationError("Can't find token from MSAL cache.", + recommendation="To re-authenticate, please run:\naz login") + if 'error' in result: + aad_error_handler(result, **kwargs) + + # For user authentication + if 'id_token_claims' in result: + idt = result['id_token_claims'] + return { + # AAD returns "preferred_username", ADFS returns "upn" + 'username': idt.get("preferred_username") or idt["upn"], + 'tenant_id': idt['tid'] + } + + return None def decode_access_token(access_token): diff --git a/src/azure-cli-core/azure/cli/core/commands/client_factory.py b/src/azure-cli-core/azure/cli/core/commands/client_factory.py index a267da25f42..4f13bc14916 100644 --- a/src/azure-cli-core/azure/cli/core/commands/client_factory.py +++ b/src/azure-cli-core/azure/cli/core/commands/client_factory.py @@ -4,11 +4,11 @@ # -------------------------------------------------------------------------------------------- import azure.cli.core._debug as _debug +from azure.cli.core.auth.util import resource_to_scopes from azure.cli.core.extension import EXTENSIONS_MOD_PREFIX -from azure.cli.core.profiles._shared import get_client_class, SDKProfile from azure.cli.core.profiles import ResourceType, CustomResourceType, get_api_version, get_sdk +from azure.cli.core.profiles._shared import get_client_class, SDKProfile from azure.cli.core.util import get_az_user_agent, is_track2 - from knack.log import get_logger from knack.util import CLIError @@ -167,21 +167,26 @@ def _prepare_mgmt_client_kwargs_track2(cli_ctx, cred): """Prepare kwargs for Track 2 SDK mgmt client.""" client_kwargs = _prepare_client_kwargs_track2(cli_ctx) - from azure.cli.core.util import resource_to_scopes + # Enable CAE support in mgmt SDK + from azure.core.pipeline.policies import BearerTokenCredentialPolicy + # Track 2 SDK maintains `scopes` and passes `scopes` to get_token. scopes = resource_to_scopes(cli_ctx.cloud.endpoints.active_directory_resource_id) + policy = BearerTokenCredentialPolicy(cred, *scopes) client_kwargs['credential_scopes'] = scopes + client_kwargs['authentication_policy'] = policy # Track 2 currently lacks the ability to take external credentials. # https://github.com/Azure/azure-sdk-for-python/issues/8313 # As a temporary workaround, manually add external tokens to 'x-ms-authorization-auxiliary' header. # https://docs.microsoft.com/en-us/azure/azure-resource-manager/management/authenticate-multi-tenant - if getattr(cred, "_external_tenant_token_retriever", None): - *_, external_tenant_tokens = cred.get_all_tokens(*scopes) - # Hard-code scheme to 'Bearer' as _BearerTokenCredentialPolicyBase._update_headers does. - client_kwargs['headers']['x-ms-authorization-auxiliary'] = \ - ', '.join("Bearer {}".format(t[1]) for t in external_tenant_tokens) + if hasattr(cred, "get_auxiliary_tokens"): + aux_tokens = cred.get_auxiliary_tokens(*scopes) + if aux_tokens: + # Hard-code scheme to 'Bearer' as _BearerTokenCredentialPolicyBase._update_headers does. + client_kwargs['headers']['x-ms-authorization-auxiliary'] = \ + ', '.join("Bearer {}".format(token.token) for token in aux_tokens) return client_kwargs @@ -199,6 +204,9 @@ def _get_mgmt_service_client(cli_ctx, **kwargs): from azure.cli.core._profile import Profile logger.debug('Getting management service client client_type=%s', client_type.__name__) + + # Track 1 SDK doesn't maintain the `resource`. The `resource` of the token is the one passed to + # get_login_credentials. resource = resource or cli_ctx.cloud.endpoints.active_directory_resource_id profile = Profile(cli_ctx=cli_ctx) cred, subscription_id, _ = profile.get_login_credentials(subscription_id=subscription_id, resource=resource, diff --git a/src/azure-cli-core/azure/cli/core/commands/tests/__init__.py b/src/azure-cli-core/azure/cli/core/commands/tests/__init__.py new file mode 100644 index 00000000000..34913fb394d --- /dev/null +++ b/src/azure-cli-core/azure/cli/core/commands/tests/__init__.py @@ -0,0 +1,4 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- diff --git a/src/azure-cli-core/azure/cli/core/commands/tests/test_client_factory.py b/src/azure-cli-core/azure/cli/core/commands/tests/test_client_factory.py new file mode 100644 index 00000000000..637b1a1d0d4 --- /dev/null +++ b/src/azure-cli-core/azure/cli/core/commands/tests/test_client_factory.py @@ -0,0 +1,93 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- + +# pylint: disable=protected-access +import unittest +from unittest import mock +import os + +from azure.cli.core.commands.client_factory import get_mgmt_service_client +from azure.cli.core.mock import DummyCli +from azure.cli.core.profiles import ResourceType +from azure.cli.testsdk import ScenarioTest, LiveScenarioTest +from knack.util import CLIError +from azure.cli.testsdk import live_only, MOCKED_USER_NAME +from azure.cli.testsdk.constants import AUX_SUBSCRIPTION, AUX_TENANT + +from azure_devtools.scenario_tests.const import MOCKED_SUBSCRIPTION_ID, MOCKED_TENANT_ID + +mock_subscriptions = [ + { + "id": MOCKED_SUBSCRIPTION_ID, + "state": "Enabled", + "name": "Example", + "tenantId": MOCKED_TENANT_ID, + "isDefault": True, + "user": { + "name": MOCKED_USER_NAME, + "type": "user" + } + }, + { + "id": AUX_SUBSCRIPTION, + "state": "Enabled", + "name": "Auxiliary Subscription", + "tenantId": AUX_TENANT, + "isDefault": False, + "user": { + "name": MOCKED_USER_NAME, + "type": "user" + } + } +] + + +class CredentialMock: + def __init__(self, *args, **kwargs): + super().__init__() + self._authority = kwargs.get('authority') + + def get_token(self, *scopes, **kwargs): # pylint: disable=unused-argument + from azure.core.credentials import AccessToken + import time + now = int(time.time()) + return AccessToken("access_token_from_" + self._authority, now + 3600) + + +class TestClientFactory(unittest.TestCase): + def test_get_mgmt_service_client(self): + cli = DummyCli() + client = get_mgmt_service_client(cli, ResourceType.MGMT_RESOURCE_RESOURCES) + assert client + + @mock.patch("azure.cli.core.auth.identity.UserCredential", CredentialMock) + @mock.patch('azure.cli.core._profile.Profile.load_cached_subscriptions', return_value=mock_subscriptions) + def test_get_mgmt_service_client_with_aux_subs_and_tenants(self, load_cached_subscriptions_mock): + cli = DummyCli() + + def _verify_client_aux_token(client_to_check): + aux_tokens = client_to_check._config.headers_policy.headers.get('x-ms-authorization-auxiliary') + assert aux_tokens + assert aux_tokens.startswith("Bearer ") + assert AUX_TENANT in aux_tokens + + # Specify aux_subscriptions + client = get_mgmt_service_client(cli, ResourceType.MGMT_RESOURCE_RESOURCES, + aux_subscriptions=[AUX_SUBSCRIPTION]) + _verify_client_aux_token(client) + + # Specify aux_tenants + client = get_mgmt_service_client(cli, ResourceType.MGMT_RESOURCE_RESOURCES, + aux_tenants=[AUX_TENANT]) + _verify_client_aux_token(client) + + # But not both + with self.assertRaisesRegex(CLIError, "only one of aux_subscriptions and aux_tenants"): + get_mgmt_service_client(cli, ResourceType.MGMT_RESOURCE_RESOURCES, + aux_subscriptions=[AUX_SUBSCRIPTION], aux_tenants=[AUX_TENANT]) + + +if __name__ == '__main__': + unittest.main() diff --git a/src/azure-cli-core/azure/cli/core/msal_authentication.py b/src/azure-cli-core/azure/cli/core/msal_authentication.py deleted file mode 100644 index ffaaba0d927..00000000000 --- a/src/azure-cli-core/azure/cli/core/msal_authentication.py +++ /dev/null @@ -1,50 +0,0 @@ -# -------------------------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for license information. -# -------------------------------------------------------------------------------------------- - -""" -Credentials defined in this module are alternative implementations of credentials provided by Azure Identity. - -These credentials implements azure.core.credentials.TokenCredential by exposing get_token method for Track 2 -SDK invocation. -""" - -import os - -from knack.log import get_logger -from msal import PublicClientApplication, ConfidentialClientApplication - -logger = get_logger(__name__) - - -class UserCredential(PublicClientApplication): - - def get_token(self, scopes, **kwargs): - raise NotImplementedError - - -class ServicePrincipalCredential(ConfidentialClientApplication): - - def __init__(self, client_id, secret_or_certificate=None, **kwargs): - - # If certificate file path is provided, transfer it to MSAL input - if os.path.isfile(secret_or_certificate): - cert_file = secret_or_certificate - with open(cert_file, 'r') as f: - cert_str = f.read() - - # Compute the thumbprint - from OpenSSL.crypto import load_certificate, FILETYPE_PEM - cert = load_certificate(FILETYPE_PEM, cert_str) - thumbprint = cert.digest("sha1").decode().replace(' ', '').replace(':', '') - - client_credential = {"private_key": cert_str, "thumbprint": thumbprint} - else: - client_credential = secret_or_certificate - - super().__init__(client_id, client_credential=client_credential, **kwargs) - - def get_token(self, scopes, **kwargs): - logger.debug("ServicePrincipalCredential.get_token: scopes=%r, kwargs=%r", scopes, kwargs) - return self.acquire_token_for_client(scopes=scopes, **kwargs) diff --git a/src/azure-cli-core/azure/cli/core/tests/test_adal_authentication.py b/src/azure-cli-core/azure/cli/core/tests/test_adal_authentication.py index 1d176a61b77..e69de29bb2d 100644 --- a/src/azure-cli-core/azure/cli/core/tests/test_adal_authentication.py +++ b/src/azure-cli-core/azure/cli/core/tests/test_adal_authentication.py @@ -1,87 +0,0 @@ -# -------------------------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for license information. -# -------------------------------------------------------------------------------------------- - -# pylint: disable=line-too-long -import datetime -import unittest -from unittest import mock -from unittest.mock import MagicMock - -from azure.cli.core.adal_authentication import AdalAuthentication, _try_scopes_to_resource - - -class TestUtils(unittest.TestCase): - - def test_try_scopes_to_resource(self): - # Test no scopes - self.assertIsNone(_try_scopes_to_resource(())) - self.assertIsNone(_try_scopes_to_resource([])) - self.assertIsNone(_try_scopes_to_resource(None)) - - # Test multiple scopes, with the first one discarded - resource = _try_scopes_to_resource(("https://management.core.windows.net//.default", - "https://management.core.chinacloudapi.cn//.default")) - self.assertEqual(resource, "https://management.core.chinacloudapi.cn/") - - # Test single scopes (the correct usage) - resource = _try_scopes_to_resource(("https://management.core.chinacloudapi.cn//.default",)) - self.assertEqual(resource, "https://management.core.chinacloudapi.cn/") - - -class TestAdalAuthentication(unittest.TestCase): - - def test_get_token(self): - user_full_token = ( - 'Bearer', - 'access_token_user_mock', - { - 'tokenType': 'Bearer', - 'expiresIn': 3599, - 'expiresOn': '2020-11-18 15:35:17.512862', # Local time - 'resource': 'https://management.core.windows.net/', - 'accessToken': 'access_token_user_mock', - 'refreshToken': 'refresh_token_user_mock', - 'oid': '6d97229a-391f-473a-893f-f0608b592d7b', 'userId': 'rolelivetest@azuresdkteam.onmicrosoft.com', - 'isMRRT': True, '_clientId': '04b07795-8ddb-461a-bbee-02f9e1bf7b46', - '_authority': 'https://login.microsoftonline.com/54826b22-38d6-4fb2-bad9-b7b93a3e9c5a' - }) - cloud_shell_full_token = ( - 'Bearer', - 'access_token_cloud_shell_mock', - { - 'access_token': 'access_token_cloud_shell_mock', - 'refresh_token': '', - 'expires_in': '2732', - 'expires_on': '1605683384', - 'not_before': '1605679484', - 'resource': 'https://management.core.windows.net/', - 'token_type': 'Bearer' - }) - token_retriever = MagicMock() - cred = AdalAuthentication(token_retriever) - - def utc_to_timestamp(dt): - # Obtain the POSIX timestamp from a naive datetime instance representing UTC time - # https://docs.python.org/3/library/datetime.html#datetime.datetime.timestamp - return dt.replace(tzinfo=datetime.timezone.utc).timestamp() - - # Test expiresOn is used and converted to epoch time - # Force expiresOn to be treated as UTC to make the test pass on both local machine (such as UTC+8) - # and CI (UTC). - with mock.patch("azure.cli.core.adal_authentication._timestamp", utc_to_timestamp): - token_retriever.return_value = user_full_token - access_token = cred.get_token("https://management.core.windows.net//.default") - self.assertEqual(access_token.token, "access_token_user_mock") - self.assertEqual(access_token.expires_on, 1605713717) - - # Test expires_on is used as epoch directly - token_retriever.return_value = cloud_shell_full_token - access_token = cred.get_token("https://management.core.windows.net//.default") - self.assertEqual(access_token.token, "access_token_cloud_shell_mock") - self.assertEqual(access_token.expires_on, 1605683384) - - -if __name__ == '__main__': - unittest.main() diff --git a/src/azure-cli-core/azure/cli/core/tests/test_profile.py b/src/azure-cli-core/azure/cli/core/tests/test_profile.py index 519c75e1534..3b5b5abdd01 100644 --- a/src/azure-cli-core/azure/cli/core/tests/test_profile.py +++ b/src/azure-cli-core/azure/cli/core/tests/test_profile.py @@ -10,26 +10,68 @@ import unittest from unittest import mock import re +import datetime from copy import deepcopy -from adal import AdalError +from azure.core.credentials import AccessToken -from azure.cli.core._profile import (Profile, CredsCache, SubscriptionFinder, - ServicePrincipalAuth, _AUTH_CTX_FACTORY, _USE_VENDORED_SUBSCRIPTION_SDK, +from azure.cli.core._profile import (Profile, SubscriptionFinder, + _detect_adfs_authority, _attach_token_tenant, _transform_subscription_for_multiapi) -if _USE_VENDORED_SUBSCRIPTION_SDK: - from azure.cli.core.vendored_sdks.subscriptions.models import \ - (SubscriptionState, Subscription, SubscriptionPolicies, SpendingLimit, ManagedByTenant) -else: - from azure.mgmt.resource.subscriptions.models import \ - (SubscriptionState, Subscription, SubscriptionPolicies, SpendingLimit, ManagedByTenant) + +from azure.mgmt.resource.subscriptions.models import \ + (Subscription, SubscriptionPolicies, SpendingLimit, ManagedByTenant, TenantIdDescription) from azure.cli.core.mock import DummyCli +from azure.identity import AuthenticationRecord from knack.util import CLIError +MOCK_ACCESS_TOKEN = "mock_access_token" +MOCK_EXPIRES_ON = 1630920323 +BEARER = 'Bearer' + + +class MockCredential: + + def __init__(self, *args, **kwargs): + super().__init__() + + def get_token(self, *scopes, **kwargs): + from azure.core.credentials import AccessToken + import time + now = int(time.time()) + # Mock sdk/identity/azure-identity/azure/identity/_internal/msal_credentials.py:230 + return AccessToken(MOCK_ACCESS_TOKEN, MOCK_EXPIRES_ON) + + +class MSRestAzureAuthStub: + def __init__(self, *args, **kwargs): + self._token = { + 'token_type': 'Bearer', + 'access_token': TestProfile.test_msi_access_token + } + self.set_token_invoked_count = 0 + self.token_read_count = 0 + self.client_id = kwargs.get('client_id') + self.object_id = kwargs.get('object_id') + self.msi_res_id = kwargs.get('msi_res_id') + + def set_token(self): + self.set_token_invoked_count += 1 + + @property + def token(self): + self.token_read_count += 1 + return self._token + + @token.setter + def token(self, value): + self._token = value + + class TestProfile(unittest.TestCase): @classmethod @@ -38,7 +80,12 @@ def setUpClass(cls): cls.user1 = 'foo@foo.com' cls.id1 = 'subscriptions/1' cls.display_name1 = 'foo account' - cls.state1 = SubscriptionState.enabled + cls.home_account_id = "00000003-0000-0000-0000-000000000000.00000003-0000-0000-0000-000000000000" + cls.client_id = "00000003-0000-0000-0000-000000000000" + cls.authentication_record = AuthenticationRecord(cls.tenant_id, cls.client_id, + "https://login.microsoftonline.com", cls.home_account_id, + cls.user1) + cls.state1 = 'Enabled' cls.managed_by_tenants = [ManagedByTenantStub('00000003-0000-0000-0000-000000000000'), ManagedByTenantStub('00000004-0000-0000-0000-000000000000')] # Dummy Subscription from SDK azure.mgmt.resource.subscriptions.v2019_06_01.operations._subscriptions_operations.SubscriptionsOperations.list @@ -49,9 +96,23 @@ def setUpClass(cls): cls.state1, tenant_id=cls.tenant_id, managed_by_tenants=cls.managed_by_tenants) + + cls.subscription1_output = [{'environmentName': 'AzureCloud', + 'homeTenantId': 'microsoft.com', + 'id': '1', + 'isDefault': True, + 'managedByTenants': [{'tenantId': '00000003-0000-0000-0000-000000000000'}, + {'tenantId': '00000004-0000-0000-0000-000000000000'}], + 'name': 'foo account', + 'state': 'Enabled', + 'tenantId': 'microsoft.com', + 'user': { + 'name': 'foo@foo.com', + 'type': 'user' + }}] + # Dummy result of azure.cli.core._profile.SubscriptionFinder._find_using_specific_tenant - # home_tenant_id is mapped from tenant_id - # tenant_id denotes token tenant + # It has home_tenant_id which is mapped from tenant_id. tenant_id now denotes token tenant. cls.subscription1 = SubscriptionStub(cls.id1, cls.display_name1, cls.state1, @@ -63,7 +124,7 @@ def setUpClass(cls): 'environmentName': 'AzureCloud', 'id': '1', 'name': cls.display_name1, - 'state': cls.state1.value, + 'state': cls.state1, 'user': { 'name': cls.user1, 'type': 'user' @@ -95,11 +156,12 @@ def setUpClass(cls): "accessToken": cls.raw_token1, "userId": cls.user1 } - + import time + cls.access_token = AccessToken(cls.raw_token1, int(cls.token_entry1['expiresIn'] + time.time())) cls.user2 = 'bar@bar.com' cls.id2 = 'subscriptions/2' cls.display_name2 = 'bar account' - cls.state2 = SubscriptionState.past_due + cls.state2 = 'PastDue' cls.subscription2_raw = SubscriptionStub(cls.id2, cls.display_name2, cls.state2, @@ -113,7 +175,7 @@ def setUpClass(cls): 'environmentName': 'AzureCloud', 'id': '2', 'name': cls.display_name2, - 'state': cls.state2.value, + 'state': cls.state2, 'user': { 'name': cls.user2, 'type': 'user' @@ -146,404 +208,664 @@ def setUpClass(cls): 'e-lOym1sH5iOcxfIjXF0Tp2y0f3zM7qCq8Cp1ZxEwz6xYIgByoxjErNXrOME5Ld1WizcsaWxTXpwxJn_' 'Q8U2g9kXHrbYFeY2gJxF_hnfLvNKxUKUBnftmyYxZwKi0GDS0BvdJnJnsqSRSpxUx__Ra9QJkG1IaDzj' 'ZcSZPHK45T6ohK9Hk9ktZo0crVl7Tmw') - cls.arm_resource = 'https://management.core.windows.net/' + cls.test_user_msi_access_token = ('eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiIsIng1dCI6IlNzWnNCTmhaY0YzUTlTNHRycFFCVE' + 'J5TlJSSSIsImtpZCI6IlNzWnNCTmhaY0YzUTlTNHRycFFCVEJ5TlJSSSJ9.eyJhdWQiOiJodHR' + 'wczovL21hbmFnZW1lbnQuY29yZS53aW5kb3dzLm5ldCIsImlzcyI6Imh0dHBzOi8vc3RzLndpbm' + 'Rvd3MubmV0LzU0ODI2YjIyLTM4ZDYtNGZiMi1iYWQ5LWI3YjkzYTNlOWM1YS8iLCJpYXQiOjE1O' + 'TE3ODM5MDQsIm5iZiI6MTU5MTc4MzkwNCwiZXhwIjoxNTkxODcwNjA0LCJhaW8iOiI0MmRnWUZE' + 'd2JsZmR0WmYxck8zeGlMcVdtOU5MQVE9PSIsImFwcGlkIjoiNjJhYzQ5ZTYtMDQzOC00MTJjLWJ' + 'kZjUtNDg0ZTdkNDUyOTM2IiwiYXBwaWRhY3IiOiIyIiwiaWRwIjoiaHR0cHM6Ly9zdHMud2luZG' + '93cy5uZXQvNTQ4MjZiMjItMzhkNi00ZmIyLWJhZDktYjdiOTNhM2U5YzVhLyIsIm9pZCI6ImQ4M' + 'zRjNjZmLTNhZjgtNDBiNy1iNDYzLWViZGNlN2YzYTgyNyIsInN1YiI6ImQ4MzRjNjZmLTNhZjgt' + 'NDBiNy1iNDYzLWViZGNlN2YzYTgyNyIsInRpZCI6IjU0ODI2YjIyLTM4ZDYtNGZiMi1iYWQ5LWI' + '3YjkzYTNlOWM1YSIsInV0aSI6Ild2YjFyVlBQT1V5VjJDYmNyeHpBQUEiLCJ2ZXIiOiIxLjAiLC' + 'J4bXNfbWlyaWQiOiIvc3Vic2NyaXB0aW9ucy8wYjFmNjQ3MS0xYmYwLTRkZGEtYWVjMy1jYjkyNz' + 'JmMDk1OTAvcmVzb3VyY2Vncm91cHMvcWlhbndlbnMvcHJvdmlkZXJzL01pY3Jvc29mdC5NYW5hZ2' + 'VkSWRlbnRpdHkvdXNlckFzc2lnbmVkSWRlbnRpdGllcy9xaWFud2VuaWRlbnRpdHkifQ.nAxWA5_' + 'qTs_uwGoziKtDFAqxlmYSlyPGqAKZ8YFqFfm68r5Ouo2x2PztAv2D71L-j8B3GykNgW-2yhbB-z2' + 'h53dgjG2TVoeZjhV9DOpSJ06kLAeH-nskGxpBFf7se1qohlU7uyctsUMQWjXVUQbTEanJzj_IH-Y' + '47O3lvM4Yrliz5QUApm63VF4EhqNpNvb5w0HkuB72SJ0MKJt5VdQqNcG077NQNoiTJ34XVXkyNDp' + 'I15y0Cj504P_xw-Dpvg-hmEbykjFMIaB8RoSrp3BzYjNtJh2CHIuWhXF0ngza2SwN2CXK0Vpn5Za' + 'EvZdD57j3h8iGE0Tw5IzG86uNS2AQ0A') + + cls.msal_accounts = [ + { + 'home_account_id': '182c0000-0000-0000-0000-000000000000.54820000-0000-0000-0000-000000000000', + 'environment': 'login.microsoftonline.com', + 'realm': 'organizations', + 'local_account_id': '182c0000-0000-0000-0000-000000000000', + 'username': cls.user1, + 'authority_type': 'MSSTS' + }, { + 'home_account_id': '182c0000-0000-0000-0000-000000000000.54820000-0000-0000-0000-000000000000', + 'environment': 'login.microsoftonline.com', + 'realm': '54820000-0000-0000-0000-000000000000', + 'local_account_id': '182c0000-0000-0000-0000-000000000000', + 'username': cls.user1, + 'authority_type': 'MSSTS' + }, { + 'home_account_id': 'c7970000-0000-0000-0000-000000000000.54820000-0000-0000-0000-000000000000', + 'environment': 'login.microsoftonline.com', + 'realm': 'organizations', + 'local_account_id': 'c7970000-0000-0000-0000-000000000000', + 'username': cls.user2, + 'authority_type': 'MSSTS' + }, { + 'home_account_id': 'c7970000-0000-0000-0000-000000000000.54820000-0000-0000-0000-000000000000', + 'environment': 'login.microsoftonline.com', + 'realm': '54820000-0000-0000-0000-000000000000', + 'local_account_id': 'c7970000-0000-0000-0000-000000000000', + 'username': cls.user2, + 'authority_type': 'MSSTS' + }] + + cls.msal_scopes = ['https://foo/.default'] + + cls.service_principal_id = "00000001-0000-0000-0000-000000000000" + cls.service_principal_secret = "test_secret" + cls.service_principal_tenant_id = "00000001-0000-0000-0000-000000000000" + + @mock.patch('azure.cli.core._profile.SubscriptionFinder._create_subscription_client', autospec=True) + @mock.patch('azure.cli.core.auth.identity.Identity.get_user_credential', autospec=True) + @mock.patch('azure.cli.core.auth.identity.Identity.login_with_auth_code', autospec=True) + @mock.patch('azure.cli.core._profile.can_launch_browser', autospec=True, return_value=True) + def test_login_with_auth_code(self, can_launch_browser_mock, login_with_auth_code_mock, get_user_credential_mock, + create_subscription_client_mock): + user_identity_mock = { + 'username': self.user1, + 'tenantId': self.tenant_id + } + login_with_auth_code_mock.return_value = user_identity_mock - def test_normalize(self): cli = DummyCli() - storage_mock = {'subscriptions': None} - profile = Profile(cli_ctx=cli, storage=storage_mock, use_global_creds_cache=False, async_persist=False) - consolidated = profile._normalize_properties(self.user1, - [self.subscription1], - False) - expected = self.subscription1_normalized - self.assertEqual(expected, consolidated[0]) - # verify serialization works - self.assertIsNotNone(json.dumps(consolidated[0])) + mock_subscription_client = mock.MagicMock() + mock_subscription_client.tenants.list.return_value = [TenantStub(self.tenant_id)] + mock_subscription_client.subscriptions.list.return_value = [deepcopy(self.subscription1_raw)] + create_subscription_client_mock.return_value = mock_subscription_client - def test_normalize_with_unicode_in_subscription_name(self): - cli = DummyCli() storage_mock = {'subscriptions': None} - test_display_name = 'sub' + chr(255) - polished_display_name = 'sub?' - test_subscription = SubscriptionStub('subscriptions/sub1', - test_display_name, - SubscriptionState.enabled, - 'tenant1') - profile = Profile(cli_ctx=cli, storage=storage_mock, use_global_creds_cache=False, async_persist=False) - consolidated = profile._normalize_properties(self.user1, - [test_subscription], - False) - self.assertTrue(consolidated[0]['name'] in [polished_display_name, test_display_name]) + profile = Profile(cli_ctx=cli, storage=storage_mock) + subs = profile.login(True, None, None, False, None, use_device_code=False, allow_no_subscriptions=False) - def test_normalize_with_none_subscription_name(self): - cli = DummyCli() - storage_mock = {'subscriptions': None} - test_display_name = None - polished_display_name = '' - test_subscription = SubscriptionStub('subscriptions/sub1', - test_display_name, - SubscriptionState.enabled, - 'tenant1') - profile = Profile(cli_ctx=cli, storage=storage_mock, use_global_creds_cache=False, async_persist=False) - consolidated = profile._normalize_properties(self.user1, - [test_subscription], - False) - self.assertTrue(consolidated[0]['name'] == polished_display_name) + # assert + login_with_auth_code_mock.assert_called_once() + get_user_credential_mock.assert_called() + self.assertEqual(self.subscription1_output, subs) + + @mock.patch('azure.cli.core._profile.SubscriptionFinder._create_subscription_client', autospec=True) + @mock.patch('azure.cli.core.auth.identity.Identity.get_user_credential', autospec=True) + @mock.patch('azure.cli.core.auth.identity.Identity.login_with_device_code', autospec=True) + def test_login_with_device_code(self, login_with_device_code_mock, get_user_credential_mock, + create_subscription_client_mock): + user_identity_mock = { + 'username': self.user1, + 'tenantId': self.tenant_id + } + login_with_device_code_mock.return_value = user_identity_mock - def test_update_add_two_different_subscriptions(self): cli = DummyCli() - storage_mock = {'subscriptions': None} - profile = Profile(cli_ctx=cli, storage=storage_mock, use_global_creds_cache=False, async_persist=False) + mock_subscription_client = mock.MagicMock() + mock_subscription_client.tenants.list.return_value = [TenantStub(self.tenant_id)] + mock_subscription_client.subscriptions.list.return_value = [deepcopy(self.subscription1_raw)] + create_subscription_client_mock.return_value = mock_subscription_client - # add the first and verify - consolidated = profile._normalize_properties(self.user1, - [self.subscription1], - False) - profile._set_subscriptions(consolidated) - - self.assertEqual(len(storage_mock['subscriptions']), 1) - subscription1 = storage_mock['subscriptions'][0] - subscription1_is_default = deepcopy(self.subscription1_normalized) - subscription1_is_default['isDefault'] = True - self.assertEqual(subscription1, subscription1_is_default) - - # add the second and verify - consolidated = profile._normalize_properties(self.user2, - [self.subscription2], - False) - profile._set_subscriptions(consolidated) - - self.assertEqual(len(storage_mock['subscriptions']), 2) - subscription2 = storage_mock['subscriptions'][1] - subscription2_is_default = deepcopy(self.subscription2_normalized) - subscription2_is_default['isDefault'] = True - self.assertEqual(subscription2, subscription2_is_default) + storage_mock = {'subscriptions': None} + profile = Profile(cli_ctx=cli, storage=storage_mock) + subs = profile.login(True, None, None, False, None, use_device_code=True, allow_no_subscriptions=False) - # verify the old one stays, but no longer active - self.assertEqual(storage_mock['subscriptions'][0]['name'], - subscription1['name']) - self.assertFalse(storage_mock['subscriptions'][0]['isDefault']) + # assert + self.assertEqual(self.subscription1_output, subs) + + @mock.patch('azure.cli.core._profile.SubscriptionFinder._create_subscription_client', autospec=True) + @mock.patch('azure.cli.core.auth.identity.Identity.get_user_credential', autospec=True) + @mock.patch('azure.cli.core.auth.identity.Identity.login_with_device_code', autospec=True) + def test_login_with_device_code_for_tenant(self, login_with_device_code_mock, get_user_credential_mock, + create_subscription_client_mock): + user_identity_mock = { + 'username': self.user1, + 'tenantId': self.tenant_id + } + login_with_device_code_mock.return_value = user_identity_mock - def test_update_with_same_subscription_added_twice(self): cli = DummyCli() + mock_subscription_client = mock.MagicMock() + mock_subscription_client.tenants.list.return_value = [TenantStub(self.tenant_id)] + mock_subscription_client.subscriptions.list.return_value = [deepcopy(self.subscription1_raw)] + create_subscription_client_mock.return_value = mock_subscription_client + storage_mock = {'subscriptions': None} - profile = Profile(cli_ctx=cli, storage=storage_mock, use_global_creds_cache=False, async_persist=False) + profile = Profile(cli_ctx=cli, storage=storage_mock) + subs = profile.login(True, None, None, False, self.tenant_id, use_device_code=True, + allow_no_subscriptions=False) - # add one twice and verify we will have one but with new token - consolidated = profile._normalize_properties(self.user1, - [self.subscription1], - False) - profile._set_subscriptions(consolidated) + # assert + self.assertEqual(self.subscription1_output, subs) + + @mock.patch('azure.cli.core._profile.SubscriptionFinder._create_subscription_client', autospec=True) + @mock.patch('azure.cli.core.auth.identity.Identity.get_user_credential', autospec=True) + @mock.patch('azure.cli.core.auth.identity.Identity.login_with_username_password', autospec=True) + def test_login_with_username_password_for_tenant(self, login_with_username_password_mock, get_user_credential_mock, + create_subscription_client_mock): + user_identity_mock = { + 'username': self.user1, + 'tenantId': self.tenant_id + } + login_with_username_password_mock.return_value = user_identity_mock - new_subscription1 = SubscriptionStub(self.id1, - self.display_name1, - self.state1, - self.tenant_id) - consolidated = profile._normalize_properties(self.user1, - [new_subscription1], - False) - profile._set_subscriptions(consolidated) + cli = DummyCli() + mock_subscription_client = mock.MagicMock() + mock_subscription_client.tenants.list.return_value = [TenantStub(self.tenant_id)] + mock_subscription_client.subscriptions.list.return_value = [deepcopy(self.subscription1_raw)] + create_subscription_client_mock.return_value = mock_subscription_client - self.assertEqual(len(storage_mock['subscriptions']), 1) - self.assertTrue(storage_mock['subscriptions'][0]['isDefault']) + storage_mock = {'subscriptions': None} + profile = Profile(cli_ctx=cli, storage=storage_mock) + subs = profile.login(False, '1234', 'my-secret', False, self.tenant_id, use_device_code=False, + allow_no_subscriptions=False) + + self.assertEqual(self.subscription1_output, subs) + + @mock.patch('azure.cli.core._profile.SubscriptionFinder._create_subscription_client', autospec=True) + @mock.patch('azure.cli.core.auth.identity.Identity.get_service_principal_credential', autospec=True) + @mock.patch('azure.cli.core.auth.identity.Identity.login_with_service_principal', autospec=True) + def test_login_with_service_principal(self, login_with_service_principal_mock, + get_service_principal_credential_mock, + create_subscription_client_mock): + cli = DummyCli() + mock_subscription_client = mock.MagicMock() + mock_subscription_client.tenants.list.return_value = [TenantStub(self.tenant_id)] + mock_subscription_client.subscriptions.list.return_value = [deepcopy(self.subscription1_raw)] + create_subscription_client_mock.return_value = mock_subscription_client - def test_set_active_subscription(self): + storage_mock = {'subscriptions': None} + profile = Profile(cli_ctx=cli, storage=storage_mock) + subs = profile.login(False, 'my app', 'my secret', True, self.tenant_id, use_device_code=True, + allow_no_subscriptions=False) + output = [{'environmentName': 'AzureCloud', + 'homeTenantId': 'microsoft.com', + 'id': '1', + 'isDefault': True, + 'managedByTenants': [{'tenantId': '00000003-0000-0000-0000-000000000000'}, + {'tenantId': '00000004-0000-0000-0000-000000000000'}], + 'name': 'foo account', + 'state': 'Enabled', + 'tenantId': 'microsoft.com', + 'user': { + 'name': 'my app', + 'type': 'servicePrincipal'}}] + self.assertEqual(output, subs) + + @unittest.skip("Not supported by Azure Identity.") + def test_login_with_service_principal_cert_sn_issuer(self, get_token_mock): cli = DummyCli() + mock_arm_client = mock.MagicMock() + mock_arm_client.subscriptions.list.return_value = [deepcopy(self.subscription1_raw)] + finder = SubscriptionFinder(cli, lambda _: mock_arm_client) + curr_dir = os.path.dirname(os.path.realpath(__file__)) + test_cert_file = os.path.join(curr_dir, 'sp_cert.pem') storage_mock = {'subscriptions': None} profile = Profile(cli_ctx=cli, storage=storage_mock, use_global_creds_cache=False, async_persist=False) + subs = profile.login(False, 'my app', test_cert_file, True, self.tenant_id, use_device_code=True, + allow_no_subscriptions=False, subscription_finder=finder, use_cert_sn_issuer=True) + output = [{'environmentName': 'AzureCloud', + 'homeTenantId': 'microsoft.com', + 'id': '1', + 'isDefault': True, + 'managedByTenants': [{'tenantId': '00000003-0000-0000-0000-000000000000'}, + {'tenantId': '00000004-0000-0000-0000-000000000000'}], + 'name': 'foo account', + 'state': 'Enabled', + 'tenantId': 'microsoft.com', + 'user': { + 'name': 'my app', + 'type': 'servicePrincipal', + 'useCertSNIssuerAuth': True}}] + # assert + self.assertEqual(output, subs) - consolidated = profile._normalize_properties(self.user1, - [self.subscription1], - False) - profile._set_subscriptions(consolidated) - - consolidated = profile._normalize_properties(self.user2, - [self.subscription2], - False) - profile._set_subscriptions(consolidated) + @mock.patch('azure.cli.core._profile.SubscriptionFinder._create_subscription_client', autospec=True) + @mock.patch('azure.cli.core.auth.adal_authentication.MSIAuthenticationWrapper', autospec=True) + def test_login_in_cloud_shell(self, msi_auth_mock, create_subscription_client_mock): + msi_auth_mock.return_value = MSRestAzureAuthStub() - self.assertTrue(storage_mock['subscriptions'][1]['isDefault']) + cli = DummyCli() + mock_subscription_client = mock.MagicMock() + mock_subscription_client.tenants.list.return_value = [TenantStub(self.tenant_id)] + mock_subscription_client.subscriptions.list.return_value = [deepcopy(self.subscription1_raw)] + create_subscription_client_mock.return_value = mock_subscription_client - profile.set_active_subscription(storage_mock['subscriptions'][0]['id']) - self.assertFalse(storage_mock['subscriptions'][1]['isDefault']) - self.assertTrue(storage_mock['subscriptions'][0]['isDefault']) + profile = Profile(cli_ctx=cli, storage={'subscriptions': None}) - def test_default_active_subscription_to_non_disabled_one(self): - cli = DummyCli() - storage_mock = {'subscriptions': None} - profile = Profile(cli_ctx=cli, storage=storage_mock, use_global_creds_cache=False, async_persist=False) + subscriptions = profile.login_in_cloud_shell() - subscriptions = profile._normalize_properties( - self.user2, [self.subscription2, self.subscription1], False) + # Check correct token is used + assert create_subscription_client_mock.call_args[0][1].token['access_token'] == TestProfile.test_msi_access_token - profile._set_subscriptions(subscriptions) + self.assertEqual(len(subscriptions), 1) + s = subscriptions[0] + self.assertEqual(s['user']['name'], 'admin3@AzureSDKTeam.onmicrosoft.com') + self.assertEqual(s['tenantId'], '54826b22-38d6-4fb2-bad9-b7b93a3e9c5a') + self.assertEqual(s['user']['cloudShellID'], True) + self.assertEqual(s['user']['type'], 'user') + self.assertEqual(s['name'], self.display_name1) + self.assertEqual(s['id'], self.id1.split('/')[-1]) - # verify we skip the overdued subscription and default to the 2nd one in the list - self.assertEqual(storage_mock['subscriptions'][1]['name'], self.subscription1.display_name) - self.assertTrue(storage_mock['subscriptions'][1]['isDefault']) + @mock.patch('requests.get', autospec=True) + @mock.patch('azure.cli.core._profile.SubscriptionFinder._create_subscription_client', autospec=True) + def test_find_subscriptions_in_vm_with_msi_system_assigned(self, create_subscription_client_mock, mock_get): + mock_subscription_client = mock.MagicMock() + mock_subscription_client.subscriptions.list.return_value = [deepcopy(self.subscription1_raw)] + create_subscription_client_mock.return_value = mock_subscription_client - def test_get_subscription(self): cli = DummyCli() storage_mock = {'subscriptions': None} - profile = Profile(cli_ctx=cli, storage=storage_mock, use_global_creds_cache=False, async_persist=False) - - consolidated = profile._normalize_properties(self.user1, - [self.subscription1], - False) - profile._set_subscriptions(consolidated) - - self.assertEqual(self.display_name1, profile.get_subscription()['name']) - self.assertEqual(self.display_name1, - profile.get_subscription(subscription=self.display_name1)['name']) + profile = Profile(cli_ctx=cli, storage=storage_mock) - sub_id = self.id1.split('/')[-1] - self.assertEqual(sub_id, profile.get_subscription()['id']) - self.assertEqual(sub_id, profile.get_subscription(subscription=sub_id)['id']) - self.assertRaises(CLIError, profile.get_subscription, "random_id") + test_token_entry = { + 'token_type': 'Bearer', + 'access_token': TestProfile.test_msi_access_token + } + encoded_test_token = json.dumps(test_token_entry).encode() + good_response = mock.MagicMock() + good_response.status_code = 200 + good_response.content = encoded_test_token + mock_get.return_value = good_response - def test_get_auth_info_fail_on_user_account(self): - cli = DummyCli() - storage_mock = {'subscriptions': None} - profile = Profile(cli_ctx=cli, storage=storage_mock, use_global_creds_cache=False, async_persist=False) + subscriptions = profile.login_with_managed_identity() - consolidated = profile._normalize_properties(self.user1, - [self.subscription1], - False) - profile._set_subscriptions(consolidated) + # assert + self.assertEqual(len(subscriptions), 1) + s = subscriptions[0] + self.assertEqual(s['user']['name'], 'systemAssignedIdentity') + self.assertEqual(s['user']['type'], 'servicePrincipal') + self.assertEqual(s['user']['assignedIdentityInfo'], 'MSI') + self.assertEqual(s['name'], self.display_name1) + self.assertEqual(s['id'], self.id1.split('/')[-1]) + self.assertEqual(s['tenantId'], '54826b22-38d6-4fb2-bad9-b7b93a3e9c5a') - # testing dump of existing logged in account - self.assertRaises(CLIError, profile.get_sp_auth_info) + @mock.patch('requests.get', autospec=True) + @mock.patch('azure.cli.core._profile.SubscriptionFinder._create_subscription_client', autospec=True) + def test_find_subscriptions_in_vm_with_msi_no_subscriptions(self, create_subscription_client_mock, mock_get): + mock_subscription_client = mock.MagicMock() + mock_subscription_client.subscriptions.list.return_value = [] + create_subscription_client_mock.return_value = mock_subscription_client - @mock.patch('azure.cli.core.profiles.get_api_version', autospec=True) - def test_subscription_finder_constructor(self, get_api_mock): cli = DummyCli() - get_api_mock.return_value = '2016-06-01' - cli.cloud.endpoints.resource_manager = 'http://foo_arm' - finder = SubscriptionFinder(cli, None, None, arm_client_factory=None) - result = finder._arm_client_factory(mock.MagicMock()) - self.assertEqual(result._client._base_url, 'http://foo_arm') + storage_mock = {'subscriptions': None} + profile = Profile(cli_ctx=cli, storage=storage_mock) - @mock.patch('azure.cli.core._profile.SubscriptionFinder._get_subscription_client_class', autospec=True) - def test_subscription_finder_fail_on_arm_client_factory(self, get_client_class_mock): - cli = DummyCli() - get_client_class_mock.return_value = None - finder = SubscriptionFinder(cli, None, None, arm_client_factory=None) - from azure.cli.core.azclierror import CLIInternalError - with self.assertRaisesRegexp(CLIInternalError, 'Unable to get'): - finder._arm_client_factory(mock.MagicMock()) - - @mock.patch('adal.AuthenticationContext', autospec=True) - def test_get_auth_info_for_logged_in_service_principal(self, mock_auth_context): - cli = DummyCli() - mock_auth_context.acquire_token_with_client_credentials.return_value = self.token_entry1 - mock_arm_client = mock.MagicMock() - mock_arm_client.subscriptions.list.return_value = [deepcopy(self.subscription1_raw)] - finder = SubscriptionFinder(cli, lambda _, _1, _2: mock_auth_context, None, lambda _: mock_arm_client) + test_token_entry = { + 'token_type': 'Bearer', + 'access_token': TestProfile.test_msi_access_token + } + encoded_test_token = json.dumps(test_token_entry).encode() + good_response = mock.MagicMock() + good_response.status_code = 200 + good_response.content = encoded_test_token + mock_get.return_value = good_response - storage_mock = {'subscriptions': []} - profile = Profile(cli_ctx=cli, storage=storage_mock, use_global_creds_cache=False, async_persist=False) - profile._management_resource_uri = 'https://management.core.windows.net/' - profile.find_subscriptions_on_login(False, '1234', 'my-secret', True, self.tenant_id, use_device_code=False, - allow_no_subscriptions=False, subscription_finder=finder) - # action - extended_info = profile.get_sp_auth_info() - # assert - self.assertEqual(self.id1.split('/')[-1], extended_info['subscriptionId']) - self.assertEqual('1234', extended_info['clientId']) - self.assertEqual('my-secret', extended_info['clientSecret']) - self.assertEqual('https://login.microsoftonline.com', extended_info['activeDirectoryEndpointUrl']) - self.assertEqual('https://management.azure.com/', extended_info['resourceManagerEndpointUrl']) + subscriptions = profile.login_with_managed_identity(allow_no_subscriptions=True) - def test_get_auth_info_for_newly_created_service_principal(self): - cli = DummyCli() - storage_mock = {'subscriptions': []} - profile = Profile(cli_ctx=cli, storage=storage_mock, use_global_creds_cache=False, async_persist=False) - consolidated = profile._normalize_properties(self.user1, [self.subscription1], False) - profile._set_subscriptions(consolidated) - # action - extended_info = profile.get_sp_auth_info(name='1234', cert_file='/tmp/123.pem') # assert - self.assertEqual(self.id1.split('/')[-1], extended_info['subscriptionId']) - self.assertEqual(self.tenant_id, extended_info['tenantId']) - self.assertEqual('1234', extended_info['clientId']) - self.assertEqual('/tmp/123.pem', extended_info['clientCertificate']) - self.assertIsNone(extended_info.get('clientSecret', None)) - self.assertEqual('https://login.microsoftonline.com', extended_info['activeDirectoryEndpointUrl']) - self.assertEqual('https://management.azure.com/', extended_info['resourceManagerEndpointUrl']) - - @mock.patch('adal.AuthenticationContext', autospec=True) - def test_create_account_without_subscriptions_thru_service_principal(self, mock_auth_context): - mock_auth_context.acquire_token_with_client_credentials.return_value = self.token_entry1 - cli = DummyCli() - mock_arm_client = mock.MagicMock() - mock_arm_client.subscriptions.list.return_value = [] - finder = SubscriptionFinder(cli, lambda _, _1, _2: mock_auth_context, None, lambda _: mock_arm_client) - - storage_mock = {'subscriptions': []} - profile = Profile(cli_ctx=cli, storage=storage_mock, use_global_creds_cache=False, async_persist=False) - profile._management_resource_uri = 'https://management.core.windows.net/' + self.assertEqual(len(subscriptions), 1) + s = subscriptions[0] - # action - result = profile.find_subscriptions_on_login(False, - '1234', - 'my-secret', - True, - self.tenant_id, - use_device_code=False, - allow_no_subscriptions=True, - subscription_finder=finder) - # assert - self.assertEqual(1, len(result)) - self.assertEqual(result[0]['id'], self.tenant_id) - self.assertEqual(result[0]['state'], 'Enabled') - self.assertEqual(result[0]['tenantId'], self.tenant_id) - self.assertEqual(result[0]['name'], 'N/A(tenant level account)') - self.assertTrue(profile.is_tenant_level_account()) + self.assertEqual(s['name'], 'N/A(tenant level account)') + self.assertEqual(s['id'], self.test_msi_tenant) + self.assertEqual(s['tenantId'], self.test_msi_tenant) - @mock.patch('adal.AuthenticationContext', autospec=True) - def test_create_account_with_subscriptions_allow_no_subscriptions_thru_service_principal(self, mock_auth_context): - """test subscription is returned even with --allow-no-subscriptions. """ - mock_auth_context.acquire_token_with_client_credentials.return_value = self.token_entry1 - cli = DummyCli() - mock_arm_client = mock.MagicMock() - mock_arm_client.subscriptions.list.return_value = [deepcopy(self.subscription1_raw)] - finder = SubscriptionFinder(cli, lambda _, _1, _2: mock_auth_context, None, lambda _: mock_arm_client) + self.assertEqual(s['user']['name'], 'systemAssignedIdentity') + self.assertEqual(s['user']['type'], 'servicePrincipal') + self.assertEqual(s['user']['assignedIdentityInfo'], 'MSI') - storage_mock = {'subscriptions': []} - profile = Profile(cli_ctx=cli, storage=storage_mock, use_global_creds_cache=False, async_persist=False) - profile._management_resource_uri = 'https://management.core.windows.net/' + @mock.patch('requests.get', autospec=True) + @mock.patch('azure.cli.core._profile.SubscriptionFinder._create_subscription_client', autospec=True) + def test_find_subscriptions_in_vm_with_msi_user_assigned_with_client_id(self, create_subscription_client_mock, mock_get): + mock_subscription_client = mock.MagicMock() + mock_subscription_client.subscriptions.list.return_value = [deepcopy(self.subscription1_raw)] + create_subscription_client_mock.return_value = mock_subscription_client - # action - result = profile.find_subscriptions_on_login(False, - '1234', - 'my-secret', - True, - self.tenant_id, - use_device_code=False, - allow_no_subscriptions=True, - subscription_finder=finder) - # assert - self.assertEqual(1, len(result)) - self.assertEqual(result[0]['id'], self.id1.split('/')[-1]) - self.assertEqual(result[0]['state'], 'Enabled') - self.assertEqual(result[0]['tenantId'], self.tenant_id) - self.assertEqual(result[0]['name'], self.display_name1) - self.assertFalse(profile.is_tenant_level_account()) - - @mock.patch('adal.AuthenticationContext', autospec=True) - def test_create_account_without_subscriptions_thru_common_tenant(self, mock_auth_context): - mock_auth_context.acquire_token.return_value = self.token_entry1 - mock_auth_context.acquire_token_with_username_password.return_value = self.token_entry1 cli = DummyCli() - tenant_object = mock.MagicMock() - tenant_object.id = "foo-bar" - tenant_object.tenant_id = self.tenant_id - mock_arm_client = mock.MagicMock() - mock_arm_client.subscriptions.list.return_value = [] - mock_arm_client.tenants.list.return_value = (x for x in [tenant_object]) - - finder = SubscriptionFinder(cli, lambda _, _1, _2: mock_auth_context, None, lambda _: mock_arm_client) + storage_mock = {'subscriptions': None} + profile = Profile(cli_ctx=cli, storage=storage_mock) - storage_mock = {'subscriptions': []} - profile = Profile(cli_ctx=cli, storage=storage_mock, use_global_creds_cache=False, async_persist=False) - profile._management_resource_uri = 'https://management.core.windows.net/' + test_token_entry = { + 'token_type': 'Bearer', + 'access_token': TestProfile.test_msi_access_token + } + test_client_id = '54826b22-38d6-4fb2-bad9-b7b93a3e9999' + encoded_test_token = json.dumps(test_token_entry).encode() + good_response = mock.MagicMock() + good_response.status_code = 200 + good_response.content = encoded_test_token + mock_get.return_value = good_response - # action - result = profile.find_subscriptions_on_login(False, - '1234', - 'my-secret', - False, - None, - use_device_code=False, - allow_no_subscriptions=True, - subscription_finder=finder) + subscriptions = profile.login_with_managed_identity(identity_id=test_client_id) - # assert - self.assertEqual(1, len(result)) - self.assertEqual(result[0]['id'], self.tenant_id) - self.assertEqual(result[0]['state'], 'Enabled') - self.assertEqual(result[0]['tenantId'], self.tenant_id) - self.assertEqual(result[0]['name'], 'N/A(tenant level account)') - - @mock.patch('adal.AuthenticationContext', autospec=True) - def test_create_account_without_subscriptions_without_tenant(self, mock_auth_context): + self.assertEqual(len(subscriptions), 1) + s = subscriptions[0] + self.assertEqual(s['name'], self.display_name1) + self.assertEqual(s['id'], self.id1.split('/')[-1]) + self.assertEqual(s['tenantId'], '54826b22-38d6-4fb2-bad9-b7b93a3e9c5a') + + self.assertEqual(s['user']['name'], 'userAssignedIdentity') + self.assertEqual(s['user']['type'], 'servicePrincipal') + self.assertEqual(s['user']['assignedIdentityInfo'], 'MSIClient-{}'.format(test_client_id)) + + @mock.patch('azure.cli.core.auth.adal_authentication.MSIAuthenticationWrapper', autospec=True) + @mock.patch('azure.cli.core._profile.SubscriptionFinder._create_subscription_client', autospec=True) + def test_find_subscriptions_in_vm_with_msi_user_assigned_with_object_id(self, create_subscription_client_mock, + mock_msi_auth): + mock_subscription_client = mock.MagicMock() + mock_subscription_client.subscriptions.list.return_value = [deepcopy(self.subscription1_raw)] + create_subscription_client_mock.return_value = mock_subscription_client + + from azure.cli.core.azclierror import AzureResponseError + class AuthStub: + def __init__(self, **kwargs): + self.token = None + self.client_id = kwargs.get('client_id') + self.object_id = kwargs.get('object_id') + # since msrestazure 0.4.34, set_token in init + self.set_token() + + def set_token(self): + # here we will reject the 1st sniffing of trying with client_id and then acccept the 2nd + if self.object_id: + self.token = { + 'token_type': 'Bearer', + 'access_token': TestProfile.test_msi_access_token + } + else: + raise AzureResponseError('Failed to connect to MSI. Please make sure MSI is configured correctly.\n' + 'Get Token request returned http error: 400, reason: Bad Request') + + profile = Profile(cli_ctx=DummyCli(), storage={'subscriptions': None}) + + mock_msi_auth.side_effect = AuthStub + test_object_id = '54826b22-38d6-4fb2-bad9-b7b93a3e9999' + + subscriptions = profile.login_with_managed_identity(identity_id=test_object_id) + + s = subscriptions[0] + self.assertEqual(s['user']['name'], 'userAssignedIdentity') + self.assertEqual(s['user']['type'], 'servicePrincipal') + self.assertEqual(s['user']['assignedIdentityInfo'], 'MSIObject-{}'.format(test_object_id)) + + @mock.patch('requests.get', autospec=True) + @mock.patch('azure.cli.core._profile.SubscriptionFinder._create_subscription_client', autospec=True) + def test_find_subscriptions_in_vm_with_msi_user_assigned_with_res_id(self, create_subscription_client_mock, + mock_get): + + mock_subscription_client = mock.MagicMock() + mock_subscription_client.subscriptions.list.return_value = [deepcopy(self.subscription1_raw)] + create_subscription_client_mock.return_value = mock_subscription_client + + cli = DummyCli() + storage_mock = {'subscriptions': None} + profile = Profile(cli_ctx=cli, storage=storage_mock) + + test_token_entry = { + 'token_type': 'Bearer', + 'access_token': TestProfile.test_msi_access_token + } + test_res_id = ('/subscriptions/0b1f6471-1bf0-4dda-aec3-cb9272f09590/resourcegroups/g1/' + 'providers/Microsoft.ManagedIdentity/userAssignedIdentities/id1') + + encoded_test_token = json.dumps(test_token_entry).encode() + good_response = mock.MagicMock() + good_response.status_code = 200 + good_response.content = encoded_test_token + mock_get.return_value = good_response + + subscriptions = profile.login_with_managed_identity(identity_id=test_res_id) + + s = subscriptions[0] + self.assertEqual(s['user']['name'], 'userAssignedIdentity') + self.assertEqual(s['user']['type'], 'servicePrincipal') + self.assertEqual(subscriptions[0]['user']['assignedIdentityInfo'], 'MSIResource-{}'.format(test_res_id)) + + @mock.patch('azure.cli.core._profile.SubscriptionFinder._create_subscription_client', autospec=True) + @mock.patch('azure.cli.core.auth.identity.Identity.get_user_credential', autospec=True) + @mock.patch('azure.cli.core.auth.identity.Identity.login_with_auth_code', autospec=True) + @mock.patch('azure.cli.core._profile.can_launch_browser', autospec=True, return_value=True) + def test_login_no_subscription(self, can_launch_browser_mock, + login_with_auth_code_mock, get_user_credential_mock, + create_subscription_client_mock): + user_identity_mock = { + 'username': self.user1, + 'tenantId': self.tenant_id + } + login_with_auth_code_mock.return_value = user_identity_mock + + cli = DummyCli() + mock_subscription_client = mock.MagicMock() + mock_subscription_client.tenants.list.return_value = [TenantStub(self.tenant_id)] + mock_subscription_client.subscriptions.list.return_value = [] + create_subscription_client_mock.return_value = mock_subscription_client + + storage_mock = {'subscriptions': None} + profile = Profile(cli_ctx=cli, storage=storage_mock) + subs = profile.login(True, None, None, False, None, use_device_code=False, allow_no_subscriptions=True) + + self.assertEqual(1, len(subs)) + self.assertEqual(subs[0]['id'], self.tenant_id) + self.assertEqual(subs[0]['state'], 'Enabled') + self.assertEqual(subs[0]['tenantId'], self.tenant_id) + self.assertEqual(subs[0]['name'], 'N/A(tenant level account)') + self.assertTrue(profile.is_tenant_level_account()) + + @mock.patch('azure.cli.core._profile.SubscriptionFinder._create_subscription_client', autospec=True) + @mock.patch('azure.cli.core.auth.identity.Identity.get_user_credential', autospec=True) + @mock.patch('azure.cli.core.auth.identity.Identity.login_with_auth_code', autospec=True) + @mock.patch('azure.cli.core._profile.can_launch_browser', autospec=True, return_value=True) + def test_login_no_tenant(self, can_launch_browser_mock, + login_with_auth_code_mock, get_user_credential_mock, + create_subscription_client_mock): + user_identity_mock = { + 'username': self.user1, + 'tenantId': self.tenant_id + } + login_with_auth_code_mock.return_value = user_identity_mock + + cli = DummyCli() + mock_subscription_client = mock.MagicMock() + mock_subscription_client.tenants.list.return_value = [] + mock_subscription_client.subscriptions.list.return_value = [] + create_subscription_client_mock.return_value = mock_subscription_client + + storage_mock = {'subscriptions': None} + profile = Profile(cli_ctx=cli, storage=storage_mock) + subs = profile.login(True, None, None, False, None, use_device_code=False, allow_no_subscriptions=True) + + assert subs == [] + + def test_normalize(self): + cli = DummyCli() + storage_mock = {'subscriptions': None} + profile = Profile(cli_ctx=cli, storage=storage_mock) + consolidated = profile._normalize_properties(self.user1, [self.subscription1], False) + expected = self.subscription1_normalized + self.assertEqual(expected, consolidated[0]) + # verify serialization works + self.assertIsNotNone(json.dumps(consolidated[0])) + + def test_normalize_v2016_06_01(self): + cli = DummyCli() + storage_mock = {'subscriptions': None} + profile = Profile(cli_ctx=cli, storage=storage_mock) + from azure.mgmt.resource.subscriptions.v2016_06_01.models import Subscription \ + as Subscription_v2016_06_01 + subscription = Subscription_v2016_06_01() + subscription.id = self.id1 + subscription.display_name = self.display_name1 + subscription.state = self.state1 + subscription.tenant_id = self.tenant_id + + consolidated = profile._normalize_properties(self.user1, [subscription], False) + + # The subscription shouldn't have managed_by_tenants and home_tenant_id + expected = { + 'id': '1', + 'name': self.display_name1, + 'state': 'Enabled', + 'user': { + 'name': 'foo@foo.com', + 'type': 'user' + }, + 'isDefault': False, + 'tenantId': self.tenant_id, + 'environmentName': 'AzureCloud' + } + self.assertEqual(expected, consolidated[0]) + # verify serialization works + self.assertIsNotNone(json.dumps(consolidated[0])) + + def test_update_add_two_different_subscriptions(self): cli = DummyCli() - finder = mock.MagicMock() - finder.find_through_interactive_flow.return_value = [] storage_mock = {'subscriptions': []} - profile = Profile(cli_ctx=cli, storage=storage_mock, use_global_creds_cache=False, async_persist=False) + profile = Profile(cli_ctx=cli, storage=storage_mock) - # action - result = profile.find_subscriptions_on_login(True, - '1234', - 'my-secret', - False, - None, - use_device_code=False, - allow_no_subscriptions=True, - subscription_finder=finder) + # add the first and verify + consolidated = profile._normalize_properties(self.user1, + [self.subscription1], + False) + profile._set_subscriptions(consolidated) - # assert - self.assertTrue(0 == len(result)) + self.assertEqual(len(storage_mock['subscriptions']), 1) + subscription1 = storage_mock['subscriptions'][0] + subscription1_is_default = deepcopy(self.subscription1_normalized) + subscription1_is_default['isDefault'] = True + self.assertEqual(subscription1, subscription1_is_default) + + # add the second and verify + consolidated = profile._normalize_properties(self.user2, + [self.subscription2], + False) + profile._set_subscriptions(consolidated) + + self.assertEqual(len(storage_mock['subscriptions']), 2) + subscription2 = storage_mock['subscriptions'][1] + subscription2_is_default = deepcopy(self.subscription2_normalized) + subscription2_is_default['isDefault'] = True + self.assertEqual(subscription2, subscription2_is_default) + + # verify the old one stays, but no longer active + self.assertEqual(storage_mock['subscriptions'][0]['name'], + subscription1['name']) + self.assertFalse(storage_mock['subscriptions'][0]['isDefault']) - @mock.patch('azure.cli.core._profile._load_tokens_from_file', autospec=True) - def test_get_current_account_user(self, mock_read_cred_file): + def test_update_with_same_subscription_added_twice(self): + cli = DummyCli() + storage_mock = {'subscriptions': []} + profile = Profile(cli_ctx=cli, storage=storage_mock) + + # add one twice and verify we will have one but with new token + consolidated = profile._normalize_properties(self.user1, + [self.subscription1], + False) + profile._set_subscriptions(consolidated) + + new_subscription1 = SubscriptionStub(self.id1, + self.display_name1, + self.state1, + self.tenant_id) + consolidated = profile._normalize_properties(self.user1, + [new_subscription1], + False) + profile._set_subscriptions(consolidated) + + self.assertEqual(len(storage_mock['subscriptions']), 1) + self.assertTrue(storage_mock['subscriptions'][0]['isDefault']) + + def test_set_active_subscription(self): + cli = DummyCli() + storage_mock = {'subscriptions': []} + profile = Profile(cli_ctx=cli, storage=storage_mock) + + consolidated = profile._normalize_properties(self.user1, + [self.subscription1], + False) + profile._set_subscriptions(consolidated) + + consolidated = profile._normalize_properties(self.user2, + [self.subscription2], + False) + profile._set_subscriptions(consolidated) + + self.assertTrue(storage_mock['subscriptions'][1]['isDefault']) + + profile.set_active_subscription(storage_mock['subscriptions'][0]['id']) + self.assertFalse(storage_mock['subscriptions'][1]['isDefault']) + self.assertTrue(storage_mock['subscriptions'][0]['isDefault']) + + def test_default_active_subscription_to_non_disabled_one(self): + cli = DummyCli() + storage_mock = {'subscriptions': []} + profile = Profile(cli_ctx=cli, storage=storage_mock) + + subscriptions = profile._normalize_properties( + self.user2, [self.subscription2, self.subscription1], False) + + profile._set_subscriptions(subscriptions) + + # verify we skip the overdued subscription and default to the 2nd one in the list + self.assertEqual(storage_mock['subscriptions'][1]['name'], self.subscription1.display_name) + self.assertTrue(storage_mock['subscriptions'][1]['isDefault']) + + def test_get_subscription(self): + cli = DummyCli() + storage_mock = {'subscriptions': []} + profile = Profile(cli_ctx=cli, storage=storage_mock) + + consolidated = profile._normalize_properties(self.user1, + [self.subscription1], + False) + profile._set_subscriptions(consolidated) + + self.assertEqual(self.display_name1, profile.get_subscription()['name']) + self.assertEqual(self.display_name1, + profile.get_subscription(subscription=self.display_name1)['name']) + + sub_id = self.id1.split('/')[-1] + self.assertEqual(sub_id, profile.get_subscription()['id']) + self.assertEqual(sub_id, profile.get_subscription(subscription=sub_id)['id']) + self.assertRaises(CLIError, profile.get_subscription, "random_id") + + @mock.patch('azure.cli.core.profiles.get_api_version', autospec=True) + def test_subscription_finder_constructor(self, get_api_mock): + cli = DummyCli() + get_api_mock.return_value = '2019-11-01' + cli.cloud.endpoints.resource_manager = 'http://foo_arm' + finder = SubscriptionFinder(cli) + result = finder._create_subscription_client(mock.MagicMock()) + self.assertEqual(result._client._base_url, 'http://foo_arm') + + def test_get_current_account_user(self): cli = DummyCli() - # setup - mock_read_cred_file.return_value = [TestProfile.token_entry1] storage_mock = {'subscriptions': None} - profile = Profile(cli_ctx=cli, storage=storage_mock, use_global_creds_cache=False, async_persist=False) + profile = Profile(cli_ctx=cli, storage=storage_mock) consolidated = profile._normalize_properties(self.user1, [self.subscription1], False) profile._set_subscriptions(consolidated) - # action user = profile.get_current_account_user() - # verify self.assertEqual(user, self.user1) - @mock.patch('azure.cli.core._profile._load_tokens_from_file', return_value=None) - def test_create_token_cache(self, mock_read_file): - cli = DummyCli() - mock_read_file.return_value = [] - profile = Profile(cli_ctx=cli, use_global_creds_cache=False, async_persist=False) - cache = profile._creds_cache.adal_token_cache - self.assertFalse(cache.read_items()) - self.assertTrue(mock_read_file.called) - - @mock.patch('azure.cli.core._profile._load_tokens_from_file', autospec=True) - def test_load_cached_tokens(self, mock_read_file): - cli = DummyCli() - mock_read_file.return_value = [TestProfile.token_entry1] - profile = Profile(cli_ctx=cli, use_global_creds_cache=False, async_persist=False) - cache = profile._creds_cache.adal_token_cache - matched = cache.find({ - "_authority": "https://login.microsoftonline.com/common", - "_clientId": "04b07795-8ddb-461a-bbee-02f9e1bf7b46", - "userId": self.user1 - }) - self.assertEqual(len(matched), 1) - self.assertEqual(matched[0]['accessToken'], self.raw_token1) - - @mock.patch('azure.cli.core._profile._load_tokens_from_file', autospec=True) - @mock.patch('azure.cli.core._profile.CredsCache.retrieve_token_for_user', autospec=True) - def test_get_login_credentials(self, mock_get_token, mock_read_cred_file): + @mock.patch('azure.cli.core.auth.identity.UserCredential', MockCredential) + def test_get_login_credentials(self): cli = DummyCli() - some_token_type = 'Bearer' - mock_read_cred_file.return_value = [TestProfile.token_entry1] - mock_get_token.return_value = (some_token_type, TestProfile.raw_token1) # setup storage_mock = {'subscriptions': None} - profile = Profile(cli_ctx=cli, storage=storage_mock, use_global_creds_cache=False, async_persist=False) + profile = Profile(cli_ctx=cli, storage=storage_mock) test_subscription_id = '12345678-1bf0-4dda-aec3-cb9272f09590' - test_tenant_id = '12345678-38d6-4fb2-bad9-b7b93a3e1234' test_subscription = SubscriptionStub('/subscriptions/{}'.format(test_subscription_id), 'MSI-DEV-INC', self.state1, '12345678-38d6-4fb2-bad9-b7b93a3e1234') consolidated = profile._normalize_properties(self.user1, [test_subscription], - False) + False, None, None) profile._set_subscriptions(consolidated) # action cred, subscription_id, _ = profile.get_login_credentials() @@ -551,123 +873,80 @@ def test_get_login_credentials(self, mock_get_token, mock_read_cred_file): # verify self.assertEqual(subscription_id, test_subscription_id) - # verify the cred._tokenRetriever is a working lambda - token_type, token = cred._token_retriever(self.arm_resource) - self.assertEqual(token, self.raw_token1) - self.assertEqual(some_token_type, token_type) - mock_get_token.assert_called_once_with(mock.ANY, self.user1, test_tenant_id, - 'https://management.core.windows.net/') - self.assertEqual(mock_get_token.call_count, 1) - - @mock.patch('azure.cli.core._profile._load_tokens_from_file', autospec=True) - @mock.patch('azure.cli.core._profile.CredsCache.retrieve_token_for_user', autospec=True) - def test_get_login_credentials_aux_subscriptions(self, mock_get_token, mock_read_cred_file): + # verify the cred.get_token() + token = cred.get_token() + self.assertEqual(token.token, MOCK_ACCESS_TOKEN) + + @mock.patch('azure.cli.core.auth.identity.UserCredential', MockCredential) + def test_get_login_credentials_aux_subscriptions(self): cli = DummyCli() - raw_token2 = 'some...secrets2' - token_entry2 = { - "resource": "https://management.core.windows.net/", - "tokenType": "Bearer", - "_authority": "https://login.microsoftonline.com/common", - "accessToken": raw_token2, - } - some_token_type = 'Bearer' - mock_read_cred_file.return_value = [TestProfile.token_entry1, token_entry2] - mock_get_token.side_effect = [(some_token_type, TestProfile.raw_token1), (some_token_type, raw_token2)] - # setup + storage_mock = {'subscriptions': None} - profile = Profile(cli_ctx=cli, storage=storage_mock, use_global_creds_cache=False, async_persist=False) - test_subscription_id = '12345678-1bf0-4dda-aec3-cb9272f09590' + profile = Profile(cli_ctx=cli, storage=storage_mock) + test_subscription_id1 = '12345678-1bf0-4dda-aec3-cb9272f09590' test_subscription_id2 = '12345678-1bf0-4dda-aec3-cb9272f09591' - test_tenant_id = '12345678-38d6-4fb2-bad9-b7b93a3e1234' + test_tenant_id1 = '12345678-38d6-4fb2-bad9-b7b93a3e1234' test_tenant_id2 = '12345678-38d6-4fb2-bad9-b7b93a3e4321' - test_subscription = SubscriptionStub('/subscriptions/{}'.format(test_subscription_id), - 'MSI-DEV-INC', self.state1, test_tenant_id) + test_subscription1 = SubscriptionStub('/subscriptions/{}'.format(test_subscription_id1), + 'MSI-DEV-INC', self.state1, test_tenant_id1) test_subscription2 = SubscriptionStub('/subscriptions/{}'.format(test_subscription_id2), 'MSI-DEV-INC2', self.state1, test_tenant_id2) consolidated = profile._normalize_properties(self.user1, - [test_subscription, test_subscription2], - False) + [test_subscription1, test_subscription2], + False, None, None) profile._set_subscriptions(consolidated) - # action - cred, subscription_id, _ = profile.get_login_credentials(subscription_id=test_subscription_id, - aux_subscriptions=[test_subscription_id2]) - - # verify - self.assertEqual(subscription_id, test_subscription_id) - # verify the cred._tokenRetriever is a working lambda - token_type, token = cred._token_retriever(self.arm_resource) - self.assertEqual(token, self.raw_token1) - self.assertEqual(some_token_type, token_type) + cred, subscription_id, _ = profile.get_login_credentials(subscription_id=test_subscription_id1, + aux_subscriptions=[test_subscription_id2]) - token2 = cred._external_tenant_token_retriever(self.arm_resource) - self.assertEqual(len(token2), 1) - self.assertEqual(token2[0][1], raw_token2) + self.assertEqual(subscription_id, test_subscription_id1) - self.assertEqual(mock_get_token.call_count, 2) + token = cred.get_token() + aux_tokens = cred.get_auxiliary_tokens() + self.assertEqual(token.token, MOCK_ACCESS_TOKEN) + self.assertEqual(aux_tokens[0].token, MOCK_ACCESS_TOKEN) - @mock.patch('azure.cli.core._profile._load_tokens_from_file', autospec=True) - @mock.patch('azure.cli.core._profile.CredsCache.retrieve_token_for_user', autospec=True) - def test_get_login_credentials_aux_tenants(self, mock_get_token, mock_read_cred_file): + @mock.patch('azure.cli.core.auth.identity.UserCredential', MockCredential) + def test_get_login_credentials_aux_tenants(self): cli = DummyCli() - raw_token2 = 'some...secrets2' - token_entry2 = { - "resource": "https://management.core.windows.net/", - "tokenType": "Bearer", - "_authority": "https://login.microsoftonline.com/common", - "accessToken": raw_token2, - } - some_token_type = 'Bearer' - mock_read_cred_file.return_value = [TestProfile.token_entry1, token_entry2] - mock_get_token.side_effect = [(some_token_type, TestProfile.raw_token1), (some_token_type, raw_token2)] - # setup + storage_mock = {'subscriptions': None} - profile = Profile(cli_ctx=cli, storage=storage_mock, use_global_creds_cache=False, async_persist=False) - test_subscription_id = '12345678-1bf0-4dda-aec3-cb9272f09590' + profile = Profile(cli_ctx=cli, storage=storage_mock) + test_subscription_id1 = '12345678-1bf0-4dda-aec3-cb9272f09590' test_subscription_id2 = '12345678-1bf0-4dda-aec3-cb9272f09591' - test_tenant_id = '12345678-38d6-4fb2-bad9-b7b93a3e1234' + test_tenant_id1 = '12345678-38d6-4fb2-bad9-b7b93a3e1234' test_tenant_id2 = '12345678-38d6-4fb2-bad9-b7b93a3e4321' - test_subscription = SubscriptionStub('/subscriptions/{}'.format(test_subscription_id), - 'MSI-DEV-INC', self.state1, test_tenant_id) + test_subscription = SubscriptionStub('/subscriptions/{}'.format(test_subscription_id1), + 'MSI-DEV-INC', self.state1, test_tenant_id1) test_subscription2 = SubscriptionStub('/subscriptions/{}'.format(test_subscription_id2), 'MSI-DEV-INC2', self.state1, test_tenant_id2) consolidated = profile._normalize_properties(self.user1, [test_subscription, test_subscription2], - False) + False, None, None) profile._set_subscriptions(consolidated) # test only input aux_tenants - cred, subscription_id, _ = profile.get_login_credentials(subscription_id=test_subscription_id, + cred, subscription_id, _ = profile.get_login_credentials(subscription_id=test_subscription_id1, aux_tenants=[test_tenant_id2]) - # verify - self.assertEqual(subscription_id, test_subscription_id) - - # verify the cred._tokenRetriever is a working lambda - token_type, token = cred._token_retriever(self.arm_resource) - self.assertEqual(token, self.raw_token1) - self.assertEqual(some_token_type, token_type) + self.assertEqual(subscription_id, test_subscription_id1) - token2 = cred._external_tenant_token_retriever(self.arm_resource) - self.assertEqual(len(token2), 1) - self.assertEqual(token2[0][1], raw_token2) - - self.assertEqual(mock_get_token.call_count, 2) + token = cred.get_token() + aux_tokens = cred.get_auxiliary_tokens() + self.assertEqual(token.token, MOCK_ACCESS_TOKEN) + self.assertEqual(aux_tokens[0].token, MOCK_ACCESS_TOKEN) # test input aux_tenants and aux_subscriptions with self.assertRaisesRegexp(CLIError, "Please specify only one of aux_subscriptions and aux_tenants, not both"): - cred, subscription_id, _ = profile.get_login_credentials(subscription_id=test_subscription_id, + cred, subscription_id, _ = profile.get_login_credentials(subscription_id=test_subscription_id1, aux_subscriptions=[test_subscription_id2], aux_tenants=[test_tenant_id2]) - @mock.patch('azure.cli.core._profile._load_tokens_from_file', autospec=True) - @mock.patch('azure.cli.core.adal_authentication.MSIAuthenticationWrapper', autospec=True) - def test_get_login_credentials_msi_system_assigned(self, mock_msi_auth, mock_read_cred_file): - mock_read_cred_file.return_value = [] + @mock.patch('azure.cli.core.auth.adal_authentication.MSIAuthenticationWrapper', MSRestAzureAuthStub) + def test_get_login_credentials_msi_system_assigned(self): # setup an existing msi subscription - profile = Profile(cli_ctx=DummyCli(), storage={'subscriptions': None}, use_global_creds_cache=False, - async_persist=False) + profile = Profile(cli_ctx=DummyCli(), storage={'subscriptions': None}) test_subscription_id = '12345678-1bf0-4dda-aec3-cb9272f09590' test_tenant_id = '12345678-38d6-4fb2-bad9-b7b93a3e1234' test_user = 'systemAssignedIdentity' @@ -677,12 +956,8 @@ def test_get_login_credentials_msi_system_assigned(self, mock_msi_auth, mock_rea True) profile._set_subscriptions(consolidated) - mock_msi_auth.side_effect = MSRestAzureAuthStub - - # action cred, subscription_id, _ = profile.get_login_credentials() - # assert self.assertEqual(subscription_id, test_subscription_id) # sniff test the msi_auth object @@ -691,14 +966,10 @@ def test_get_login_credentials_msi_system_assigned(self, mock_msi_auth, mock_rea self.assertTrue(cred.set_token_invoked_count) self.assertTrue(cred.token_read_count) - @mock.patch('azure.cli.core._profile._load_tokens_from_file', autospec=True) - @mock.patch('azure.cli.core.adal_authentication.MSIAuthenticationWrapper', autospec=True) - def test_get_login_credentials_msi_user_assigned_with_client_id(self, mock_msi_auth, mock_read_cred_file): - mock_read_cred_file.return_value = [] - + @mock.patch('azure.cli.core.auth.adal_authentication.MSIAuthenticationWrapper', MSRestAzureAuthStub) + def test_get_login_credentials_msi_user_assigned_with_client_id(self): # setup an existing msi subscription - profile = Profile(cli_ctx=DummyCli(), storage={'subscriptions': None}, use_global_creds_cache=False, - async_persist=False) + profile = Profile(cli_ctx=DummyCli(), storage={'subscriptions': None}) test_subscription_id = '12345678-1bf0-4dda-aec3-cb9272f09590' test_tenant_id = '12345678-38d6-4fb2-bad9-b7b93a3e1234' test_user = 'userAssignedIdentity' @@ -707,12 +978,8 @@ def test_get_login_credentials_msi_user_assigned_with_client_id(self, mock_msi_a consolidated = profile._normalize_properties(test_user, [msi_subscription], True) profile._set_subscriptions(consolidated, secondary_key_name='name') - mock_msi_auth.side_effect = MSRestAzureAuthStub - - # action cred, subscription_id, _ = profile.get_login_credentials() - # assert self.assertEqual(subscription_id, test_subscription_id) # sniff test the msi_auth object @@ -722,14 +989,11 @@ def test_get_login_credentials_msi_user_assigned_with_client_id(self, mock_msi_a self.assertTrue(cred.token_read_count) self.assertTrue(cred.client_id, test_client_id) - @mock.patch('azure.cli.core._profile._load_tokens_from_file', autospec=True) - @mock.patch('azure.cli.core.adal_authentication.MSIAuthenticationWrapper', autospec=True) - def test_get_login_credentials_msi_user_assigned_with_object_id(self, mock_msi_auth, mock_read_cred_file): - mock_read_cred_file.return_value = [] + @mock.patch('azure.cli.core.auth.adal_authentication.MSIAuthenticationWrapper', MSRestAzureAuthStub) + def test_get_login_credentials_msi_user_assigned_with_object_id(self): # setup an existing msi subscription - profile = Profile(cli_ctx=DummyCli(), storage={'subscriptions': None}, use_global_creds_cache=False, - async_persist=False) + profile = Profile(cli_ctx=DummyCli(), storage={'subscriptions': None}) test_subscription_id = '12345678-1bf0-4dda-aec3-cb9272f09590' test_object_id = '12345678-38d6-4fb2-bad9-b7b93a3e9999' msi_subscription = SubscriptionStub('/subscriptions/12345678-1bf0-4dda-aec3-cb9272f09590', @@ -738,12 +1002,8 @@ def test_get_login_credentials_msi_user_assigned_with_object_id(self, mock_msi_a consolidated = profile._normalize_properties('userAssignedIdentity', [msi_subscription], True) profile._set_subscriptions(consolidated, secondary_key_name='name') - mock_msi_auth.side_effect = MSRestAzureAuthStub - - # action cred, subscription_id, _ = profile.get_login_credentials() - # assert self.assertEqual(subscription_id, test_subscription_id) # sniff test the msi_auth object @@ -753,14 +1013,10 @@ def test_get_login_credentials_msi_user_assigned_with_object_id(self, mock_msi_a self.assertTrue(cred.token_read_count) self.assertTrue(cred.object_id, test_object_id) - @mock.patch('azure.cli.core._profile._load_tokens_from_file', autospec=True) - @mock.patch('azure.cli.core.adal_authentication.MSIAuthenticationWrapper', autospec=True) - def test_get_login_credentials_msi_user_assigned_with_res_id(self, mock_msi_auth, mock_read_cred_file): - mock_read_cred_file.return_value = [] - + @mock.patch('azure.cli.core.auth.adal_authentication.MSIAuthenticationWrapper', MSRestAzureAuthStub) + def test_get_login_credentials_msi_user_assigned_with_res_id(self): # setup an existing msi subscription - profile = Profile(cli_ctx=DummyCli(), storage={'subscriptions': None}, use_global_creds_cache=False, - async_persist=False) + profile = Profile(cli_ctx=DummyCli(), storage={'subscriptions': None}) test_subscription_id = '12345678-1bf0-4dda-aec3-cb9272f09590' test_res_id = ('/subscriptions/{}/resourceGroups/r1/providers/Microsoft.ManagedIdentity/' 'userAssignedIdentities/id1').format(test_subscription_id) @@ -770,12 +1026,8 @@ def test_get_login_credentials_msi_user_assigned_with_res_id(self, mock_msi_auth consolidated = profile._normalize_properties('userAssignedIdentity', [msi_subscription], True) profile._set_subscriptions(consolidated, secondary_key_name='name') - mock_msi_auth.side_effect = MSRestAzureAuthStub - - # action cred, subscription_id, _ = profile.get_login_credentials() - # assert self.assertEqual(subscription_id, test_subscription_id) # sniff test the msi_auth object @@ -785,57 +1037,53 @@ def test_get_login_credentials_msi_user_assigned_with_res_id(self, mock_msi_auth self.assertTrue(cred.token_read_count) self.assertTrue(cred.msi_res_id, test_res_id) - @mock.patch('azure.cli.core._profile._load_tokens_from_file', autospec=True) - @mock.patch('azure.cli.core._profile.CredsCache.retrieve_token_for_user', autospec=True) - def test_get_raw_token(self, mock_get_token, mock_read_cred_file): + @mock.patch('azure.cli.core.auth.identity.UserCredential', MockCredential) + def test_get_raw_token(self): cli = DummyCli() - some_token_type = 'Bearer' - mock_read_cred_file.return_value = [TestProfile.token_entry1] - mock_get_token.return_value = (some_token_type, TestProfile.raw_token1, - TestProfile.token_entry1) # setup storage_mock = {'subscriptions': None} - profile = Profile(cli_ctx=cli, storage=storage_mock, use_global_creds_cache=False, async_persist=False) + profile = Profile(cli_ctx=cli, storage=storage_mock) consolidated = profile._normalize_properties(self.user1, [self.subscription1], - False) + False, None, None) profile._set_subscriptions(consolidated) + # action - creds, sub, tenant = profile.get_raw_token(resource='https://foo') + # Get token with ADAL-style resource + resource_result = profile.get_raw_token(resource='https://foo') + # Get token with MSAL-style scopes + scopes_result = profile.get_raw_token(scopes=self.msal_scopes) # verify - self.assertEqual(creds[0], self.token_entry1['tokenType']) - self.assertEqual(creds[1], self.raw_token1) - # the last in the tuple is the whole token entry which has several fields - self.assertEqual(creds[2]['expiresOn'], self.token_entry1['expiresOn']) - mock_get_token.assert_called_once_with(mock.ANY, self.user1, self.tenant_id, - 'https://foo') - self.assertEqual(mock_get_token.call_count, 1) - self.assertEqual(sub, '1') + self.assertEqual(resource_result, scopes_result) + creds, sub, tenant = scopes_result + + self.assertEqual(creds[0], 'Bearer') + self.assertEqual(creds[1], MOCK_ACCESS_TOKEN) + self.assertEqual(creds[2]['expires_on'], MOCK_EXPIRES_ON) + + # subscription should be set + self.assertEqual(sub, self.subscription1.subscription_id) self.assertEqual(tenant, self.tenant_id) # Test get_raw_token with tenant creds, sub, tenant = profile.get_raw_token(resource='https://foo', tenant=self.tenant_id) - self.assertEqual(creds[0], self.token_entry1['tokenType']) - self.assertEqual(creds[1], self.raw_token1) - self.assertEqual(creds[2]['expiresOn'], self.token_entry1['expiresOn']) - mock_get_token.assert_called_with(mock.ANY, self.user1, self.tenant_id, 'https://foo') - self.assertEqual(mock_get_token.call_count, 2) + self.assertEqual(creds[0], 'Bearer') + self.assertEqual(creds[1], MOCK_ACCESS_TOKEN) + self.assertEqual(creds[2]['expires_on'], MOCK_EXPIRES_ON) + + # subscription shouldn't be set self.assertIsNone(sub) self.assertEqual(tenant, self.tenant_id) - @mock.patch('azure.cli.core._profile._load_tokens_from_file', autospec=True) - @mock.patch('azure.cli.core._profile.CredsCache.retrieve_token_for_service_principal', autospec=True) - def test_get_raw_token_for_sp(self, mock_get_token, mock_read_cred_file): + @mock.patch('azure.cli.core.auth.identity.Identity.get_service_principal_credential') + def test_get_raw_token_for_sp(self, get_service_principal_credential_mock): + get_service_principal_credential_mock.return_value = MockCredential() cli = DummyCli() - some_token_type = 'Bearer' - mock_read_cred_file.return_value = [TestProfile.token_entry1] - mock_get_token.return_value = (some_token_type, TestProfile.raw_token1, - TestProfile.token_entry1) # setup storage_mock = {'subscriptions': None} - profile = Profile(cli_ctx=cli, storage=storage_mock, use_global_creds_cache=False, async_persist=False) + profile = Profile(cli_ctx=cli, storage=storage_mock) consolidated = profile._normalize_properties('sp1', [self.subscription1], True) @@ -844,34 +1092,30 @@ def test_get_raw_token_for_sp(self, mock_get_token, mock_read_cred_file): creds, sub, tenant = profile.get_raw_token(resource='https://foo') # verify - self.assertEqual(creds[0], self.token_entry1['tokenType']) - self.assertEqual(creds[1], self.raw_token1) + self.assertEqual(creds[0], BEARER) + self.assertEqual(creds[1], MOCK_ACCESS_TOKEN) # the last in the tuple is the whole token entry which has several fields - self.assertEqual(creds[2]['expiresOn'], self.token_entry1['expiresOn']) - mock_get_token.assert_called_once_with(mock.ANY, 'sp1', 'https://foo', self.tenant_id, False) - self.assertEqual(mock_get_token.call_count, 1) - self.assertEqual(sub, '1') + self.assertEqual(creds[2]['expires_on'], MOCK_EXPIRES_ON) + + # subscription should be set + self.assertEqual(sub, self.subscription1.subscription_id) self.assertEqual(tenant, self.tenant_id) # Test get_raw_token with tenant creds, sub, tenant = profile.get_raw_token(resource='https://foo', tenant=self.tenant_id) - self.assertEqual(creds[0], self.token_entry1['tokenType']) - self.assertEqual(creds[1], self.raw_token1) - self.assertEqual(creds[2]['expiresOn'], self.token_entry1['expiresOn']) - mock_get_token.assert_called_with(mock.ANY, 'sp1', 'https://foo', self.tenant_id, False) - self.assertEqual(mock_get_token.call_count, 2) + self.assertEqual(creds[0], BEARER) + self.assertEqual(creds[1], MOCK_ACCESS_TOKEN) + self.assertEqual(creds[2]['expires_on'], MOCK_EXPIRES_ON) + + # subscription shouldn't be set self.assertIsNone(sub) self.assertEqual(tenant, self.tenant_id) - @mock.patch('azure.cli.core._profile._load_tokens_from_file', autospec=True) - @mock.patch('azure.cli.core.adal_authentication.MSIAuthenticationWrapper', autospec=True) - def test_get_raw_token_msi_system_assigned(self, mock_msi_auth, mock_read_cred_file): - mock_read_cred_file.return_value = [] - + @mock.patch('azure.cli.core.auth.adal_authentication.MSIAuthenticationWrapper', autospec=True) + def test_get_raw_token_msi_system_assigned(self, mock_msi_auth): # setup an existing msi subscription - profile = Profile(cli_ctx=DummyCli(), storage={'subscriptions': None}, use_global_creds_cache=False, - async_persist=False) + profile = Profile(cli_ctx=DummyCli(), storage={'subscriptions': None}) test_subscription_id = '12345678-1bf0-4dda-aec3-cb9272f09590' test_tenant_id = '12345678-38d6-4fb2-bad9-b7b93a3e1234' test_user = 'systemAssignedIdentity' @@ -899,15 +1143,12 @@ def test_get_raw_token_msi_system_assigned(self, mock_msi_auth, mock_read_cred_f cred, subscription_id, _ = profile.get_raw_token(resource='http://test_resource', tenant=self.tenant_id) @mock.patch('azure.cli.core._profile.in_cloud_console', autospec=True) - @mock.patch('azure.cli.core._profile._load_tokens_from_file', autospec=True) - @mock.patch('azure.cli.core.adal_authentication.MSIAuthenticationWrapper', autospec=True) - def test_get_raw_token_in_cloud_console(self, mock_msi_auth, mock_read_cred_file, mock_in_cloud_console): - mock_read_cred_file.return_value = [] + @mock.patch('azure.cli.core.auth.adal_authentication.MSIAuthenticationWrapper', autospec=True) + def test_get_raw_token_in_cloud_console(self, mock_msi_auth, mock_in_cloud_console): mock_in_cloud_console.return_value = True # setup an existing msi subscription - profile = Profile(cli_ctx=DummyCli(), storage={'subscriptions': None}, use_global_creds_cache=False, - async_persist=False) + profile = Profile(cli_ctx=DummyCli(), storage={'subscriptions': None}) test_subscription_id = '12345678-1bf0-4dda-aec3-cb9272f09590' test_tenant_id = '12345678-38d6-4fb2-bad9-b7b93a3e1234' msi_subscription = SubscriptionStub('/subscriptions/' + test_subscription_id, @@ -915,594 +1156,118 @@ def test_get_raw_token_in_cloud_console(self, mock_msi_auth, mock_read_cred_file consolidated = profile._normalize_properties(self.user1, [msi_subscription], True) - consolidated[0]['user']['cloudShellID'] = True - profile._set_subscriptions(consolidated) - - mock_msi_auth.side_effect = MSRestAzureAuthStub - - # action - cred, subscription_id, tenant_id = profile.get_raw_token(resource='http://test_resource') - - # assert - self.assertEqual(subscription_id, test_subscription_id) - self.assertEqual(cred[0], 'Bearer') - self.assertEqual(cred[1], TestProfile.test_msi_access_token) - self.assertEqual(subscription_id, test_subscription_id) - self.assertEqual(tenant_id, test_tenant_id) - - # verify tenant shouldn't be specified for Cloud Shell account - with self.assertRaisesRegexp(CLIError, 'Cloud Shell'): - cred, subscription_id, _ = profile.get_raw_token(resource='http://test_resource', tenant=self.tenant_id) - - @mock.patch('azure.cli.core._profile._load_tokens_from_file', autospec=True) - @mock.patch('azure.cli.core._profile.CredsCache.retrieve_token_for_user', autospec=True) - def test_get_login_credentials_for_graph_client(self, mock_get_token, mock_read_cred_file): - cli = DummyCli() - some_token_type = 'Bearer' - mock_read_cred_file.return_value = [TestProfile.token_entry1] - mock_get_token.return_value = (some_token_type, TestProfile.raw_token1) - # setup - storage_mock = {'subscriptions': None} - profile = Profile(cli_ctx=cli, storage=storage_mock, use_global_creds_cache=False, async_persist=False) - consolidated = profile._normalize_properties(self.user1, [self.subscription1], - False) - profile._set_subscriptions(consolidated) - # action - cred, _, tenant_id = profile.get_login_credentials( - resource=cli.cloud.endpoints.active_directory_graph_resource_id) - _, _ = cred._token_retriever('https://graph.windows.net/') - # verify - mock_get_token.assert_called_once_with(mock.ANY, self.user1, self.tenant_id, - 'https://graph.windows.net/') - self.assertEqual(tenant_id, self.tenant_id) - - @mock.patch('azure.cli.core._profile._load_tokens_from_file', autospec=True) - @mock.patch('azure.cli.core._profile.CredsCache.retrieve_token_for_user', autospec=True) - def test_get_login_credentials_for_data_lake_client(self, mock_get_token, mock_read_cred_file): - cli = DummyCli() - some_token_type = 'Bearer' - mock_read_cred_file.return_value = [TestProfile.token_entry1] - mock_get_token.return_value = (some_token_type, TestProfile.raw_token1) - # setup - storage_mock = {'subscriptions': None} - profile = Profile(cli_ctx=cli, storage=storage_mock, use_global_creds_cache=False, async_persist=False) - consolidated = profile._normalize_properties(self.user1, [self.subscription1], - False) - profile._set_subscriptions(consolidated) - # action - cred, _, tenant_id = profile.get_login_credentials( - resource=cli.cloud.endpoints.active_directory_data_lake_resource_id) - _, _ = cred._token_retriever('https://datalake.azure.net/') - # verify - mock_get_token.assert_called_once_with(mock.ANY, self.user1, self.tenant_id, - 'https://datalake.azure.net/') - self.assertEqual(tenant_id, self.tenant_id) - - @mock.patch('azure.cli.core._profile._load_tokens_from_file', autospec=True) - @mock.patch('azure.cli.core._profile.CredsCache.persist_cached_creds', autospec=True) - def test_logout(self, mock_persist_creds, mock_read_cred_file): - cli = DummyCli() - # setup - mock_read_cred_file.return_value = [TestProfile.token_entry1] - - storage_mock = {'subscriptions': None} - profile = Profile(cli_ctx=cli, storage=storage_mock, use_global_creds_cache=False, async_persist=False) - consolidated = profile._normalize_properties(self.user1, - [self.subscription1], - False) - profile._set_subscriptions(consolidated) - self.assertEqual(1, len(storage_mock['subscriptions'])) - # action - profile.logout(self.user1) - - # verify - self.assertEqual(0, len(storage_mock['subscriptions'])) - self.assertEqual(mock_read_cred_file.call_count, 1) - self.assertEqual(mock_persist_creds.call_count, 1) - - @mock.patch('azure.cli.core._profile._delete_file', autospec=True) - def test_logout_all(self, mock_delete_cred_file): - cli = DummyCli() - # setup - storage_mock = {'subscriptions': None} - profile = Profile(cli_ctx=cli, storage=storage_mock, use_global_creds_cache=False, async_persist=False) - consolidated = profile._normalize_properties(self.user1, - [self.subscription1], - False) - consolidated2 = profile._normalize_properties(self.user2, - [self.subscription2], - False) - profile._set_subscriptions(consolidated + consolidated2) - - self.assertEqual(2, len(storage_mock['subscriptions'])) - # action - profile.logout_all() - - # verify - self.assertEqual([], storage_mock['subscriptions']) - self.assertEqual(mock_delete_cred_file.call_count, 1) - - @mock.patch('adal.AuthenticationContext', autospec=True) - def test_find_subscriptions_thru_username_password(self, mock_auth_context): - cli = DummyCli() - mock_auth_context.acquire_token_with_username_password.return_value = self.token_entry1 - mock_auth_context.acquire_token.return_value = self.token_entry1 - mock_arm_client = mock.MagicMock() - mock_arm_client.tenants.list.return_value = [TenantStub(self.tenant_id)] - mock_arm_client.subscriptions.list.return_value = [deepcopy(self.subscription1_raw)] - finder = SubscriptionFinder(cli, lambda _, _1, _2: mock_auth_context, None, lambda _: mock_arm_client) - mgmt_resource = 'https://management.core.windows.net/' - # action - subs = finder.find_from_user_account(self.user1, 'bar', None, mgmt_resource) - - # assert - self.assertEqual([self.subscription1], subs) - mock_auth_context.acquire_token_with_username_password.assert_called_once_with( - mgmt_resource, self.user1, 'bar', mock.ANY) - mock_auth_context.acquire_token.assert_called_once_with( - mgmt_resource, self.user1, mock.ANY) - - @mock.patch('adal.AuthenticationContext', autospec=True) - def test_find_subscriptions_thru_username_non_password(self, mock_auth_context): - cli = DummyCli() - mock_auth_context.acquire_token_with_username_password.return_value = None - finder = SubscriptionFinder(cli, lambda _, _1, _2: mock_auth_context, None, lambda _: None) - # action - subs = finder.find_from_user_account(self.user1, 'bar', None, 'http://goo-resource') - - # assert - self.assertEqual([], subs) - - @mock.patch('azure.cli.core.adal_authentication.MSIAuthenticationWrapper', autospec=True) - @mock.patch('azure.cli.core.profiles._shared.get_client_class', autospec=True) - @mock.patch('azure.cli.core._profile._get_cloud_console_token_endpoint', autospec=True) - @mock.patch('azure.cli.core._profile.SubscriptionFinder', autospec=True) - def test_find_subscriptions_in_cloud_console(self, mock_subscription_finder, mock_get_token_endpoint, - mock_get_client_class, mock_msi_auth): - - class SubscriptionFinderStub: - def find_from_raw_token(self, tenant, token): - # make sure the tenant and token args match 'TestProfile.test_msi_access_token' - if token != TestProfile.test_msi_access_token or tenant != '54826b22-38d6-4fb2-bad9-b7b93a3e9c5a': - raise AssertionError('find_from_raw_token was not invoked with expected tenant or token') - return [TestProfile.subscription1] - - mock_subscription_finder.return_value = SubscriptionFinderStub() - - mock_get_token_endpoint.return_value = "http://great_endpoint" - mock_msi_auth.return_value = MSRestAzureAuthStub() - - profile = Profile(cli_ctx=DummyCli(), storage={'subscriptions': None}, use_global_creds_cache=False, - async_persist=False) - - # action - subscriptions = profile.find_subscriptions_in_cloud_console() - - # assert - self.assertEqual(len(subscriptions), 1) - s = subscriptions[0] - self.assertEqual(s['user']['name'], 'admin3@AzureSDKTeam.onmicrosoft.com') - self.assertEqual(s['user']['cloudShellID'], True) - self.assertEqual(s['user']['type'], 'user') - self.assertEqual(s['name'], self.display_name1) - self.assertEqual(s['id'], self.id1.split('/')[-1]) - - @mock.patch('requests.get', autospec=True) - @mock.patch('azure.cli.core._profile.SubscriptionFinder._get_subscription_client_class', autospec=True) - def test_find_subscriptions_in_vm_with_msi_system_assigned(self, mock_get_client_class, mock_get): - - class ClientStub: - def __init__(self, *args, **kwargs): - self.subscriptions = mock.MagicMock() - self.subscriptions.list.return_value = [deepcopy(TestProfile.subscription1_raw)] - self.config = mock.MagicMock() - self._client = mock.MagicMock() - - mock_get_client_class.return_value = ClientStub - cli = DummyCli() - storage_mock = {'subscriptions': None} - profile = Profile(cli_ctx=cli, storage=storage_mock, use_global_creds_cache=False, async_persist=False) - - test_token_entry = { - 'token_type': 'Bearer', - 'access_token': TestProfile.test_msi_access_token - } - encoded_test_token = json.dumps(test_token_entry).encode() - good_response = mock.MagicMock() - good_response.status_code = 200 - good_response.content = encoded_test_token - mock_get.return_value = good_response - - subscriptions = profile.find_subscriptions_in_vm_with_msi() - - # assert - self.assertEqual(len(subscriptions), 1) - s = subscriptions[0] - self.assertEqual(s['user']['name'], 'systemAssignedIdentity') - self.assertEqual(s['user']['type'], 'servicePrincipal') - self.assertEqual(s['user']['assignedIdentityInfo'], 'MSI') - self.assertEqual(s['name'], self.display_name1) - self.assertEqual(s['id'], self.id1.split('/')[-1]) - self.assertEqual(s['tenantId'], '54826b22-38d6-4fb2-bad9-b7b93a3e9c5a') - - @mock.patch('requests.get', autospec=True) - @mock.patch('azure.cli.core._profile.SubscriptionFinder._get_subscription_client_class', autospec=True) - def test_find_subscriptions_in_vm_with_msi_no_subscriptions(self, mock_get_client_class, mock_get): - - class ClientStub: - def __init__(self, *args, **kwargs): - self.subscriptions = mock.MagicMock() - self.subscriptions.list.return_value = [] - self.config = mock.MagicMock() - self._client = mock.MagicMock() - - mock_get_client_class.return_value = ClientStub - cli = DummyCli() - storage_mock = {'subscriptions': None} - profile = Profile(cli_ctx=cli, storage=storage_mock, use_global_creds_cache=False, async_persist=False) - - test_token_entry = { - 'token_type': 'Bearer', - 'access_token': TestProfile.test_msi_access_token - } - encoded_test_token = json.dumps(test_token_entry).encode() - good_response = mock.MagicMock() - good_response.status_code = 200 - good_response.content = encoded_test_token - mock_get.return_value = good_response - - subscriptions = profile.find_subscriptions_in_vm_with_msi(allow_no_subscriptions=True) - - # assert - self.assertEqual(len(subscriptions), 1) - s = subscriptions[0] - self.assertEqual(s['user']['name'], 'systemAssignedIdentity') - self.assertEqual(s['user']['type'], 'servicePrincipal') - self.assertEqual(s['user']['assignedIdentityInfo'], 'MSI') - self.assertEqual(s['name'], 'N/A(tenant level account)') - self.assertEqual(s['id'], self.test_msi_tenant) - self.assertEqual(s['tenantId'], self.test_msi_tenant) - - @mock.patch('requests.get', autospec=True) - @mock.patch('azure.cli.core._profile.SubscriptionFinder._get_subscription_client_class', autospec=True) - def test_find_subscriptions_in_vm_with_msi_user_assigned_with_client_id(self, mock_get_client_class, mock_get): - - class ClientStub: - def __init__(self, *args, **kwargs): - self.subscriptions = mock.MagicMock() - self.subscriptions.list.return_value = [deepcopy(TestProfile.subscription1_raw)] - self.config = mock.MagicMock() - self._client = mock.MagicMock() - - mock_get_client_class.return_value = ClientStub - cli = DummyCli() - storage_mock = {'subscriptions': None} - profile = Profile(cli_ctx=cli, storage=storage_mock, use_global_creds_cache=False, async_persist=False) - - test_token_entry = { - 'token_type': 'Bearer', - 'access_token': TestProfile.test_msi_access_token - } - test_client_id = '54826b22-38d6-4fb2-bad9-b7b93a3e9999' - encoded_test_token = json.dumps(test_token_entry).encode() - good_response = mock.MagicMock() - good_response.status_code = 200 - good_response.content = encoded_test_token - mock_get.return_value = good_response - - subscriptions = profile.find_subscriptions_in_vm_with_msi(identity_id=test_client_id) - - # assert - self.assertEqual(len(subscriptions), 1) - s = subscriptions[0] - self.assertEqual(s['user']['name'], 'userAssignedIdentity') - self.assertEqual(s['user']['type'], 'servicePrincipal') - self.assertEqual(s['name'], self.display_name1) - self.assertEqual(s['user']['assignedIdentityInfo'], 'MSIClient-{}'.format(test_client_id)) - self.assertEqual(s['id'], self.id1.split('/')[-1]) - self.assertEqual(s['tenantId'], '54826b22-38d6-4fb2-bad9-b7b93a3e9c5a') - - @mock.patch('azure.cli.core.adal_authentication.MSIAuthenticationWrapper', autospec=True) - @mock.patch('azure.cli.core.profiles._shared.get_client_class', autospec=True) - @mock.patch('azure.cli.core._profile.SubscriptionFinder', autospec=True) - def test_find_subscriptions_in_vm_with_msi_user_assigned_with_object_id(self, mock_subscription_finder, mock_get_client_class, - mock_msi_auth): - from azure.cli.core.azclierror import AzureResponseError - - class SubscriptionFinderStub: - def find_from_raw_token(self, tenant, token): - # make sure the tenant and token args match 'TestProfile.test_msi_access_token' - if token != TestProfile.test_msi_access_token or tenant != '54826b22-38d6-4fb2-bad9-b7b93a3e9c5a': - raise AssertionError('find_from_raw_token was not invoked with expected tenant or token') - return [TestProfile.subscription1] - - class AuthStub: - def __init__(self, **kwargs): - self.token = None - self.client_id = kwargs.get('client_id') - self.object_id = kwargs.get('object_id') - # since msrestazure 0.4.34, set_token in init - self.set_token() - - def set_token(self): - # here we will reject the 1st sniffing of trying with client_id and then acccept the 2nd - if self.object_id: - self.token = { - 'token_type': 'Bearer', - 'access_token': TestProfile.test_msi_access_token - } - else: - raise AzureResponseError('Failed to connect to MSI. Please make sure MSI is configured correctly.\n' - 'Get Token request returned http error: 400, reason: Bad Request') - - profile = Profile(cli_ctx=DummyCli(), storage={'subscriptions': None}, use_global_creds_cache=False, - async_persist=False) - - mock_subscription_finder.return_value = SubscriptionFinderStub() - - mock_msi_auth.side_effect = AuthStub - test_object_id = '54826b22-38d6-4fb2-bad9-b7b93a3e9999' - - # action - subscriptions = profile.find_subscriptions_in_vm_with_msi(identity_id=test_object_id) - - # assert - self.assertEqual(subscriptions[0]['user']['assignedIdentityInfo'], 'MSIObject-{}'.format(test_object_id)) - - @mock.patch('requests.get', autospec=True) - @mock.patch('azure.cli.core._profile.SubscriptionFinder._get_subscription_client_class', autospec=True) - def test_find_subscriptions_in_vm_with_msi_user_assigned_with_res_id(self, mock_get_client_class, mock_get): - - class ClientStub: - def __init__(self, *args, **kwargs): - self.subscriptions = mock.MagicMock() - self.subscriptions.list.return_value = [deepcopy(TestProfile.subscription1_raw)] - self.config = mock.MagicMock() - self._client = mock.MagicMock() - - mock_get_client_class.return_value = ClientStub - cli = DummyCli() - storage_mock = {'subscriptions': None} - profile = Profile(cli_ctx=cli, storage=storage_mock, use_global_creds_cache=False, async_persist=False) - - test_token_entry = { - 'token_type': 'Bearer', - 'access_token': TestProfile.test_msi_access_token - } - test_res_id = ('/subscriptions/0b1f6471-1bf0-4dda-aec3-cb9272f09590/resourcegroups/g1/' - 'providers/Microsoft.ManagedIdentity/userAssignedIdentities/id1') - - encoded_test_token = json.dumps(test_token_entry).encode() - good_response = mock.MagicMock() - good_response.status_code = 200 - good_response.content = encoded_test_token - mock_get.return_value = good_response - - subscriptions = profile.find_subscriptions_in_vm_with_msi(identity_id=test_res_id) - - # assert - self.assertEqual(subscriptions[0]['user']['assignedIdentityInfo'], 'MSIResource-{}'.format(test_res_id)) - - @mock.patch('adal.AuthenticationContext.acquire_token_with_username_password', autospec=True) - @mock.patch('adal.AuthenticationContext.acquire_token', autospec=True) - def test_find_subscriptions_thru_username_password_adfs(self, mock_acquire_token, - mock_acquire_token_username_password): - cli = DummyCli() - TEST_ADFS_AUTH_URL = 'https://adfs.local.azurestack.external/adfs' - - def test_acquire_token(self, resource, username, password, client_id): - global acquire_token_invoked - acquire_token_invoked = True - if (self.authority.url == TEST_ADFS_AUTH_URL and self.authority.is_adfs_authority): - return TestProfile.token_entry1 - else: - raise ValueError('AuthContext was not initialized correctly for ADFS') - - mock_acquire_token_username_password.side_effect = test_acquire_token - mock_acquire_token.return_value = self.token_entry1 - mock_arm_client = mock.MagicMock() - mock_arm_client.tenants.list.return_value = [TenantStub(self.tenant_id)] - mock_arm_client.subscriptions.list.return_value = [deepcopy(self.subscription1_raw)] - cli.cloud.endpoints.active_directory = TEST_ADFS_AUTH_URL - finder = SubscriptionFinder(cli, _AUTH_CTX_FACTORY, None, lambda _: mock_arm_client) - mgmt_resource = 'https://management.core.windows.net/' - # action - subs = finder.find_from_user_account(self.user1, 'bar', None, mgmt_resource) - - # assert - self.assertEqual([self.subscription1], subs) - self.assertTrue(acquire_token_invoked) - - @mock.patch('adal.AuthenticationContext', autospec=True) - @mock.patch('azure.cli.core._profile.logger', autospec=True) - def test_find_subscriptions_thru_username_password_with_account_disabled(self, mock_logger, mock_auth_context): - cli = DummyCli() - mock_auth_context.acquire_token_with_username_password.return_value = self.token_entry1 - mock_auth_context.acquire_token.side_effect = AdalError('Account is disabled') - mock_arm_client = mock.MagicMock() - mock_arm_client.tenants.list.return_value = [TenantStub(self.tenant_id)] - finder = SubscriptionFinder(cli, lambda _, _1, _2: mock_auth_context, None, lambda _: mock_arm_client) - mgmt_resource = 'https://management.core.windows.net/' - # action - subs = finder.find_from_user_account(self.user1, 'bar', None, mgmt_resource) - - # assert - self.assertEqual([], subs) - mock_logger.warning.assert_called_once_with(mock.ANY, mock.ANY, mock.ANY) + consolidated[0]['user']['cloudShellID'] = True + profile._set_subscriptions(consolidated) - @mock.patch('adal.AuthenticationContext', autospec=True) - def test_find_subscriptions_from_particular_tenent(self, mock_auth_context): - def just_raise(ex): - raise ex + mock_msi_auth.side_effect = MSRestAzureAuthStub - cli = DummyCli() - mock_arm_client = mock.MagicMock() - mock_arm_client.tenants.list.side_effect = lambda: just_raise( - ValueError("'tenants.list' should not occur")) - mock_arm_client.subscriptions.list.return_value = [deepcopy(self.subscription1_raw)] - finder = SubscriptionFinder(cli, lambda _, _1, _2: mock_auth_context, None, lambda _: mock_arm_client) # action - subs = finder.find_from_user_account(self.user1, 'bar', self.tenant_id, 'http://someresource') + cred, subscription_id, tenant_id = profile.get_raw_token(resource='http://test_resource') # assert - self.assertEqual([self.subscription1], subs) + self.assertEqual(subscription_id, test_subscription_id) + self.assertEqual(cred[0], 'Bearer') + self.assertEqual(cred[1], TestProfile.test_msi_access_token) + self.assertEqual(subscription_id, test_subscription_id) + self.assertEqual(tenant_id, test_tenant_id) - @mock.patch('adal.AuthenticationContext', autospec=True) - def test_find_subscriptions_through_device_code_flow(self, mock_auth_context): - cli = DummyCli() - test_nonsense_code = {'message': 'magic code for you'} - mock_auth_context.acquire_user_code.return_value = test_nonsense_code - mock_auth_context.acquire_token_with_device_code.return_value = self.token_entry1 - mock_arm_client = mock.MagicMock() - mock_arm_client.tenants.list.return_value = [TenantStub(self.tenant_id)] - mock_arm_client.subscriptions.list.return_value = [deepcopy(self.subscription1_raw)] - finder = SubscriptionFinder(cli, lambda _, _1, _2: mock_auth_context, None, lambda _: mock_arm_client) - mgmt_resource = 'https://management.core.windows.net/' - # action - subs = finder.find_through_interactive_flow(None, mgmt_resource) + # verify tenant shouldn't be specified for Cloud Shell account + with self.assertRaisesRegexp(CLIError, 'Cloud Shell'): + cred, subscription_id, _ = profile.get_raw_token(resource='http://test_resource', tenant=self.tenant_id) - # assert - self.assertEqual([self.subscription1], subs) - mock_auth_context.acquire_user_code.assert_called_once_with( - mgmt_resource, mock.ANY) - mock_auth_context.acquire_token_with_device_code.assert_called_once_with( - mgmt_resource, test_nonsense_code, mock.ANY) - mock_auth_context.acquire_token.assert_called_once_with( - mgmt_resource, self.user1, mock.ANY) - - @mock.patch('adal.AuthenticationContext', autospec=True) - @mock.patch('azure.cli.core._profile._get_authorization_code', autospec=True) - def test_find_subscriptions_through_authorization_code_flow(self, _get_authorization_code_mock, mock_auth_context): - import adal + @mock.patch('azure.cli.core.auth.identity.Identity.logout_user') + def test_logout(self, logout_user_mock): cli = DummyCli() - mock_arm_client = mock.MagicMock() - mock_arm_client.tenants.list.return_value = [TenantStub(self.tenant_id)] - mock_arm_client.subscriptions.list.return_value = [deepcopy(self.subscription1_raw)] - token_cache = adal.TokenCache() - finder = SubscriptionFinder(cli, lambda _, _1, _2: mock_auth_context, token_cache, lambda _: mock_arm_client) - _get_authorization_code_mock.return_value = { - 'code': 'code1', - 'reply_url': 'http://localhost:8888' - } - mgmt_resource = 'https://management.core.windows.net/' - temp_token_cache = mock.MagicMock() - type(mock_auth_context).cache = temp_token_cache - temp_token_cache.read_items.return_value = [] - mock_auth_context.acquire_token_with_authorization_code.return_value = self.token_entry1 + storage_mock = {'subscriptions': []} + profile = Profile(cli_ctx=cli, storage=storage_mock) + consolidated = profile._normalize_properties(self.user1, + [self.subscription1], + False) + profile._set_subscriptions(consolidated) + self.assertEqual(1, len(storage_mock['subscriptions'])) # action - subs = finder.find_through_authorization_code_flow(None, mgmt_resource, 'https:/some_aad_point/common') + profile.logout(self.user1) - # assert - self.assertEqual([self.subscription1], subs) - mock_auth_context.acquire_token.assert_called_once_with(mgmt_resource, self.user1, mock.ANY) - mock_auth_context.acquire_token_with_authorization_code.assert_called_once_with('code1', - 'http://localhost:8888', - mgmt_resource, mock.ANY, - None) - _get_authorization_code_mock.assert_called_once_with(mgmt_resource, 'https:/some_aad_point/common') - - @mock.patch('adal.AuthenticationContext', autospec=True) - def test_find_subscriptions_interactive_from_particular_tenent(self, mock_auth_context): - def just_raise(ex): - raise ex + # verify + self.assertEqual(0, len(storage_mock['subscriptions'])) + logout_user_mock.assert_called_with(self.user1) + @mock.patch('azure.cli.core.auth.identity.Identity.logout_all_users') + def test_logout_all(self, logout_all_users_mock): cli = DummyCli() - mock_arm_client = mock.MagicMock() - mock_arm_client.tenants.list.side_effect = lambda: just_raise( - ValueError("'tenants.list' should not occur")) - mock_arm_client.subscriptions.list.return_value = [deepcopy(self.subscription1_raw)] - finder = SubscriptionFinder(cli, lambda _, _1, _2: mock_auth_context, None, lambda _: mock_arm_client) + # setup + storage_mock = {'subscriptions': []} + profile = Profile(cli_ctx=cli, storage=storage_mock) + consolidated = profile._normalize_properties(self.user1, + [self.subscription1], + False) + consolidated2 = profile._normalize_properties(self.user2, + [self.subscription2], + False) + profile._set_subscriptions(consolidated + consolidated2) + + self.assertEqual(2, len(storage_mock['subscriptions'])) # action - subs = finder.find_through_interactive_flow(self.tenant_id, 'http://someresource') + profile.logout_all() - # assert - self.assertEqual([self.subscription1], subs) + # verify + self.assertEqual([], storage_mock['subscriptions']) + logout_all_users_mock.assert_called_once() - @mock.patch('adal.AuthenticationContext', autospec=True) - def test_find_subscriptions_from_service_principal_id(self, mock_auth_context): + @unittest.skip("todo: wait for identity support") + @mock.patch('azure.identity.UsernamePasswordCredential.get_token', autospec=True) + def test_find_subscriptions_thru_username_password_adfs(self, get_token_mock): cli = DummyCli() - mock_auth_context.acquire_token_with_client_credentials.return_value = self.token_entry1 - mock_arm_client = mock.MagicMock() - mock_arm_client.subscriptions.list.return_value = [deepcopy(self.subscription1_raw)] - finder = SubscriptionFinder(cli, lambda _, _1, _2: mock_auth_context, None, lambda _: mock_arm_client) - mgmt_resource = 'https://management.core.windows.net/' - # action - subs = finder.find_from_service_principal_id('my app', ServicePrincipalAuth('my secret'), - self.tenant_id, mgmt_resource) + TEST_ADFS_AUTH_URL = 'https://adfs.local.azurestack.external/adfs' + get_token_mock.return_value = self.access_token - # assert - self.assertEqual([self.subscription1], subs) - mock_arm_client.tenants.list.assert_not_called() - mock_auth_context.acquire_token.assert_not_called() - mock_auth_context.acquire_token_with_client_credentials.assert_called_once_with( - mgmt_resource, 'my app', 'my secret') + # todo: adfs test should be covered in azure.identity + def test_acquire_token(self, resource, username, password, client_id): + global acquire_token_invoked + acquire_token_invoked = True + if (self.authority.url == TEST_ADFS_AUTH_URL and self.authority.is_adfs_authority): + return TestProfile.token_entry1 + else: + raise ValueError('AuthContext was not initialized correctly for ADFS') - @mock.patch('adal.AuthenticationContext', autospec=True) - def test_find_subscriptions_from_service_principal_using_cert(self, mock_auth_context): - cli = DummyCli() - mock_auth_context.acquire_token_with_client_certificate.return_value = self.token_entry1 + get_token_mock.return_value = self.access_token mock_arm_client = mock.MagicMock() + mock_arm_client.tenants.list.return_value = [TenantStub(self.tenant_id)] mock_arm_client.subscriptions.list.return_value = [deepcopy(self.subscription1_raw)] - finder = SubscriptionFinder(cli, lambda _, _1, _2: mock_auth_context, None, lambda _: mock_arm_client) + cli.cloud.endpoints.active_directory = TEST_ADFS_AUTH_URL + finder = SubscriptionFinder(cli) + finder._arm_client_factory = mock_arm_client mgmt_resource = 'https://management.core.windows.net/' - - curr_dir = os.path.dirname(os.path.realpath(__file__)) - test_cert_file = os.path.join(curr_dir, 'sp_cert.pem') + storage_mock = {'subscriptions': None} + profile = Profile(cli_ctx=cli, storage=storage_mock, use_global_creds_cache=False, async_persist=False) + profile.login(False, '1234', 'my-secret', True, self.tenant_id, use_device_code=False, + allow_no_subscriptions=False, subscription_finder=finder) # action - subs = finder.find_from_service_principal_id('my app', ServicePrincipalAuth(test_cert_file), - self.tenant_id, mgmt_resource) + subs = finder.find_from_user_account(self.user1, 'bar', None, mgmt_resource) # assert self.assertEqual([self.subscription1], subs) - mock_arm_client.tenants.list.assert_not_called() - mock_auth_context.acquire_token.assert_not_called() - mock_auth_context.acquire_token_with_client_certificate.assert_called_once_with( - mgmt_resource, 'my app', mock.ANY, mock.ANY, None) - @mock.patch('adal.AuthenticationContext', autospec=True) - def test_find_subscriptions_from_service_principal_using_cert_sn_issuer(self, mock_auth_context): - cli = DummyCli() - mock_auth_context.acquire_token_with_client_certificate.return_value = self.token_entry1 + @mock.patch('azure.cli.core._profile.SubscriptionFinder._create_subscription_client', autospec=True) + @mock.patch('azure.cli.core.auth.identity.Identity.get_user_credential', autospec=True) + def test_refresh_accounts_one_user_account(self, get_user_credential_mock, create_subscription_client_mock): mock_arm_client = mock.MagicMock() + mock_arm_client.tenants.list.return_value = [TenantStub(self.tenant_id)] mock_arm_client.subscriptions.list.return_value = [deepcopy(self.subscription1_raw)] - finder = SubscriptionFinder(cli, lambda _, _1, _2: mock_auth_context, None, lambda _: mock_arm_client) - mgmt_resource = 'https://management.core.windows.net/' - - curr_dir = os.path.dirname(os.path.realpath(__file__)) - test_cert_file = os.path.join(curr_dir, 'sp_cert.pem') - with open(test_cert_file) as cert_file: - cert_file_string = cert_file.read() - match = re.search(r'\-+BEGIN CERTIFICATE.+\-+(?P[^-]+)\-+END CERTIFICATE.+\-+', - cert_file_string, re.I) - public_certificate = match.group('public').strip() - # action - subs = finder.find_from_service_principal_id('my app', ServicePrincipalAuth(test_cert_file, use_cert_sn_issuer=True), - self.tenant_id, mgmt_resource) + create_subscription_client_mock.return_value = mock_arm_client - # assert - self.assertEqual([self.subscription1], subs) - mock_arm_client.tenants.list.assert_not_called() - mock_auth_context.acquire_token.assert_not_called() - mock_auth_context.acquire_token_with_client_certificate.assert_called_once_with( - mgmt_resource, 'my app', mock.ANY, mock.ANY, public_certificate) - - @mock.patch('adal.AuthenticationContext', autospec=True) - def test_refresh_accounts_one_user_account(self, mock_auth_context): cli = DummyCli() - storage_mock = {'subscriptions': None} - profile = Profile(cli_ctx=cli, storage=storage_mock, use_global_creds_cache=False, async_persist=False) - consolidated = profile._normalize_properties(self.user1, deepcopy([self.subscription1]), False) + storage_mock = {'subscriptions': []} + profile = Profile(cli_ctx=cli, storage=storage_mock) + consolidated = profile._normalize_properties(self.user1, deepcopy([self.subscription1]), False, None, None) profile._set_subscriptions(consolidated) - mock_auth_context.acquire_token_with_username_password.return_value = self.token_entry1 - mock_auth_context.acquire_token.return_value = self.token_entry1 - mock_arm_client = mock.MagicMock() + mock_arm_client.tenants.list.return_value = [TenantStub(self.tenant_id)] mock_arm_client.subscriptions.list.return_value = deepcopy([self.subscription1_raw, self.subscription2_raw]) - finder = SubscriptionFinder(cli, lambda _, _1, _2: mock_auth_context, None, lambda _: mock_arm_client) - # action - profile.refresh_accounts(finder) + + profile.refresh_accounts() # assert result = storage_mock['subscriptions'] @@ -1511,28 +1276,28 @@ def test_refresh_accounts_one_user_account(self, mock_auth_context): self.assertEqual(self.id2.split('/')[-1], result[1]['id']) self.assertTrue(result[0]['isDefault']) - @mock.patch('adal.AuthenticationContext', autospec=True) - def test_refresh_accounts_one_user_account_one_sp_account(self, mock_auth_context): + @mock.patch('azure.cli.core._profile.SubscriptionFinder._create_subscription_client', autospec=True) + @mock.patch('azure.cli.core.auth.identity.Identity.get_service_principal_credential', autospec=True) + @mock.patch('azure.cli.core.auth.identity.Identity.get_user_credential', autospec=True) + def test_refresh_accounts_one_user_account_one_sp_account(self, get_user_credential_mock, + get_service_principal_credential_mock, + create_subscription_client_mock): cli = DummyCli() - storage_mock = {'subscriptions': None} - profile = Profile(cli_ctx=cli, storage=storage_mock, use_global_creds_cache=False, async_persist=False) - sp_subscription1 = SubscriptionStub('sp-sub/3', 'foo-subname', self.state1, 'foo_tenant.onmicrosoft.com') - consolidated = profile._normalize_properties(self.user1, deepcopy([self.subscription1]), False) + storage_mock = {'subscriptions': []} + profile = Profile(cli_ctx=cli, storage=storage_mock) + sp_subscription1 = SubscriptionStub('sp-sub/3', 'foo-subname', self.state1, 'footenant.onmicrosoft.com') + consolidated = profile._normalize_properties(self.user1, deepcopy([self.subscription1]), False, None, None) consolidated += profile._normalize_properties('http://foo', [sp_subscription1], True) profile._set_subscriptions(consolidated) - mock_auth_context.acquire_token_with_username_password.return_value = self.token_entry1 - mock_auth_context.acquire_token.return_value = self.token_entry1 - mock_auth_context.acquire_token_with_client_credentials.return_value = self.token_entry1 + mock_arm_client = mock.MagicMock() mock_arm_client.tenants.list.return_value = [TenantStub(self.tenant_id)] - mock_arm_client.subscriptions.list.side_effect = deepcopy([[self.subscription1], [self.subscription2, sp_subscription1]]) - finder = SubscriptionFinder(cli, lambda _, _1, _2: mock_auth_context, None, lambda _: mock_arm_client) - profile._creds_cache.retrieve_cred_for_service_principal = lambda _: 'verySecret' - profile._creds_cache.flush_to_disk = lambda _: '' - # action - profile.refresh_accounts(finder) + mock_arm_client.subscriptions.list.side_effect = deepcopy( + [[self.subscription1], [self.subscription2, sp_subscription1]]) + create_subscription_client_mock.return_value = mock_arm_client + + profile.refresh_accounts() - # assert result = storage_mock['subscriptions'] self.assertEqual(3, len(result)) self.assertEqual(self.id1.split('/')[-1], result[0]['id']) @@ -1540,468 +1305,69 @@ def test_refresh_accounts_one_user_account_one_sp_account(self, mock_auth_contex self.assertEqual('3', result[2]['id']) self.assertTrue(result[0]['isDefault']) - @mock.patch('adal.AuthenticationContext', autospec=True) - def test_refresh_accounts_with_nothing(self, mock_auth_context): + @mock.patch('azure.cli.core._profile.SubscriptionFinder._create_subscription_client', autospec=True) + @mock.patch('azure.cli.core.auth.identity.Identity.get_user_credential', autospec=True) + def test_refresh_accounts_with_nothing(self, get_user_credential_mock, create_subscription_client_mock): cli = DummyCli() - storage_mock = {'subscriptions': None} - profile = Profile(cli_ctx=cli, storage=storage_mock, use_global_creds_cache=False, async_persist=False) - consolidated = profile._normalize_properties(self.user1, deepcopy([self.subscription1]), False) + storage_mock = {'subscriptions': []} + profile = Profile(cli_ctx=cli, storage=storage_mock) + consolidated = profile._normalize_properties(self.user1, deepcopy([self.subscription1]), False, None, None) profile._set_subscriptions(consolidated) - mock_auth_context.acquire_token_with_username_password.return_value = self.token_entry1 - mock_auth_context.acquire_token.return_value = self.token_entry1 + mock_arm_client = mock.MagicMock() mock_arm_client.tenants.list.return_value = [TenantStub(self.tenant_id)] mock_arm_client.subscriptions.list.return_value = [] - finder = SubscriptionFinder(cli, lambda _, _1, _2: mock_auth_context, None, lambda _: mock_arm_client) - # action - profile.refresh_accounts(finder) + create_subscription_client_mock.return_value = mock_arm_client + + profile.refresh_accounts() # assert result = storage_mock['subscriptions'] self.assertEqual(0, len(result)) - @mock.patch('azure.cli.core._profile._load_tokens_from_file', autospec=True) - def test_credscache_load_tokens_and_sp_creds_with_secret(self, mock_read_file): - cli = DummyCli() - test_sp = { - "servicePrincipalId": "myapp", - "servicePrincipalTenant": "mytenant", - "accessToken": "Secret" - } - mock_read_file.return_value = [self.token_entry1, test_sp] - - # action - creds_cache = CredsCache(cli, async_persist=False) - - # assert - token_entries = [entry for _, entry in creds_cache.load_adal_token_cache().read_items()] - self.assertEqual(token_entries, [self.token_entry1]) - self.assertEqual(creds_cache._service_principal_creds, [test_sp]) - - @mock.patch('azure.cli.core._profile._load_tokens_from_file', autospec=True) - def test_credscache_load_tokens_and_sp_creds_with_cert(self, mock_read_file): - cli = DummyCli() - test_sp = { - "servicePrincipalId": "myapp", - "servicePrincipalTenant": "mytenant", - "certificateFile": 'junkcert.pem' - } - mock_read_file.return_value = [test_sp] - - # action - creds_cache = CredsCache(cli, async_persist=False) - creds_cache.load_adal_token_cache() - - # assert - self.assertEqual(creds_cache._service_principal_creds, [test_sp]) - - @mock.patch('azure.cli.core._profile._load_tokens_from_file', autospec=True) - def test_credscache_retrieve_sp_cred(self, mock_read_file): - cli = DummyCli() - test_cache = [ - { - "servicePrincipalId": "myapp", - "servicePrincipalTenant": "mytenant", - "accessToken": "Secret" - }, - { - "servicePrincipalId": "myapp2", - "servicePrincipalTenant": "mytenant", - "certificateFile": 'junkcert.pem' - } - ] - mock_read_file.return_value = test_cache - - # action - creds_cache = CredsCache(cli, async_persist=False) - creds_cache.load_adal_token_cache() - - # assert - self.assertEqual(creds_cache.retrieve_cred_for_service_principal('myapp'), 'Secret') - self.assertEqual(creds_cache.retrieve_cred_for_service_principal('myapp2'), 'junkcert.pem') - - @mock.patch('azure.cli.core._profile._load_tokens_from_file', autospec=True) - @mock.patch('os.fdopen', autospec=True) - @mock.patch('os.open', autospec=True) - def test_credscache_add_new_sp_creds(self, _, mock_open_for_write, mock_read_file): - cli = DummyCli() - test_sp = { - "servicePrincipalId": "myapp", - "servicePrincipalTenant": "mytenant", - "accessToken": "Secret" - } - test_sp2 = { - "servicePrincipalId": "myapp2", - "servicePrincipalTenant": "mytenant2", - "accessToken": "Secret2" - } - mock_open_for_write.return_value = FileHandleStub() - mock_read_file.return_value = [self.token_entry1, test_sp] - creds_cache = CredsCache(cli, async_persist=False) - - # action - creds_cache.save_service_principal_cred(test_sp2) - - # assert - token_entries = [e for _, e in creds_cache.adal_token_cache.read_items()] # noqa: F812 - self.assertEqual(token_entries, [self.token_entry1]) - self.assertEqual(creds_cache._service_principal_creds, [test_sp, test_sp2]) - mock_open_for_write.assert_called_with(mock.ANY, 'w+') - - @mock.patch('azure.cli.core._profile._load_tokens_from_file', autospec=True) - @mock.patch('os.fdopen', autospec=True) - @mock.patch('os.open', autospec=True) - def test_credscache_add_preexisting_sp_creds(self, _, mock_open_for_write, mock_read_file): - cli = DummyCli() - test_sp = { - "servicePrincipalId": "myapp", - "servicePrincipalTenant": "mytenant", - "accessToken": "Secret" - } - mock_open_for_write.return_value = FileHandleStub() - mock_read_file.return_value = [test_sp] - creds_cache = CredsCache(cli, async_persist=False) - - # action - creds_cache.save_service_principal_cred(test_sp) - - # assert - self.assertEqual(creds_cache._service_principal_creds, [test_sp]) - self.assertFalse(mock_open_for_write.called) - - @mock.patch('azure.cli.core._profile._load_tokens_from_file', autospec=True) - @mock.patch('os.fdopen', autospec=True) - @mock.patch('os.open', autospec=True) - def test_credscache_add_preexisting_sp_new_secret(self, _, mock_open_for_write, mock_read_file): - cli = DummyCli() - test_sp = { - "servicePrincipalId": "myapp", - "servicePrincipalTenant": "mytenant", - "accessToken": "Secret" - } - mock_open_for_write.return_value = FileHandleStub() - mock_read_file.return_value = [test_sp] - creds_cache = CredsCache(cli, async_persist=False) - - new_creds = test_sp.copy() - new_creds['accessToken'] = 'Secret2' - # action - creds_cache.save_service_principal_cred(new_creds) - - # assert - self.assertEqual(creds_cache._service_principal_creds, [new_creds]) - self.assertTrue(mock_open_for_write.called) - - @mock.patch('azure.cli.core._profile._load_tokens_from_file', autospec=True) - @mock.patch('os.fdopen', autospec=True) - @mock.patch('os.open', autospec=True) - def test_credscache_match_service_principal_correctly(self, _, mock_open_for_write, mock_read_file): - cli = DummyCli() - test_sp = { - "servicePrincipalId": "myapp", - "servicePrincipalTenant": "mytenant", - "accessToken": "Secret" - } - mock_open_for_write.return_value = FileHandleStub() - mock_read_file.return_value = [test_sp] - factory = mock.MagicMock() - factory.side_effect = ValueError('SP was found') - creds_cache = CredsCache(cli, factory, async_persist=False) - - # action and verify(we plant an exception to throw after the SP was found; so if the exception is thrown, - # we know the matching did go through) - self.assertRaises(ValueError, creds_cache.retrieve_token_for_service_principal, - 'myapp', 'resource1', 'mytenant', False) - - # tenant doesn't exactly match, but it still succeeds - # before fully migrating to pytest and utilizing capsys fixture, use `pytest -o log_cli=True` to manually - # verify the warning log - self.assertRaises(ValueError, creds_cache.retrieve_token_for_service_principal, - 'myapp', 'resource1', 'mytenant2', False) - - @mock.patch('azure.cli.core._profile._load_tokens_from_file', autospec=True) - @mock.patch('os.fdopen', autospec=True) - @mock.patch('os.open', autospec=True) - def test_credscache_remove_creds(self, _, mock_open_for_write, mock_read_file): - cli = DummyCli() - test_sp = { - "servicePrincipalId": "myapp", - "servicePrincipalTenant": "mytenant", - "accessToken": "Secret" - } - mock_open_for_write.return_value = FileHandleStub() - mock_read_file.return_value = [self.token_entry1, test_sp] - creds_cache = CredsCache(cli, async_persist=False) - - # action #1, logout a user - creds_cache.remove_cached_creds(self.user1) - - # assert #1 - token_entries = [e for _, e in creds_cache.adal_token_cache.read_items()] # noqa: F812 - self.assertEqual(token_entries, []) - - # action #2 logout a service principal - creds_cache.remove_cached_creds('myapp') - - # assert #2 - self.assertEqual(creds_cache._service_principal_creds, []) - - mock_open_for_write.assert_called_with(mock.ANY, 'w+') - self.assertEqual(mock_open_for_write.call_count, 2) - - @mock.patch('azure.cli.core._profile._load_tokens_from_file', autospec=True) - @mock.patch('os.fdopen', autospec=True) - @mock.patch('os.open', autospec=True) - @mock.patch('adal.AuthenticationContext', autospec=True) - def test_credscache_new_token_added_by_adal(self, mock_adal_auth_context, _, mock_open_for_write, mock_read_file): # pylint: disable=line-too-long - cli = DummyCli() - token_entry2 = { - "accessToken": "new token", - "tokenType": "Bearer", - "userId": self.user1 - } - - def acquire_token_side_effect(*args): # pylint: disable=unused-argument - creds_cache.adal_token_cache.has_state_changed = True - return token_entry2 - - def get_auth_context(_, authority, **kwargs): # pylint: disable=unused-argument - mock_adal_auth_context.cache = kwargs['cache'] - return mock_adal_auth_context - - mock_adal_auth_context.acquire_token.side_effect = acquire_token_side_effect - mock_open_for_write.return_value = FileHandleStub() - mock_read_file.return_value = [self.token_entry1] - creds_cache = CredsCache(cli, auth_ctx_factory=get_auth_context, async_persist=False) - - # action - mgmt_resource = 'https://management.core.windows.net/' - token_type, token, _ = creds_cache.retrieve_token_for_user(self.user1, self.tenant_id, - mgmt_resource) - mock_adal_auth_context.acquire_token.assert_called_once_with( - 'https://management.core.windows.net/', - self.user1, - mock.ANY) - - # assert - mock_open_for_write.assert_called_with(mock.ANY, 'w+') - self.assertEqual(token, 'new token') - self.assertEqual(token_type, token_entry2['tokenType']) - - @mock.patch('azure.cli.core._profile.get_file_json', autospec=True) - @mock.patch('os.path.isfile', autospec=True, return_value=True) - def test_credscache_good_error_on_file_corruption(self, isfile_mock, get_file_json_mock): - get_file_json_mock.side_effect = ValueError('a bad error for you') - cli = DummyCli() - - # action - creds_cache = CredsCache(cli, async_persist=False) - - # assert - with self.assertRaises(CLIError) as context: - creds_cache.load_adal_token_cache() - - self.assertTrue(re.findall(r'bad error for you', str(context.exception))) - - def test_service_principal_auth_client_secret(self): - sp_auth = ServicePrincipalAuth('verySecret!') - result = sp_auth.get_entry_to_persist('sp_id1', 'tenant1') - self.assertEqual(result, { - 'servicePrincipalId': 'sp_id1', - 'servicePrincipalTenant': 'tenant1', - 'accessToken': 'verySecret!' - }) - - def test_service_principal_auth_client_cert(self): - curr_dir = os.path.dirname(os.path.realpath(__file__)) - test_cert_file = os.path.join(curr_dir, 'sp_cert.pem') - sp_auth = ServicePrincipalAuth(test_cert_file) - - result = sp_auth.get_entry_to_persist('sp_id1', 'tenant1') - self.assertEqual(result, { - 'servicePrincipalId': 'sp_id1', - 'servicePrincipalTenant': 'tenant1', - 'certificateFile': test_cert_file, - 'thumbprint': 'F0:6A:53:84:8B:BE:71:4A:42:90:D6:9D:33:52:79:C1:D0:10:73:FD' - }) - - def test_service_principal_auth_client_cert_err(self): - curr_dir = os.path.dirname(os.path.realpath(__file__)) - test_cert_file = os.path.join(curr_dir, 'err_sp_cert.pem') - with self.assertRaisesRegexp(CLIError, 'Invalid certificate'): - ServicePrincipalAuth(test_cert_file) - - def test_detect_adfs_authority_url(self): - cli = DummyCli() - adfs_url_1 = 'https://adfs.redmond.ext-u15f2402.masd.stbtest.microsoft.com/adfs/' - cli.cloud.endpoints.active_directory = adfs_url_1 - storage_mock = {'subscriptions': None} - profile = Profile(cli_ctx=cli, storage=storage_mock, use_global_creds_cache=False, async_persist=False) - - # test w/ trailing slash - r = profile.auth_ctx_factory(cli, 'common', None) - self.assertEqual(r.authority.url, adfs_url_1.rstrip('/')) - - # test w/o trailing slash - adfs_url_2 = 'https://adfs.redmond.ext-u15f2402.masd.stbtest.microsoft.com/adfs' - cli.cloud.endpoints.active_directory = adfs_url_2 - r = profile.auth_ctx_factory(cli, 'common', None) - self.assertEqual(r.authority.url, adfs_url_2) - - # test w/ regular aad - aad_url = 'https://login.microsoftonline.com' - cli.cloud.endpoints.active_directory = aad_url - r = profile.auth_ctx_factory(cli, 'common', None) - self.assertEqual(r.authority.url, aad_url + '/common') - - @mock.patch('adal.AuthenticationContext', autospec=True) - @mock.patch('azure.cli.core._profile._get_authorization_code', autospec=True) - def test_find_using_common_tenant(self, _get_authorization_code_mock, mock_auth_context): - """When a subscription can be listed by multiple tenants, only the first appearance is retained - """ - import adal - cli = DummyCli() - mock_arm_client = mock.MagicMock() - tenant2 = "00000002-0000-0000-0000-000000000000" - mock_arm_client.tenants.list.return_value = [TenantStub(self.tenant_id), TenantStub(tenant2)] - - # same subscription but listed from another tenant - subscription2_raw = SubscriptionStub(self.id1, self.display_name1, self.state1, self.tenant_id) - mock_arm_client.subscriptions.list.side_effect = [[deepcopy(self.subscription1_raw)], [subscription2_raw]] - - mgmt_resource = 'https://management.core.windows.net/' - token_cache = adal.TokenCache() - finder = SubscriptionFinder(cli, lambda _, _1, _2: mock_auth_context, token_cache, lambda _: mock_arm_client) - all_subscriptions = finder._find_using_common_tenant(access_token="token1", resource=mgmt_resource) - - self.assertEqual(len(all_subscriptions), 1) - self.assertEqual(all_subscriptions[0].tenant_id, self.tenant_id) - - @mock.patch('adal.AuthenticationContext', autospec=True) - @mock.patch('azure.cli.core._profile._get_authorization_code', autospec=True) - def test_find_using_common_tenant_mfa_warning(self, _get_authorization_code_mock, mock_auth_context): + @mock.patch('azure.cli.core._profile.SubscriptionFinder._create_subscription_client', autospec=True) + @mock.patch('azure.cli.core.auth.identity.Identity.get_user_credential', autospec=True) + def test_login_common_tenant_mfa_warning(self, get_user_credential_mock, create_subscription_client_mock): # Assume 2 tenants. Home tenant tenant1 doesn't require MFA, but tenant2 does - import adal cli = DummyCli() mock_arm_client = mock.MagicMock() tenant2_mfa_id = 'tenant2-0000-0000-0000-000000000000' mock_arm_client.tenants.list.return_value = [TenantStub(self.tenant_id), TenantStub(tenant2_mfa_id)] - mock_arm_client.subscriptions.list.return_value = [deepcopy(self.subscription1_raw)] - token_cache = adal.TokenCache() - finder = SubscriptionFinder(cli, lambda _, _1, _2: mock_auth_context, token_cache, lambda _: mock_arm_client) + create_subscription_client_mock.return_value = mock_arm_client - adal_error_mfa = adal.AdalError(error_msg="", error_response={ + finder = SubscriptionFinder(cli) + + from azure.cli.core.azclierror import AuthenticationError + error_description = ("AADSTS50076: Due to a configuration change made by your administrator, " + "or because you moved to a new location, you must use multi-factor " + "authentication to access '797f4846-ba00-4fd7-ba43-dac1f8f63013'.\n" + "Trace ID: 00000000-0000-0000-0000-000000000000\n" + "Correlation ID: 00000000-0000-0000-0000-000000000000\n" + "Timestamp: 2020-03-10 04:42:59Z") + msal_result = { 'error': 'interaction_required', - 'error_description': "AADSTS50076: Due to a configuration change made by your administrator, " - "or because you moved to a new location, you must use multi-factor " - "authentication to access '797f4846-ba00-4fd7-ba43-dac1f8f63013'.\n" - "Trace ID: 00000000-0000-0000-0000-000000000000\n" - "Correlation ID: 00000000-0000-0000-0000-000000000000\n" - "Timestamp: 2020-03-10 04:42:59Z", - 'error_codes': [50076], + 'error_description': error_description, + 'error_codes':[50076], 'timestamp': '2020-03-10 04:42:59Z', 'trace_id': '00000000-0000-0000-0000-000000000000', 'correlation_id': '00000000-0000-0000-0000-000000000000', 'error_uri': 'https://login.microsoftonline.com/error?code=50076', - 'suberror': 'basic_action'}) + 'suberror': 'basic_action' + } - # adal_error_mfa are raised on the second call - mock_auth_context.acquire_token.side_effect = [self.token_entry1, adal_error_mfa] + err = AuthenticationError(error_description, recommendation=None) - # action - all_subscriptions = finder._find_using_common_tenant(access_token="token1", - resource='https://management.core.windows.net/') + # MFA error raised on the second call + mock_arm_client.subscriptions.list.side_effect = [[deepcopy(self.subscription1_raw)], err] + + credential = mock.MagicMock() + all_subscriptions = finder.find_using_common_tenant(self.user1, credential) - # assert # subscriptions are correctly returned self.assertEqual(all_subscriptions, [self.subscription1]) - self.assertEqual(mock_auth_context.acquire_token.call_count, 2) # With pytest, use -o log_cli=True to manually check the log - @mock.patch('adal.AuthenticationContext', autospec=True) - @mock.patch('azure.cli.core._profile._get_authorization_code', autospec=True) - def test_find_using_specific_tenant(self, _get_authorization_code_mock, mock_auth_context): - """ Test tenant_id -> home_tenant_id mapping and token tenant attachment - """ - import adal - cli = DummyCli() - mock_arm_client = mock.MagicMock() - token_tenant = "00000001-0000-0000-0000-000000000000" - home_tenant = "00000002-0000-0000-0000-000000000000" - - subscription_raw = SubscriptionStub(self.id1, self.display_name1, self.state1, tenant_id=home_tenant) - mock_arm_client.subscriptions.list.return_value = [subscription_raw] - - token_cache = adal.TokenCache() - finder = SubscriptionFinder(cli, lambda _, _1, _2: mock_auth_context, token_cache, lambda _: mock_arm_client) - all_subscriptions = finder._find_using_specific_tenant(tenant=token_tenant, access_token="token1") - - self.assertEqual(len(all_subscriptions), 1) - self.assertEqual(all_subscriptions[0].tenant_id, token_tenant) - self.assertEqual(all_subscriptions[0].home_tenant_id, home_tenant) - - @mock.patch('msal.ConfidentialClientApplication.acquire_token_for_client', autospec=True) - @mock.patch('azure.cli.core._profile.CredsCache.retrieve_cred_for_service_principal', autospec=True) - @mock.patch('msal.ClientApplication.acquire_token_by_refresh_token', autospec=True) - @mock.patch('azure.cli.core._profile.CredsCache.retrieve_token_for_user', autospec=True) - @mock.patch('azure.cli.core._profile.Profile.get_subscription', autospec=True) - def test_get_msal_token(self, get_subscription_mock, retrieve_token_for_user_mock, - acquire_token_by_refresh_token_mock, retrieve_cred_for_service_principal_mock, - acquire_token_for_client_mock): - """ - This is added only for vmssh feature. - It is a temporary solution and will deprecate after MSAL adopted completely. - """ - cli = DummyCli() - storage_mock = {'subscriptions': None} - profile = Profile(cli_ctx=cli, storage=storage_mock, use_global_creds_cache=False, async_persist=False) - - msal_result = { - 'token_type': 'ssh-cert', - 'scope': 'https://pas.windows.net/CheckMyAccess/Linux/user_impersonation https://pas.windows.net/CheckMyAccess/Linux/.default', - 'expires_in': 3599, - 'ext_expires_in': 3599, - 'access_token': 'fake_cert' - } - - # User - get_subscription_mock.return_value = { - 'tenantId': self.tenant_id, - 'user': { - 'name': self.user1, - 'type': 'user' - }, - } - - retrieve_token_for_user_mock.return_value = ('Bearer', self.raw_token1, self.token_entry1) - acquire_token_by_refresh_token_mock.return_value = msal_result - - scopes = ["https://pas.windows.net/CheckMyAccess/Linux/.default"] - data = { - "token_type": "ssh-cert", - "req_cnf": "fake_jwk", - "key_id": "fake_id" - } - username, access_token = profile.get_msal_token(scopes, data) - self.assertEqual(username, self.user1) - self.assertEqual(access_token, 'fake_cert') - - # Service Principal - sp_id = '610a3200-0000-0000-0000-000000000000' - get_subscription_mock.return_value = { - 'tenantId': self.tenant_id, - 'user': { - 'name': sp_id, - 'type': 'servicePrincipal' - }, - } - retrieve_cred_for_service_principal_mock.return_value = "some_secret" - acquire_token_for_client_mock.return_value = msal_result - username, access_token = profile.get_msal_token(scopes, data) - self.assertEqual(username, sp_id) - self.assertEqual(access_token, 'fake_cert') - class FileHandleStub(object): # pylint: disable=too-few-public-methods @@ -2021,7 +1387,8 @@ def __init__(self, id, display_name, state, tenant_id, managed_by_tenants=[], ho policies = SubscriptionPolicies() policies.spending_limit = SpendingLimit.current_period_off policies.quota_id = 'some quota' - super(SubscriptionStub, self).__init__(subscription_policies=policies, authorization_source='some_authorization_source') + super(SubscriptionStub, self).__init__(subscription_policies=policies, + authorization_source='some_authorization_source') self.id = id self.subscription_id = id.split('/')[1] self.display_name = display_name @@ -2049,32 +1416,46 @@ def __init__(self, tenant_id, display_name="DISPLAY_NAME"): self.additional_properties = {'displayName': display_name} -class MSRestAzureAuthStub: - def __init__(self, *args, **kwargs): - self._token = { - 'token_type': 'Bearer', - 'access_token': TestProfile.test_msi_access_token - } - self.set_token_invoked_count = 0 - self.token_read_count = 0 - self.client_id = kwargs.get('client_id') - self.object_id = kwargs.get('object_id') - self.msi_res_id = kwargs.get('msi_res_id') - - def set_token(self): - self.set_token_invoked_count += 1 - - @property - def token(self): - self.token_read_count += 1 - return self._token - - @token.setter - def token(self, value): - self._token = value - - -class TestUtil(unittest.TestCase): +class TestUtils(unittest.TestCase): + def test_detect_adfs_authority(self): + # Public cloud + # Default tenant + self.assertEqual(_detect_adfs_authority('https://login.microsoftonline.com', None), + ('https://login.microsoftonline.com', None)) + # Trailing slash is stripped + self.assertEqual(_detect_adfs_authority('https://login.microsoftonline.com/', None), + ('https://login.microsoftonline.com', None)) + # Custom tenant + self.assertEqual(_detect_adfs_authority('https://login.microsoftonline.com', '601d729d-0000-0000-0000-000000000000'), + ('https://login.microsoftonline.com', '601d729d-0000-0000-0000-000000000000')) + + # ADFS + # Default tenant + self.assertEqual(_detect_adfs_authority('https://adfs.redmond.azurestack.corp.microsoft.com/adfs', None), + ('https://adfs.redmond.azurestack.corp.microsoft.com', 'adfs')) + # Trailing slash is stripped + self.assertEqual(_detect_adfs_authority('https://adfs.redmond.azurestack.corp.microsoft.com/adfs/', None), + ('https://adfs.redmond.azurestack.corp.microsoft.com', 'adfs')) + # Tenant ID is discarded + self.assertEqual(_detect_adfs_authority('https://adfs.redmond.azurestack.corp.microsoft.com/adfs', '601d729d-0000-0000-0000-000000000000'), + ('https://adfs.redmond.azurestack.corp.microsoft.com', 'adfs')) + + def test_attach_token_tenant(self): + from azure.mgmt.resource.subscriptions.v2016_06_01.models import Subscription \ + as Subscription_v2016_06_01 + subscription = Subscription_v2016_06_01() + _attach_token_tenant(subscription, "token_tenant_1") + self.assertEqual(subscription.tenant_id, "token_tenant_1") + self.assertFalse(hasattr(subscription, "home_tenant_id")) + + def test_attach_token_tenant_v2016_06_01(self): + from azure.mgmt.resource.subscriptions.v2019_11_01.models import Subscription \ + as Subscription_v2019_11_01 + subscription = Subscription_v2019_11_01() + subscription.tenant_id = "home_tenant_1" + _attach_token_tenant(subscription, "token_tenant_1") + self.assertEqual(subscription.tenant_id, "token_tenant_1") + self.assertEqual(subscription.home_tenant_id, "home_tenant_1") def test_transform_subscription_for_multiapi(self): diff --git a/src/azure-cli-core/azure/cli/core/tests/test_profile_v2016_06_01.py b/src/azure-cli-core/azure/cli/core/tests/test_profile_v2016_06_01.py deleted file mode 100644 index 03337c19cc7..00000000000 --- a/src/azure-cli-core/azure/cli/core/tests/test_profile_v2016_06_01.py +++ /dev/null @@ -1,1758 +0,0 @@ -# -------------------------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for license information. -# -------------------------------------------------------------------------------------------- - -# pylint: disable=protected-access -import json -import os -import sys -import unittest -from unittest import mock -import re - -from copy import deepcopy - -from adal import AdalError - -from azure.cli.core._profile import (Profile, CredsCache, SubscriptionFinder, - ServicePrincipalAuth, _AUTH_CTX_FACTORY, _USE_VENDORED_SUBSCRIPTION_SDK) - -if _USE_VENDORED_SUBSCRIPTION_SDK: - from azure.cli.core.vendored_sdks.subscriptions.v2016_06_01.models import \ - (SubscriptionState, Subscription, SubscriptionPolicies, SpendingLimit) -else: - from azure.mgmt.resource.subscriptions.v2016_06_01.models import \ - (SubscriptionState, Subscription, SubscriptionPolicies, SpendingLimit) - -from azure.cli.core.mock import DummyCli - -from knack.util import CLIError - - -@unittest.skip("Out of maintenance") -class TestProfile(unittest.TestCase): - - @classmethod - def setUpClass(cls): - cls.tenant_id = 'microsoft.com' - cls.user1 = 'foo@foo.com' - cls.id1 = 'subscriptions/1' - cls.display_name1 = 'foo account' - cls.state1 = SubscriptionState.enabled - # Dummy Subscription from SDK azure.mgmt.resource.subscriptions.v2016_06_01.operations._subscriptions_operations.SubscriptionsOperations.list - # tenant_id shouldn't be set as tenantId isn't returned by REST API - # Must be deepcopied before used as mock_arm_client.subscriptions.list.return_value - cls.subscription1_raw = SubscriptionStub(cls.id1, - cls.display_name1, - cls.state1) - # Dummy result of azure.cli.core._profile.SubscriptionFinder._find_using_specific_tenant - # tenant_id denotes token tenant - cls.subscription1 = SubscriptionStub(cls.id1, - cls.display_name1, - cls.state1, - cls.tenant_id) - # Dummy result of azure.cli.core._profile.Profile._normalize_properties - cls.subscription1_normalized = { - 'environmentName': 'AzureCloud', - 'id': '1', - 'name': cls.display_name1, - 'state': cls.state1.value, - 'user': { - 'name': cls.user1, - 'type': 'user' - }, - 'isDefault': False, - 'tenantId': cls.tenant_id - } - - cls.raw_token1 = 'some...secrets' - cls.token_entry1 = { - "_clientId": "04b07795-8ddb-461a-bbee-02f9e1bf7b46", - "resource": "https://management.core.windows.net/", - "tokenType": "Bearer", - "expiresOn": "2016-03-31T04:26:56.610Z", - "expiresIn": 3599, - "identityProvider": "live.com", - "_authority": "https://login.microsoftonline.com/common", - "isMRRT": True, - "refreshToken": "faked123", - "accessToken": cls.raw_token1, - "userId": cls.user1 - } - - cls.user2 = 'bar@bar.com' - cls.id2 = 'subscriptions/2' - cls.display_name2 = 'bar account' - cls.state2 = SubscriptionState.past_due - cls.subscription2_raw = SubscriptionStub(cls.id2, - cls.display_name2, - cls.state2) - cls.subscription2 = SubscriptionStub(cls.id2, - cls.display_name2, - cls.state2, - cls.tenant_id) - cls.subscription2_normalized = { - 'environmentName': 'AzureCloud', - 'id': '2', - 'name': cls.display_name2, - 'state': cls.state2.value, - 'user': { - 'name': cls.user2, - 'type': 'user' - }, - 'isDefault': False, - 'tenantId': cls.tenant_id - } - cls.test_msi_tenant = '54826b22-38d6-4fb2-bad9-b7b93a3e9c5a' - cls.test_msi_access_token = ('eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiIsIng1dCI6IlZXVkljMVdEMVRrc2JiMzAxc2FzTTVrT3E1' - 'USIsImtpZCI6IlZXVkljMVdEMVRrc2JiMzAxc2FzTTVrT3E1USJ9.eyJhdWQiOiJodHRwczovL21hbmF' - 'nZW1lbnQuY29yZS53aW5kb3dzLm5ldC8iLCJpc3MiOiJodHRwczovL3N0cy53aW5kb3dzLm5ldC81NDg' - 'yNmIyMi0zOGQ2LTRmYjItYmFkOS1iN2I5M2EzZTljNWEvIiwiaWF0IjoxNTAzMzU0ODc2LCJuYmYiOjE' - '1MDMzNTQ4NzYsImV4cCI6MTUwMzM1ODc3NiwiYWNyIjoiMSIsImFpbyI6IkFTUUEyLzhFQUFBQTFGL1k' - '0VVR3bFI1Y091QXJxc1J0OU5UVVc2MGlsUHZna0daUC8xczVtdzg9IiwiYW1yIjpbInB3ZCJdLCJhcHB' - 'pZCI6IjA0YjA3Nzk1LThkZGItNDYxYS1iYmVlLTAyZjllMWJmN2I0NiIsImFwcGlkYWNyIjoiMCIsImV' - 'fZXhwIjoyNjI4MDAsImZhbWlseV9uYW1lIjoic2RrIiwiZ2l2ZW5fbmFtZSI6ImFkbWluMyIsImdyb3V' - 'wcyI6WyJlNGJiMGI1Ni0xMDE0LTQwZjgtODhhYi0zZDhhOGNiMGUwODYiLCI4YTliMTYxNy1mYzhkLTR' - 'hYTktYTQyZi05OTg2OGQzMTQ2OTkiLCI1NDgwMzkxNy00YzcxLTRkNmMtOGJkZi1iYmQ5MzEwMTBmOGM' - 'iXSwiaXBhZGRyIjoiMTY3LjIyMC4xLjIzNCIsIm5hbWUiOiJhZG1pbjMiLCJvaWQiOiJlN2UxNThkMy0' - '3Y2RjLTQ3Y2QtODgyNS01ODU5ZDdhYjJiNTUiLCJwdWlkIjoiMTAwMzNGRkY5NUQ0NEU4NCIsInNjcCI' - '6InVzZXJfaW1wZXJzb25hdGlvbiIsInN1YiI6ImhRenl3b3FTLUEtRzAySTl6ZE5TRmtGd3R2MGVwZ2l' - 'WY1Vsdm1PZEZHaFEiLCJ0aWQiOiI1NDgyNmIyMi0zOGQ2LTRmYjItYmFkOS1iN2I5M2EzZTljNWEiLCJ' - '1bmlxdWVfbmFtZSI6ImFkbWluM0BBenVyZVNES1RlYW0ub25taWNyb3NvZnQuY29tIiwidXBuIjoiYWR' - 'taW4zQEF6dXJlU0RLVGVhbS5vbm1pY3Jvc29mdC5jb20iLCJ1dGkiOiJuUEROYm04UFkwYUdELWhNeWx' - 'rVEFBIiwidmVyIjoiMS4wIiwid2lkcyI6WyI2MmU5MDM5NC02OWY1LTQyMzctOTE5MC0wMTIxNzcxNDV' - 'lMTAiXX0.Pg4cq0MuP1uGhY_h51ZZdyUYjGDUFgTW2EfIV4DaWT9RU7GIK_Fq9VGBTTbFZA0pZrrmP-z' - '7DlN9-U0A0nEYDoXzXvo-ACTkm9_TakfADd36YlYB5aLna-yO0B7rk5W9ANelkzUQgRfidSHtCmV6i4V' - 'e-lOym1sH5iOcxfIjXF0Tp2y0f3zM7qCq8Cp1ZxEwz6xYIgByoxjErNXrOME5Ld1WizcsaWxTXpwxJn_' - 'Q8U2g9kXHrbYFeY2gJxF_hnfLvNKxUKUBnftmyYxZwKi0GDS0BvdJnJnsqSRSpxUx__Ra9QJkG1IaDzj' - 'ZcSZPHK45T6ohK9Hk9ktZo0crVl7Tmw') - - def test_normalize(self): - cli = DummyCli() - storage_mock = {'subscriptions': None} - profile = Profile(cli_ctx=cli, storage=storage_mock, use_global_creds_cache=False, async_persist=False) - consolidated = profile._normalize_properties(self.user1, - [self.subscription1], - False) - expected = self.subscription1_normalized - self.assertEqual(expected, consolidated[0]) - # verify serialization works - self.assertIsNotNone(json.dumps(consolidated[0])) - - def test_normalize_with_unicode_in_subscription_name(self): - cli = DummyCli() - storage_mock = {'subscriptions': None} - test_display_name = 'sub' + chr(255) - polished_display_name = 'sub?' - test_subscription = SubscriptionStub('subscriptions/sub1', - test_display_name, - SubscriptionState.enabled, - 'tenant1') - profile = Profile(cli_ctx=cli, storage=storage_mock, use_global_creds_cache=False, async_persist=False) - consolidated = profile._normalize_properties(self.user1, - [test_subscription], - False) - self.assertTrue(consolidated[0]['name'] in [polished_display_name, test_display_name]) - - def test_normalize_with_none_subscription_name(self): - cli = DummyCli() - storage_mock = {'subscriptions': None} - test_display_name = None - polished_display_name = '' - test_subscription = SubscriptionStub('subscriptions/sub1', - test_display_name, - SubscriptionState.enabled, - 'tenant1') - profile = Profile(cli_ctx=cli, storage=storage_mock, use_global_creds_cache=False, async_persist=False) - consolidated = profile._normalize_properties(self.user1, - [test_subscription], - False) - self.assertTrue(consolidated[0]['name'] == polished_display_name) - - def test_update_add_two_different_subscriptions(self): - cli = DummyCli() - storage_mock = {'subscriptions': None} - profile = Profile(cli_ctx=cli, storage=storage_mock, use_global_creds_cache=False, async_persist=False) - - # add the first and verify - consolidated = profile._normalize_properties(self.user1, - [self.subscription1], - False) - profile._set_subscriptions(consolidated) - - self.assertEqual(len(storage_mock['subscriptions']), 1) - subscription1 = storage_mock['subscriptions'][0] - subscription1_is_default = deepcopy(self.subscription1_normalized) - subscription1_is_default['isDefault'] = True - self.assertEqual(subscription1, subscription1_is_default) - - # add the second and verify - consolidated = profile._normalize_properties(self.user2, - [self.subscription2], - False) - profile._set_subscriptions(consolidated) - - self.assertEqual(len(storage_mock['subscriptions']), 2) - subscription2 = storage_mock['subscriptions'][1] - subscription2_is_default = deepcopy(self.subscription2_normalized) - subscription2_is_default['isDefault'] = True - self.assertEqual(subscription2, subscription2_is_default) - - # verify the old one stays, but no longer active - self.assertEqual(storage_mock['subscriptions'][0]['name'], - subscription1['name']) - self.assertFalse(storage_mock['subscriptions'][0]['isDefault']) - - def test_update_with_same_subscription_added_twice(self): - cli = DummyCli() - storage_mock = {'subscriptions': None} - profile = Profile(cli_ctx=cli, storage=storage_mock, use_global_creds_cache=False, async_persist=False) - - # add one twice and verify we will have one but with new token - consolidated = profile._normalize_properties(self.user1, - [self.subscription1], - False) - profile._set_subscriptions(consolidated) - - new_subscription1 = SubscriptionStub(self.id1, - self.display_name1, - self.state1, - self.tenant_id) - consolidated = profile._normalize_properties(self.user1, - [new_subscription1], - False) - profile._set_subscriptions(consolidated) - - self.assertEqual(len(storage_mock['subscriptions']), 1) - self.assertTrue(storage_mock['subscriptions'][0]['isDefault']) - - def test_set_active_subscription(self): - cli = DummyCli() - storage_mock = {'subscriptions': None} - profile = Profile(cli_ctx=cli, storage=storage_mock, use_global_creds_cache=False, async_persist=False) - - consolidated = profile._normalize_properties(self.user1, - [self.subscription1], - False) - profile._set_subscriptions(consolidated) - - consolidated = profile._normalize_properties(self.user2, - [self.subscription2], - False) - profile._set_subscriptions(consolidated) - - self.assertTrue(storage_mock['subscriptions'][1]['isDefault']) - - profile.set_active_subscription(storage_mock['subscriptions'][0]['id']) - self.assertFalse(storage_mock['subscriptions'][1]['isDefault']) - self.assertTrue(storage_mock['subscriptions'][0]['isDefault']) - - def test_default_active_subscription_to_non_disabled_one(self): - cli = DummyCli() - storage_mock = {'subscriptions': None} - profile = Profile(cli_ctx=cli, storage=storage_mock, use_global_creds_cache=False, async_persist=False) - - subscriptions = profile._normalize_properties( - self.user2, [self.subscription2, self.subscription1], False) - - profile._set_subscriptions(subscriptions) - - # verify we skip the overdued subscription and default to the 2nd one in the list - self.assertEqual(storage_mock['subscriptions'][1]['name'], self.subscription1.display_name) - self.assertTrue(storage_mock['subscriptions'][1]['isDefault']) - - def test_get_subscription(self): - cli = DummyCli() - storage_mock = {'subscriptions': None} - profile = Profile(cli_ctx=cli, storage=storage_mock, use_global_creds_cache=False, async_persist=False) - - consolidated = profile._normalize_properties(self.user1, - [self.subscription1], - False) - profile._set_subscriptions(consolidated) - - self.assertEqual(self.display_name1, profile.get_subscription()['name']) - self.assertEqual(self.display_name1, - profile.get_subscription(subscription=self.display_name1)['name']) - - sub_id = self.id1.split('/')[-1] - self.assertEqual(sub_id, profile.get_subscription()['id']) - self.assertEqual(sub_id, profile.get_subscription(subscription=sub_id)['id']) - self.assertRaises(CLIError, profile.get_subscription, "random_id") - - def test_get_auth_info_fail_on_user_account(self): - cli = DummyCli() - storage_mock = {'subscriptions': None} - profile = Profile(cli_ctx=cli, storage=storage_mock, use_global_creds_cache=False, async_persist=False) - - consolidated = profile._normalize_properties(self.user1, - [self.subscription1], - False) - profile._set_subscriptions(consolidated) - - # testing dump of existing logged in account - self.assertRaises(CLIError, profile.get_sp_auth_info) - - @mock.patch('azure.cli.core.profiles.get_api_version', autospec=True) - def test_subscription_finder_constructor(self, get_api_mock): - cli = DummyCli() - get_api_mock.return_value = '2016-06-01' - cli.cloud.endpoints.resource_manager = 'http://foo_arm' - finder = SubscriptionFinder(cli, None, None, arm_client_factory=None) - result = finder._arm_client_factory(mock.MagicMock()) - self.assertEqual(result._client._base_url, 'http://foo_arm') - - @mock.patch('adal.AuthenticationContext', autospec=True) - def test_get_auth_info_for_logged_in_service_principal(self, mock_auth_context): - cli = DummyCli() - mock_auth_context.acquire_token_with_client_credentials.return_value = self.token_entry1 - mock_arm_client = mock.MagicMock() - mock_arm_client.subscriptions.list.return_value = [deepcopy(self.subscription1_raw)] - finder = SubscriptionFinder(cli, lambda _, _1, _2: mock_auth_context, None, lambda _: mock_arm_client) - - storage_mock = {'subscriptions': []} - profile = Profile(cli_ctx=cli, storage=storage_mock, use_global_creds_cache=False, async_persist=False) - profile._management_resource_uri = 'https://management.core.windows.net/' - profile.find_subscriptions_on_login(False, '1234', 'my-secret', True, self.tenant_id, use_device_code=False, - allow_no_subscriptions=False, subscription_finder=finder) - # action - extended_info = profile.get_sp_auth_info() - # assert - self.assertEqual(self.id1.split('/')[-1], extended_info['subscriptionId']) - self.assertEqual('1234', extended_info['clientId']) - self.assertEqual('my-secret', extended_info['clientSecret']) - self.assertEqual('https://login.microsoftonline.com', extended_info['activeDirectoryEndpointUrl']) - self.assertEqual('https://management.azure.com/', extended_info['resourceManagerEndpointUrl']) - - def test_get_auth_info_for_newly_created_service_principal(self): - cli = DummyCli() - storage_mock = {'subscriptions': []} - profile = Profile(cli_ctx=cli, storage=storage_mock, use_global_creds_cache=False, async_persist=False) - consolidated = profile._normalize_properties(self.user1, [self.subscription1], False) - profile._set_subscriptions(consolidated) - # action - extended_info = profile.get_sp_auth_info(name='1234', cert_file='/tmp/123.pem') - # assert - self.assertEqual(self.id1.split('/')[-1], extended_info['subscriptionId']) - self.assertEqual(self.tenant_id, extended_info['tenantId']) - self.assertEqual('1234', extended_info['clientId']) - self.assertEqual('/tmp/123.pem', extended_info['clientCertificate']) - self.assertIsNone(extended_info.get('clientSecret', None)) - self.assertEqual('https://login.microsoftonline.com', extended_info['activeDirectoryEndpointUrl']) - self.assertEqual('https://management.azure.com/', extended_info['resourceManagerEndpointUrl']) - - @mock.patch('adal.AuthenticationContext', autospec=True) - def test_create_account_without_subscriptions_thru_service_principal(self, mock_auth_context): - mock_auth_context.acquire_token_with_client_credentials.return_value = self.token_entry1 - cli = DummyCli() - mock_arm_client = mock.MagicMock() - mock_arm_client.subscriptions.list.return_value = [] - finder = SubscriptionFinder(cli, lambda _, _1, _2: mock_auth_context, None, lambda _: mock_arm_client) - - storage_mock = {'subscriptions': []} - profile = Profile(cli_ctx=cli, storage=storage_mock, use_global_creds_cache=False, async_persist=False) - profile._management_resource_uri = 'https://management.core.windows.net/' - - # action - result = profile.find_subscriptions_on_login(False, - '1234', - 'my-secret', - True, - self.tenant_id, - use_device_code=False, - allow_no_subscriptions=True, - subscription_finder=finder) - # assert - self.assertEqual(1, len(result)) - self.assertEqual(result[0]['id'], self.tenant_id) - self.assertEqual(result[0]['state'], 'Enabled') - self.assertEqual(result[0]['tenantId'], self.tenant_id) - self.assertEqual(result[0]['name'], 'N/A(tenant level account)') - self.assertTrue(profile.is_tenant_level_account()) - - @mock.patch('adal.AuthenticationContext', autospec=True) - def test_create_account_with_subscriptions_allow_no_subscriptions_thru_service_principal(self, mock_auth_context): - """test subscription is returned even with --allow-no-subscriptions. """ - mock_auth_context.acquire_token_with_client_credentials.return_value = self.token_entry1 - cli = DummyCli() - mock_arm_client = mock.MagicMock() - mock_arm_client.subscriptions.list.return_value = [deepcopy(self.subscription1_raw)] - finder = SubscriptionFinder(cli, lambda _, _1, _2: mock_auth_context, None, lambda _: mock_arm_client) - - storage_mock = {'subscriptions': []} - profile = Profile(cli_ctx=cli, storage=storage_mock, use_global_creds_cache=False, async_persist=False) - profile._management_resource_uri = 'https://management.core.windows.net/' - - # action - result = profile.find_subscriptions_on_login(False, - '1234', - 'my-secret', - True, - self.tenant_id, - use_device_code=False, - allow_no_subscriptions=True, - subscription_finder=finder) - # assert - self.assertEqual(1, len(result)) - self.assertEqual(result[0]['id'], self.id1.split('/')[-1]) - self.assertEqual(result[0]['state'], 'Enabled') - self.assertEqual(result[0]['tenantId'], self.tenant_id) - self.assertEqual(result[0]['name'], self.display_name1) - self.assertFalse(profile.is_tenant_level_account()) - - @mock.patch('adal.AuthenticationContext', autospec=True) - def test_create_account_without_subscriptions_thru_common_tenant(self, mock_auth_context): - mock_auth_context.acquire_token.return_value = self.token_entry1 - mock_auth_context.acquire_token_with_username_password.return_value = self.token_entry1 - cli = DummyCli() - tenant_object = mock.MagicMock() - tenant_object.id = "foo-bar" - tenant_object.tenant_id = self.tenant_id - mock_arm_client = mock.MagicMock() - mock_arm_client.subscriptions.list.return_value = [] - mock_arm_client.tenants.list.return_value = (x for x in [tenant_object]) - - finder = SubscriptionFinder(cli, lambda _, _1, _2: mock_auth_context, None, lambda _: mock_arm_client) - - storage_mock = {'subscriptions': []} - profile = Profile(cli_ctx=cli, storage=storage_mock, use_global_creds_cache=False, async_persist=False) - profile._management_resource_uri = 'https://management.core.windows.net/' - - # action - result = profile.find_subscriptions_on_login(False, - '1234', - 'my-secret', - False, - None, - use_device_code=False, - allow_no_subscriptions=True, - subscription_finder=finder) - - # assert - self.assertEqual(1, len(result)) - self.assertEqual(result[0]['id'], self.tenant_id) - self.assertEqual(result[0]['state'], 'Enabled') - self.assertEqual(result[0]['tenantId'], self.tenant_id) - self.assertEqual(result[0]['name'], 'N/A(tenant level account)') - - @mock.patch('adal.AuthenticationContext', autospec=True) - def test_create_account_without_subscriptions_without_tenant(self, mock_auth_context): - cli = DummyCli() - finder = mock.MagicMock() - finder.find_through_interactive_flow.return_value = [] - storage_mock = {'subscriptions': []} - profile = Profile(cli_ctx=cli, storage=storage_mock, use_global_creds_cache=False, async_persist=False) - - # action - result = profile.find_subscriptions_on_login(True, - '1234', - 'my-secret', - False, - None, - use_device_code=False, - allow_no_subscriptions=True, - subscription_finder=finder) - - # assert - self.assertTrue(0 == len(result)) - - @mock.patch('azure.cli.core._profile._load_tokens_from_file', autospec=True) - def test_get_current_account_user(self, mock_read_cred_file): - cli = DummyCli() - # setup - mock_read_cred_file.return_value = [TestProfile.token_entry1] - - storage_mock = {'subscriptions': None} - profile = Profile(cli_ctx=cli, storage=storage_mock, use_global_creds_cache=False, async_persist=False) - consolidated = profile._normalize_properties(self.user1, - [self.subscription1], - False) - profile._set_subscriptions(consolidated) - # action - user = profile.get_current_account_user() - - # verify - self.assertEqual(user, self.user1) - - @mock.patch('azure.cli.core._profile._load_tokens_from_file', return_value=None) - def test_create_token_cache(self, mock_read_file): - cli = DummyCli() - mock_read_file.return_value = [] - profile = Profile(cli_ctx=cli, use_global_creds_cache=False, async_persist=False) - cache = profile._creds_cache.adal_token_cache - self.assertFalse(cache.read_items()) - self.assertTrue(mock_read_file.called) - - @mock.patch('azure.cli.core._profile._load_tokens_from_file', autospec=True) - def test_load_cached_tokens(self, mock_read_file): - cli = DummyCli() - mock_read_file.return_value = [TestProfile.token_entry1] - profile = Profile(cli_ctx=cli, use_global_creds_cache=False, async_persist=False) - cache = profile._creds_cache.adal_token_cache - matched = cache.find({ - "_authority": "https://login.microsoftonline.com/common", - "_clientId": "04b07795-8ddb-461a-bbee-02f9e1bf7b46", - "userId": self.user1 - }) - self.assertEqual(len(matched), 1) - self.assertEqual(matched[0]['accessToken'], self.raw_token1) - - @mock.patch('azure.cli.core._profile._load_tokens_from_file', autospec=True) - @mock.patch('azure.cli.core._profile.CredsCache.retrieve_token_for_user', autospec=True) - def test_get_login_credentials(self, mock_get_token, mock_read_cred_file): - cli = DummyCli() - some_token_type = 'Bearer' - mock_read_cred_file.return_value = [TestProfile.token_entry1] - mock_get_token.return_value = (some_token_type, TestProfile.raw_token1) - # setup - storage_mock = {'subscriptions': None} - profile = Profile(cli_ctx=cli, storage=storage_mock, use_global_creds_cache=False, async_persist=False) - test_subscription_id = '12345678-1bf0-4dda-aec3-cb9272f09590' - test_tenant_id = '12345678-38d6-4fb2-bad9-b7b93a3e1234' - test_subscription = SubscriptionStub('/subscriptions/{}'.format(test_subscription_id), - 'MSI-DEV-INC', self.state1, '12345678-38d6-4fb2-bad9-b7b93a3e1234') - consolidated = profile._normalize_properties(self.user1, - [test_subscription], - False) - profile._set_subscriptions(consolidated) - # action - cred, subscription_id, _ = profile.get_login_credentials() - - # verify - self.assertEqual(subscription_id, test_subscription_id) - - # verify the cred._tokenRetriever is a working lambda - token_type, token = cred._token_retriever() - self.assertEqual(token, self.raw_token1) - self.assertEqual(some_token_type, token_type) - mock_get_token.assert_called_once_with(mock.ANY, self.user1, test_tenant_id, - 'https://management.core.windows.net/') - self.assertEqual(mock_get_token.call_count, 1) - - @mock.patch('azure.cli.core._profile._load_tokens_from_file', autospec=True) - @mock.patch('azure.cli.core._profile.CredsCache.retrieve_token_for_user', autospec=True) - def test_get_login_credentials_aux_subscriptions(self, mock_get_token, mock_read_cred_file): - cli = DummyCli() - raw_token2 = 'some...secrets2' - token_entry2 = { - "resource": "https://management.core.windows.net/", - "tokenType": "Bearer", - "_authority": "https://login.microsoftonline.com/common", - "accessToken": raw_token2, - } - some_token_type = 'Bearer' - mock_read_cred_file.return_value = [TestProfile.token_entry1, token_entry2] - mock_get_token.side_effect = [(some_token_type, TestProfile.raw_token1), (some_token_type, raw_token2)] - # setup - storage_mock = {'subscriptions': None} - profile = Profile(cli_ctx=cli, storage=storage_mock, use_global_creds_cache=False, async_persist=False) - test_subscription_id = '12345678-1bf0-4dda-aec3-cb9272f09590' - test_subscription_id2 = '12345678-1bf0-4dda-aec3-cb9272f09591' - test_tenant_id = '12345678-38d6-4fb2-bad9-b7b93a3e1234' - test_tenant_id2 = '12345678-38d6-4fb2-bad9-b7b93a3e4321' - test_subscription = SubscriptionStub('/subscriptions/{}'.format(test_subscription_id), - 'MSI-DEV-INC', self.state1, test_tenant_id) - test_subscription2 = SubscriptionStub('/subscriptions/{}'.format(test_subscription_id2), - 'MSI-DEV-INC2', self.state1, test_tenant_id2) - consolidated = profile._normalize_properties(self.user1, - [test_subscription, test_subscription2], - False) - profile._set_subscriptions(consolidated) - # action - cred, subscription_id, _ = profile.get_login_credentials(subscription_id=test_subscription_id, - aux_subscriptions=[test_subscription_id2]) - - # verify - self.assertEqual(subscription_id, test_subscription_id) - - # verify the cred._tokenRetriever is a working lambda - token_type, token = cred._token_retriever() - self.assertEqual(token, self.raw_token1) - self.assertEqual(some_token_type, token_type) - - token2 = cred._external_tenant_token_retriever() - self.assertEqual(len(token2), 1) - self.assertEqual(token2[0][1], raw_token2) - - self.assertEqual(mock_get_token.call_count, 2) - - @mock.patch('azure.cli.core._profile._load_tokens_from_file', autospec=True) - @mock.patch('azure.cli.core.adal_authentication.MSIAuthenticationWrapper', autospec=True) - def test_get_login_credentials_msi_system_assigned(self, mock_msi_auth, mock_read_cred_file): - mock_read_cred_file.return_value = [] - - # setup an existing msi subscription - profile = Profile(cli_ctx=DummyCli(), storage={'subscriptions': None}, use_global_creds_cache=False, - async_persist=False) - test_subscription_id = '12345678-1bf0-4dda-aec3-cb9272f09590' - test_tenant_id = '12345678-38d6-4fb2-bad9-b7b93a3e1234' - test_user = 'systemAssignedIdentity' - msi_subscription = SubscriptionStub('/subscriptions/' + test_subscription_id, 'MSI', self.state1, test_tenant_id) - consolidated = profile._normalize_properties(test_user, - [msi_subscription], - True) - profile._set_subscriptions(consolidated) - - mock_msi_auth.side_effect = MSRestAzureAuthStub - - # action - cred, subscription_id, _ = profile.get_login_credentials() - - # assert - self.assertEqual(subscription_id, test_subscription_id) - - # sniff test the msi_auth object - cred.set_token() - cred.token - self.assertTrue(cred.set_token_invoked_count) - self.assertTrue(cred.token_read_count) - - @mock.patch('azure.cli.core._profile._load_tokens_from_file', autospec=True) - @mock.patch('azure.cli.core.adal_authentication.MSIAuthenticationWrapper', autospec=True) - def test_get_login_credentials_msi_user_assigned_with_client_id(self, mock_msi_auth, mock_read_cred_file): - mock_read_cred_file.return_value = [] - - # setup an existing msi subscription - profile = Profile(cli_ctx=DummyCli(), storage={'subscriptions': None}, use_global_creds_cache=False, - async_persist=False) - test_subscription_id = '12345678-1bf0-4dda-aec3-cb9272f09590' - test_tenant_id = '12345678-38d6-4fb2-bad9-b7b93a3e1234' - test_user = 'userAssignedIdentity' - test_client_id = '12345678-38d6-4fb2-bad9-b7b93a3e8888' - msi_subscription = SubscriptionStub('/subscriptions/' + test_subscription_id, 'MSIClient-{}'.format(test_client_id), self.state1, test_tenant_id) - consolidated = profile._normalize_properties(test_user, [msi_subscription], True) - profile._set_subscriptions(consolidated, secondary_key_name='name') - - mock_msi_auth.side_effect = MSRestAzureAuthStub - - # action - cred, subscription_id, _ = profile.get_login_credentials() - - # assert - self.assertEqual(subscription_id, test_subscription_id) - - # sniff test the msi_auth object - cred.set_token() - cred.token - self.assertTrue(cred.set_token_invoked_count) - self.assertTrue(cred.token_read_count) - self.assertTrue(cred.client_id, test_client_id) - - @mock.patch('azure.cli.core._profile._load_tokens_from_file', autospec=True) - @mock.patch('azure.cli.core.adal_authentication.MSIAuthenticationWrapper', autospec=True) - def test_get_login_credentials_msi_user_assigned_with_object_id(self, mock_msi_auth, mock_read_cred_file): - mock_read_cred_file.return_value = [] - - # setup an existing msi subscription - profile = Profile(cli_ctx=DummyCli(), storage={'subscriptions': None}, use_global_creds_cache=False, - async_persist=False) - test_subscription_id = '12345678-1bf0-4dda-aec3-cb9272f09590' - test_object_id = '12345678-38d6-4fb2-bad9-b7b93a3e9999' - msi_subscription = SubscriptionStub('/subscriptions/12345678-1bf0-4dda-aec3-cb9272f09590', - 'MSIObject-{}'.format(test_object_id), - self.state1, '12345678-38d6-4fb2-bad9-b7b93a3e1234') - consolidated = profile._normalize_properties('userAssignedIdentity', [msi_subscription], True) - profile._set_subscriptions(consolidated, secondary_key_name='name') - - mock_msi_auth.side_effect = MSRestAzureAuthStub - - # action - cred, subscription_id, _ = profile.get_login_credentials() - - # assert - self.assertEqual(subscription_id, test_subscription_id) - - # sniff test the msi_auth object - cred.set_token() - cred.token - self.assertTrue(cred.set_token_invoked_count) - self.assertTrue(cred.token_read_count) - self.assertTrue(cred.object_id, test_object_id) - - @mock.patch('azure.cli.core._profile._load_tokens_from_file', autospec=True) - @mock.patch('azure.cli.core.adal_authentication.MSIAuthenticationWrapper', autospec=True) - def test_get_login_credentials_msi_user_assigned_with_res_id(self, mock_msi_auth, mock_read_cred_file): - mock_read_cred_file.return_value = [] - - # setup an existing msi subscription - profile = Profile(cli_ctx=DummyCli(), storage={'subscriptions': None}, use_global_creds_cache=False, - async_persist=False) - test_subscription_id = '12345678-1bf0-4dda-aec3-cb9272f09590' - test_res_id = ('/subscriptions/{}/resourceGroups/r1/providers/Microsoft.ManagedIdentity/' - 'userAssignedIdentities/id1').format(test_subscription_id) - msi_subscription = SubscriptionStub('/subscriptions/{}'.format(test_subscription_id), - 'MSIResource-{}'.format(test_res_id), - self.state1, '12345678-38d6-4fb2-bad9-b7b93a3e1234') - consolidated = profile._normalize_properties('userAssignedIdentity', [msi_subscription], True) - profile._set_subscriptions(consolidated, secondary_key_name='name') - - mock_msi_auth.side_effect = MSRestAzureAuthStub - - # action - cred, subscription_id, _ = profile.get_login_credentials() - - # assert - self.assertEqual(subscription_id, test_subscription_id) - - # sniff test the msi_auth object - cred.set_token() - cred.token - self.assertTrue(cred.set_token_invoked_count) - self.assertTrue(cred.token_read_count) - self.assertTrue(cred.msi_res_id, test_res_id) - - @mock.patch('azure.cli.core._profile._load_tokens_from_file', autospec=True) - @mock.patch('azure.cli.core._profile.CredsCache.retrieve_token_for_user', autospec=True) - def test_get_raw_token(self, mock_get_token, mock_read_cred_file): - cli = DummyCli() - some_token_type = 'Bearer' - mock_read_cred_file.return_value = [TestProfile.token_entry1] - mock_get_token.return_value = (some_token_type, TestProfile.raw_token1, - TestProfile.token_entry1) - # setup - storage_mock = {'subscriptions': None} - profile = Profile(cli_ctx=cli, storage=storage_mock, use_global_creds_cache=False, async_persist=False) - consolidated = profile._normalize_properties(self.user1, - [self.subscription1], - False) - profile._set_subscriptions(consolidated) - # action - creds, sub, tenant = profile.get_raw_token(resource='https://foo') - - # verify - self.assertEqual(creds[0], self.token_entry1['tokenType']) - self.assertEqual(creds[1], self.raw_token1) - # the last in the tuple is the whole token entry which has several fields - self.assertEqual(creds[2]['expiresOn'], self.token_entry1['expiresOn']) - mock_get_token.assert_called_once_with(mock.ANY, self.user1, self.tenant_id, - 'https://foo') - self.assertEqual(mock_get_token.call_count, 1) - self.assertEqual(sub, '1') - self.assertEqual(tenant, self.tenant_id) - - @mock.patch('azure.cli.core._profile._load_tokens_from_file', autospec=True) - @mock.patch('azure.cli.core._profile.CredsCache.retrieve_token_for_service_principal', autospec=True) - def test_get_raw_token_for_sp(self, mock_get_token, mock_read_cred_file): - cli = DummyCli() - some_token_type = 'Bearer' - mock_read_cred_file.return_value = [TestProfile.token_entry1] - mock_get_token.return_value = (some_token_type, TestProfile.raw_token1, - TestProfile.token_entry1) - # setup - storage_mock = {'subscriptions': None} - profile = Profile(cli_ctx=cli, storage=storage_mock, use_global_creds_cache=False, async_persist=False) - consolidated = profile._normalize_properties('sp1', - [self.subscription1], - True) - profile._set_subscriptions(consolidated) - # action - creds, sub, tenant = profile.get_raw_token(resource='https://foo') - - # verify - self.assertEqual(creds[0], self.token_entry1['tokenType']) - self.assertEqual(creds[1], self.raw_token1) - # the last in the tuple is the whole token entry which has several fields - self.assertEqual(creds[2]['expiresOn'], self.token_entry1['expiresOn']) - mock_get_token.assert_called_once_with(mock.ANY, 'sp1', 'https://foo', self.tenant_id, False) - self.assertEqual(mock_get_token.call_count, 1) - self.assertEqual(sub, '1') - self.assertEqual(tenant, self.tenant_id) - - @mock.patch('azure.cli.core._profile._load_tokens_from_file', autospec=True) - @mock.patch('azure.cli.core.adal_authentication.MSIAuthenticationWrapper', autospec=True) - def test_get_raw_token_msi_system_assigned(self, mock_msi_auth, mock_read_cred_file): - mock_read_cred_file.return_value = [] - - # setup an existing msi subscription - profile = Profile(cli_ctx=DummyCli(), storage={'subscriptions': None}, use_global_creds_cache=False, - async_persist=False) - test_subscription_id = '12345678-1bf0-4dda-aec3-cb9272f09590' - test_tenant_id = '12345678-38d6-4fb2-bad9-b7b93a3e1234' - test_user = 'systemAssignedIdentity' - msi_subscription = SubscriptionStub('/subscriptions/' + test_subscription_id, - 'MSI', self.state1, test_tenant_id) - consolidated = profile._normalize_properties(test_user, - [msi_subscription], - True) - profile._set_subscriptions(consolidated) - - mock_msi_auth.side_effect = MSRestAzureAuthStub - - # action - cred, subscription_id, _ = profile.get_raw_token(resource='http://test_resource') - - # assert - self.assertEqual(subscription_id, test_subscription_id) - self.assertEqual(cred[0], 'Bearer') - self.assertEqual(cred[1], TestProfile.test_msi_access_token) - - @mock.patch('azure.cli.core._profile._load_tokens_from_file', autospec=True) - @mock.patch('azure.cli.core._profile.CredsCache.retrieve_token_for_user', autospec=True) - def test_get_login_credentials_for_graph_client(self, mock_get_token, mock_read_cred_file): - cli = DummyCli() - some_token_type = 'Bearer' - mock_read_cred_file.return_value = [TestProfile.token_entry1] - mock_get_token.return_value = (some_token_type, TestProfile.raw_token1) - # setup - storage_mock = {'subscriptions': None} - profile = Profile(cli_ctx=cli, storage=storage_mock, use_global_creds_cache=False, async_persist=False) - consolidated = profile._normalize_properties(self.user1, [self.subscription1], - False) - profile._set_subscriptions(consolidated) - # action - cred, _, tenant_id = profile.get_login_credentials( - resource=cli.cloud.endpoints.active_directory_graph_resource_id) - _, _ = cred._token_retriever() - # verify - mock_get_token.assert_called_once_with(mock.ANY, self.user1, self.tenant_id, - 'https://graph.windows.net/') - self.assertEqual(tenant_id, self.tenant_id) - - @mock.patch('azure.cli.core._profile._load_tokens_from_file', autospec=True) - @mock.patch('azure.cli.core._profile.CredsCache.retrieve_token_for_user', autospec=True) - def test_get_login_credentials_for_data_lake_client(self, mock_get_token, mock_read_cred_file): - cli = DummyCli() - some_token_type = 'Bearer' - mock_read_cred_file.return_value = [TestProfile.token_entry1] - mock_get_token.return_value = (some_token_type, TestProfile.raw_token1) - # setup - storage_mock = {'subscriptions': None} - profile = Profile(cli_ctx=cli, storage=storage_mock, use_global_creds_cache=False, async_persist=False) - consolidated = profile._normalize_properties(self.user1, [self.subscription1], - False) - profile._set_subscriptions(consolidated) - # action - cred, _, tenant_id = profile.get_login_credentials( - resource=cli.cloud.endpoints.active_directory_data_lake_resource_id) - _, _ = cred._token_retriever() - # verify - mock_get_token.assert_called_once_with(mock.ANY, self.user1, self.tenant_id, - 'https://datalake.azure.net/') - self.assertEqual(tenant_id, self.tenant_id) - - @mock.patch('azure.cli.core._profile._load_tokens_from_file', autospec=True) - @mock.patch('azure.cli.core._profile.CredsCache.persist_cached_creds', autospec=True) - def test_logout(self, mock_persist_creds, mock_read_cred_file): - cli = DummyCli() - # setup - mock_read_cred_file.return_value = [TestProfile.token_entry1] - - storage_mock = {'subscriptions': None} - profile = Profile(cli_ctx=cli, storage=storage_mock, use_global_creds_cache=False, async_persist=False) - consolidated = profile._normalize_properties(self.user1, - [self.subscription1], - False) - profile._set_subscriptions(consolidated) - self.assertEqual(1, len(storage_mock['subscriptions'])) - # action - profile.logout(self.user1) - - # verify - self.assertEqual(0, len(storage_mock['subscriptions'])) - self.assertEqual(mock_read_cred_file.call_count, 1) - self.assertEqual(mock_persist_creds.call_count, 1) - - @mock.patch('azure.cli.core._profile._delete_file', autospec=True) - def test_logout_all(self, mock_delete_cred_file): - cli = DummyCli() - # setup - storage_mock = {'subscriptions': None} - profile = Profile(cli_ctx=cli, storage=storage_mock, use_global_creds_cache=False, async_persist=False) - consolidated = profile._normalize_properties(self.user1, - [self.subscription1], - False) - consolidated2 = profile._normalize_properties(self.user2, - [self.subscription2], - False) - profile._set_subscriptions(consolidated + consolidated2) - - self.assertEqual(2, len(storage_mock['subscriptions'])) - # action - profile.logout_all() - - # verify - self.assertEqual([], storage_mock['subscriptions']) - self.assertEqual(mock_delete_cred_file.call_count, 1) - - @mock.patch('adal.AuthenticationContext', autospec=True) - def test_find_subscriptions_thru_username_password(self, mock_auth_context): - cli = DummyCli() - mock_auth_context.acquire_token_with_username_password.return_value = self.token_entry1 - mock_auth_context.acquire_token.return_value = self.token_entry1 - mock_arm_client = mock.MagicMock() - mock_arm_client.tenants.list.return_value = [TenantStub(self.tenant_id)] - mock_arm_client.subscriptions.list.return_value = [deepcopy(self.subscription1_raw)] - finder = SubscriptionFinder(cli, lambda _, _1, _2: mock_auth_context, None, lambda _: mock_arm_client) - mgmt_resource = 'https://management.core.windows.net/' - # action - subs = finder.find_from_user_account(self.user1, 'bar', None, mgmt_resource) - - # assert - self.assertEqual([self.subscription1], subs) - mock_auth_context.acquire_token_with_username_password.assert_called_once_with( - mgmt_resource, self.user1, 'bar', mock.ANY) - mock_auth_context.acquire_token.assert_called_once_with( - mgmt_resource, self.user1, mock.ANY) - - @mock.patch('adal.AuthenticationContext', autospec=True) - def test_find_subscriptions_thru_username_non_password(self, mock_auth_context): - cli = DummyCli() - mock_auth_context.acquire_token_with_username_password.return_value = None - finder = SubscriptionFinder(cli, lambda _, _1, _2: mock_auth_context, None, lambda _: None) - # action - subs = finder.find_from_user_account(self.user1, 'bar', None, 'http://goo-resource') - - # assert - self.assertEqual([], subs) - - @mock.patch('azure.cli.core.adal_authentication.MSIAuthenticationWrapper', autospec=True) - @mock.patch('azure.cli.core.profiles._shared.get_client_class', autospec=True) - @mock.patch('azure.cli.core._profile._get_cloud_console_token_endpoint', autospec=True) - @mock.patch('azure.cli.core._profile.SubscriptionFinder', autospec=True) - def test_find_subscriptions_in_cloud_console(self, mock_subscription_finder, mock_get_token_endpoint, - mock_get_client_class, mock_msi_auth): - - class SubscriptionFinderStub: - def find_from_raw_token(self, tenant, token): - # make sure the tenant and token args match 'TestProfile.test_msi_access_token' - if token != TestProfile.test_msi_access_token or tenant != '54826b22-38d6-4fb2-bad9-b7b93a3e9c5a': - raise AssertionError('find_from_raw_token was not invoked with expected tenant or token') - return [TestProfile.subscription1] - - mock_subscription_finder.return_value = SubscriptionFinderStub() - - mock_get_token_endpoint.return_value = "http://great_endpoint" - mock_msi_auth.return_value = MSRestAzureAuthStub() - - profile = Profile(cli_ctx=DummyCli(), storage={'subscriptions': None}, use_global_creds_cache=False, - async_persist=False) - - # action - subscriptions = profile.find_subscriptions_in_cloud_console() - - # assert - self.assertEqual(len(subscriptions), 1) - s = subscriptions[0] - self.assertEqual(s['user']['name'], 'admin3@AzureSDKTeam.onmicrosoft.com') - self.assertEqual(s['user']['cloudShellID'], True) - self.assertEqual(s['user']['type'], 'user') - self.assertEqual(s['name'], self.display_name1) - self.assertEqual(s['id'], self.id1.split('/')[-1]) - - @mock.patch('requests.get', autospec=True) - @mock.patch('azure.cli.core._profile.SubscriptionFinder._get_subscription_client_class', autospec=True) - def test_find_subscriptions_in_vm_with_msi_system_assigned(self, mock_get_client_class, mock_get): - - class ClientStub: - def __init__(self, *args, **kwargs): - self.subscriptions = mock.MagicMock() - self.subscriptions.list.return_value = [deepcopy(TestProfile.subscription1_raw)] - self.config = mock.MagicMock() - self._client = mock.MagicMock() - - mock_get_client_class.return_value = ClientStub - cli = DummyCli() - storage_mock = {'subscriptions': None} - profile = Profile(cli_ctx=cli, storage=storage_mock, use_global_creds_cache=False, async_persist=False) - - test_token_entry = { - 'token_type': 'Bearer', - 'access_token': TestProfile.test_msi_access_token - } - encoded_test_token = json.dumps(test_token_entry).encode() - good_response = mock.MagicMock() - good_response.status_code = 200 - good_response.content = encoded_test_token - mock_get.return_value = good_response - - subscriptions = profile.find_subscriptions_in_vm_with_msi() - - # assert - self.assertEqual(len(subscriptions), 1) - s = subscriptions[0] - self.assertEqual(s['user']['name'], 'systemAssignedIdentity') - self.assertEqual(s['user']['type'], 'servicePrincipal') - self.assertEqual(s['user']['assignedIdentityInfo'], 'MSI') - self.assertEqual(s['name'], self.display_name1) - self.assertEqual(s['id'], self.id1.split('/')[-1]) - self.assertEqual(s['tenantId'], '54826b22-38d6-4fb2-bad9-b7b93a3e9c5a') - - @mock.patch('requests.get', autospec=True) - @mock.patch('azure.cli.core._profile.SubscriptionFinder._get_subscription_client_class', autospec=True) - def test_find_subscriptions_in_vm_with_msi_no_subscriptions(self, mock_get_client_class, mock_get): - - class ClientStub: - def __init__(self, *args, **kwargs): - self.subscriptions = mock.MagicMock() - self.subscriptions.list.return_value = [] - self.config = mock.MagicMock() - self._client = mock.MagicMock() - - mock_get_client_class.return_value = ClientStub - cli = DummyCli() - storage_mock = {'subscriptions': None} - profile = Profile(cli_ctx=cli, storage=storage_mock, use_global_creds_cache=False, async_persist=False) - - test_token_entry = { - 'token_type': 'Bearer', - 'access_token': TestProfile.test_msi_access_token - } - encoded_test_token = json.dumps(test_token_entry).encode() - good_response = mock.MagicMock() - good_response.status_code = 200 - good_response.content = encoded_test_token - mock_get.return_value = good_response - - subscriptions = profile.find_subscriptions_in_vm_with_msi(allow_no_subscriptions=True) - - # assert - self.assertEqual(len(subscriptions), 1) - s = subscriptions[0] - self.assertEqual(s['user']['name'], 'systemAssignedIdentity') - self.assertEqual(s['user']['type'], 'servicePrincipal') - self.assertEqual(s['user']['assignedIdentityInfo'], 'MSI') - self.assertEqual(s['name'], 'N/A(tenant level account)') - self.assertEqual(s['id'], self.test_msi_tenant) - self.assertEqual(s['tenantId'], self.test_msi_tenant) - - @mock.patch('requests.get', autospec=True) - @mock.patch('azure.cli.core._profile.SubscriptionFinder._get_subscription_client_class', autospec=True) - def test_find_subscriptions_in_vm_with_msi_user_assigned_with_client_id(self, mock_get_client_class, mock_get): - - class ClientStub: - def __init__(self, *args, **kwargs): - self.subscriptions = mock.MagicMock() - self.subscriptions.list.return_value = [deepcopy(TestProfile.subscription1_raw)] - self.config = mock.MagicMock() - self._client = mock.MagicMock() - - mock_get_client_class.return_value = ClientStub - cli = DummyCli() - storage_mock = {'subscriptions': None} - profile = Profile(cli_ctx=cli, storage=storage_mock, use_global_creds_cache=False, async_persist=False) - - test_token_entry = { - 'token_type': 'Bearer', - 'access_token': TestProfile.test_msi_access_token - } - test_client_id = '54826b22-38d6-4fb2-bad9-b7b93a3e9999' - encoded_test_token = json.dumps(test_token_entry).encode() - good_response = mock.MagicMock() - good_response.status_code = 200 - good_response.content = encoded_test_token - mock_get.return_value = good_response - - subscriptions = profile.find_subscriptions_in_vm_with_msi(identity_id=test_client_id) - - # assert - self.assertEqual(len(subscriptions), 1) - s = subscriptions[0] - self.assertEqual(s['user']['name'], 'userAssignedIdentity') - self.assertEqual(s['user']['type'], 'servicePrincipal') - self.assertEqual(s['name'], self.display_name1) - self.assertEqual(s['user']['assignedIdentityInfo'], 'MSIClient-{}'.format(test_client_id)) - self.assertEqual(s['id'], self.id1.split('/')[-1]) - self.assertEqual(s['tenantId'], '54826b22-38d6-4fb2-bad9-b7b93a3e9c5a') - - @mock.patch('azure.cli.core.adal_authentication.MSIAuthenticationWrapper', autospec=True) - @mock.patch('azure.cli.core.profiles._shared.get_client_class', autospec=True) - @mock.patch('azure.cli.core._profile.SubscriptionFinder', autospec=True) - def test_find_subscriptions_in_vm_with_msi_user_assigned_with_object_id(self, mock_subscription_finder, mock_get_client_class, - mock_msi_auth): - from azure.cli.core.azclierror import AzureResponseError - - class SubscriptionFinderStub: - def find_from_raw_token(self, tenant, token): - # make sure the tenant and token args match 'TestProfile.test_msi_access_token' - if token != TestProfile.test_msi_access_token or tenant != '54826b22-38d6-4fb2-bad9-b7b93a3e9c5a': - raise AssertionError('find_from_raw_token was not invoked with expected tenant or token') - return [TestProfile.subscription1] - - class AuthStub: - def __init__(self, **kwargs): - self.token = None - self.client_id = kwargs.get('client_id') - self.object_id = kwargs.get('object_id') - # since msrestazure 0.4.34, set_token in init - self.set_token() - - def set_token(self): - # here we will reject the 1st sniffing of trying with client_id and then acccept the 2nd - if self.object_id: - self.token = { - 'token_type': 'Bearer', - 'access_token': TestProfile.test_msi_access_token - } - else: - raise AzureResponseError('Failed to connect to MSI. Please make sure MSI is configured correctly.\n' - 'Get Token request returned http error: 400, reason: Bad Request') - - profile = Profile(cli_ctx=DummyCli(), storage={'subscriptions': None}, use_global_creds_cache=False, - async_persist=False) - - mock_subscription_finder.return_value = SubscriptionFinderStub() - - mock_msi_auth.side_effect = AuthStub - test_object_id = '54826b22-38d6-4fb2-bad9-b7b93a3e9999' - - # action - subscriptions = profile.find_subscriptions_in_vm_with_msi(identity_id=test_object_id) - - # assert - self.assertEqual(subscriptions[0]['user']['assignedIdentityInfo'], 'MSIObject-{}'.format(test_object_id)) - - @mock.patch('requests.get', autospec=True) - @mock.patch('azure.cli.core._profile.SubscriptionFinder._get_subscription_client_class', autospec=True) - def test_find_subscriptions_in_vm_with_msi_user_assigned_with_res_id(self, mock_get_client_class, mock_get): - - class ClientStub: - def __init__(self, *args, **kwargs): - self.subscriptions = mock.MagicMock() - self.subscriptions.list.return_value = [deepcopy(TestProfile.subscription1_raw)] - self.config = mock.MagicMock() - self._client = mock.MagicMock() - - mock_get_client_class.return_value = ClientStub - cli = DummyCli() - storage_mock = {'subscriptions': None} - profile = Profile(cli_ctx=cli, storage=storage_mock, use_global_creds_cache=False, async_persist=False) - - test_token_entry = { - 'token_type': 'Bearer', - 'access_token': TestProfile.test_msi_access_token - } - test_res_id = ('/subscriptions/0b1f6471-1bf0-4dda-aec3-cb9272f09590/resourcegroups/g1/' - 'providers/Microsoft.ManagedIdentity/userAssignedIdentities/id1') - - encoded_test_token = json.dumps(test_token_entry).encode() - good_response = mock.MagicMock() - good_response.status_code = 200 - good_response.content = encoded_test_token - mock_get.return_value = good_response - - subscriptions = profile.find_subscriptions_in_vm_with_msi(identity_id=test_res_id) - - # assert - self.assertEqual(subscriptions[0]['user']['assignedIdentityInfo'], 'MSIResource-{}'.format(test_res_id)) - - @mock.patch('adal.AuthenticationContext.acquire_token_with_username_password', autospec=True) - @mock.patch('adal.AuthenticationContext.acquire_token', autospec=True) - def test_find_subscriptions_thru_username_password_adfs(self, mock_acquire_token, - mock_acquire_token_username_password): - cli = DummyCli() - TEST_ADFS_AUTH_URL = 'https://adfs.local.azurestack.external/adfs' - - def test_acquire_token(self, resource, username, password, client_id): - global acquire_token_invoked - acquire_token_invoked = True - if (self.authority.url == TEST_ADFS_AUTH_URL and self.authority.is_adfs_authority): - return TestProfile.token_entry1 - else: - raise ValueError('AuthContext was not initialized correctly for ADFS') - - mock_acquire_token_username_password.side_effect = test_acquire_token - mock_acquire_token.return_value = self.token_entry1 - mock_arm_client = mock.MagicMock() - mock_arm_client.tenants.list.return_value = [TenantStub(self.tenant_id)] - mock_arm_client.subscriptions.list.return_value = [deepcopy(self.subscription1_raw)] - cli.cloud.endpoints.active_directory = TEST_ADFS_AUTH_URL - finder = SubscriptionFinder(cli, _AUTH_CTX_FACTORY, None, lambda _: mock_arm_client) - mgmt_resource = 'https://management.core.windows.net/' - # action - subs = finder.find_from_user_account(self.user1, 'bar', None, mgmt_resource) - - # assert - self.assertEqual([self.subscription1], subs) - self.assertTrue(acquire_token_invoked) - - @mock.patch('adal.AuthenticationContext', autospec=True) - @mock.patch('azure.cli.core._profile.logger', autospec=True) - def test_find_subscriptions_thru_username_password_with_account_disabled(self, mock_logger, mock_auth_context): - cli = DummyCli() - mock_auth_context.acquire_token_with_username_password.return_value = self.token_entry1 - mock_auth_context.acquire_token.side_effect = AdalError('Account is disabled') - mock_arm_client = mock.MagicMock() - mock_arm_client.tenants.list.return_value = [TenantStub(self.tenant_id)] - finder = SubscriptionFinder(cli, lambda _, _1, _2: mock_auth_context, None, lambda _: mock_arm_client) - mgmt_resource = 'https://management.core.windows.net/' - # action - subs = finder.find_from_user_account(self.user1, 'bar', None, mgmt_resource) - - # assert - self.assertEqual([], subs) - mock_logger.warning.assert_called_once_with(mock.ANY, mock.ANY, mock.ANY) - - @mock.patch('adal.AuthenticationContext', autospec=True) - def test_find_subscriptions_from_particular_tenent(self, mock_auth_context): - def just_raise(ex): - raise ex - - cli = DummyCli() - mock_arm_client = mock.MagicMock() - mock_arm_client.tenants.list.side_effect = lambda: just_raise( - ValueError("'tenants.list' should not occur")) - mock_arm_client.subscriptions.list.return_value = [deepcopy(self.subscription1_raw)] - finder = SubscriptionFinder(cli, lambda _, _1, _2: mock_auth_context, None, lambda _: mock_arm_client) - # action - subs = finder.find_from_user_account(self.user1, 'bar', self.tenant_id, 'http://someresource') - - # assert - self.assertEqual([self.subscription1], subs) - - @mock.patch('adal.AuthenticationContext', autospec=True) - def test_find_subscriptions_through_device_code_flow(self, mock_auth_context): - cli = DummyCli() - test_nonsense_code = {'message': 'magic code for you'} - mock_auth_context.acquire_user_code.return_value = test_nonsense_code - mock_auth_context.acquire_token_with_device_code.return_value = self.token_entry1 - mock_arm_client = mock.MagicMock() - mock_arm_client.tenants.list.return_value = [TenantStub(self.tenant_id)] - mock_arm_client.subscriptions.list.return_value = [deepcopy(self.subscription1_raw)] - finder = SubscriptionFinder(cli, lambda _, _1, _2: mock_auth_context, None, lambda _: mock_arm_client) - mgmt_resource = 'https://management.core.windows.net/' - # action - subs = finder.find_through_interactive_flow(None, mgmt_resource) - - # assert - self.assertEqual([self.subscription1], subs) - mock_auth_context.acquire_user_code.assert_called_once_with( - mgmt_resource, mock.ANY) - mock_auth_context.acquire_token_with_device_code.assert_called_once_with( - mgmt_resource, test_nonsense_code, mock.ANY) - mock_auth_context.acquire_token.assert_called_once_with( - mgmt_resource, self.user1, mock.ANY) - - @mock.patch('adal.AuthenticationContext', autospec=True) - @mock.patch('azure.cli.core._profile._get_authorization_code', autospec=True) - def test_find_subscriptions_through_authorization_code_flow(self, _get_authorization_code_mock, mock_auth_context): - import adal - cli = DummyCli() - mock_arm_client = mock.MagicMock() - mock_arm_client.tenants.list.return_value = [TenantStub(self.tenant_id)] - mock_arm_client.subscriptions.list.return_value = [deepcopy(self.subscription1_raw)] - token_cache = adal.TokenCache() - finder = SubscriptionFinder(cli, lambda _, _1, _2: mock_auth_context, token_cache, lambda _: mock_arm_client) - _get_authorization_code_mock.return_value = { - 'code': 'code1', - 'reply_url': 'http://localhost:8888' - } - mgmt_resource = 'https://management.core.windows.net/' - temp_token_cache = mock.MagicMock() - type(mock_auth_context).cache = temp_token_cache - temp_token_cache.read_items.return_value = [] - mock_auth_context.acquire_token_with_authorization_code.return_value = self.token_entry1 - - # action - subs = finder.find_through_authorization_code_flow(None, mgmt_resource, 'https:/some_aad_point/common') - - # assert - self.assertEqual([self.subscription1], subs) - mock_auth_context.acquire_token.assert_called_once_with(mgmt_resource, self.user1, mock.ANY) - mock_auth_context.acquire_token_with_authorization_code.assert_called_once_with('code1', - 'http://localhost:8888', - mgmt_resource, mock.ANY, - None) - _get_authorization_code_mock.assert_called_once_with(mgmt_resource, 'https:/some_aad_point/common') - - @mock.patch('adal.AuthenticationContext', autospec=True) - def test_find_subscriptions_interactive_from_particular_tenent(self, mock_auth_context): - def just_raise(ex): - raise ex - - cli = DummyCli() - mock_arm_client = mock.MagicMock() - mock_arm_client.tenants.list.side_effect = lambda: just_raise( - ValueError("'tenants.list' should not occur")) - mock_arm_client.subscriptions.list.return_value = [deepcopy(self.subscription1_raw)] - finder = SubscriptionFinder(cli, lambda _, _1, _2: mock_auth_context, None, lambda _: mock_arm_client) - # action - subs = finder.find_through_interactive_flow(self.tenant_id, 'http://someresource') - - # assert - self.assertEqual([self.subscription1], subs) - - @mock.patch('adal.AuthenticationContext', autospec=True) - def test_find_subscriptions_from_service_principal_id(self, mock_auth_context): - cli = DummyCli() - mock_auth_context.acquire_token_with_client_credentials.return_value = self.token_entry1 - mock_arm_client = mock.MagicMock() - mock_arm_client.subscriptions.list.return_value = [deepcopy(self.subscription1_raw)] - finder = SubscriptionFinder(cli, lambda _, _1, _2: mock_auth_context, None, lambda _: mock_arm_client) - mgmt_resource = 'https://management.core.windows.net/' - # action - subs = finder.find_from_service_principal_id('my app', ServicePrincipalAuth('my secret'), - self.tenant_id, mgmt_resource) - - # assert - self.assertEqual([self.subscription1], subs) - mock_arm_client.tenants.list.assert_not_called() - mock_auth_context.acquire_token.assert_not_called() - mock_auth_context.acquire_token_with_client_credentials.assert_called_once_with( - mgmt_resource, 'my app', 'my secret') - - @mock.patch('adal.AuthenticationContext', autospec=True) - def test_find_subscriptions_from_service_principal_using_cert(self, mock_auth_context): - cli = DummyCli() - mock_auth_context.acquire_token_with_client_certificate.return_value = self.token_entry1 - mock_arm_client = mock.MagicMock() - mock_arm_client.subscriptions.list.return_value = [deepcopy(self.subscription1_raw)] - finder = SubscriptionFinder(cli, lambda _, _1, _2: mock_auth_context, None, lambda _: mock_arm_client) - mgmt_resource = 'https://management.core.windows.net/' - - curr_dir = os.path.dirname(os.path.realpath(__file__)) - test_cert_file = os.path.join(curr_dir, 'sp_cert.pem') - - # action - subs = finder.find_from_service_principal_id('my app', ServicePrincipalAuth(test_cert_file), - self.tenant_id, mgmt_resource) - - # assert - self.assertEqual([self.subscription1], subs) - mock_arm_client.tenants.list.assert_not_called() - mock_auth_context.acquire_token.assert_not_called() - mock_auth_context.acquire_token_with_client_certificate.assert_called_once_with( - mgmt_resource, 'my app', mock.ANY, mock.ANY, None) - - @mock.patch('adal.AuthenticationContext', autospec=True) - def test_find_subscriptions_from_service_principal_using_cert_sn_issuer(self, mock_auth_context): - cli = DummyCli() - mock_auth_context.acquire_token_with_client_certificate.return_value = self.token_entry1 - mock_arm_client = mock.MagicMock() - mock_arm_client.subscriptions.list.return_value = [deepcopy(self.subscription1_raw)] - finder = SubscriptionFinder(cli, lambda _, _1, _2: mock_auth_context, None, lambda _: mock_arm_client) - mgmt_resource = 'https://management.core.windows.net/' - - curr_dir = os.path.dirname(os.path.realpath(__file__)) - test_cert_file = os.path.join(curr_dir, 'sp_cert.pem') - with open(test_cert_file) as cert_file: - cert_file_string = cert_file.read() - match = re.search(r'\-+BEGIN CERTIFICATE.+\-+(?P[^-]+)\-+END CERTIFICATE.+\-+', - cert_file_string, re.I) - public_certificate = match.group('public').strip() - # action - subs = finder.find_from_service_principal_id('my app', ServicePrincipalAuth(test_cert_file, use_cert_sn_issuer=True), - self.tenant_id, mgmt_resource) - - # assert - self.assertEqual([self.subscription1], subs) - mock_arm_client.tenants.list.assert_not_called() - mock_auth_context.acquire_token.assert_not_called() - mock_auth_context.acquire_token_with_client_certificate.assert_called_once_with( - mgmt_resource, 'my app', mock.ANY, mock.ANY, public_certificate) - - @mock.patch('adal.AuthenticationContext', autospec=True) - def test_refresh_accounts_one_user_account(self, mock_auth_context): - cli = DummyCli() - storage_mock = {'subscriptions': None} - profile = Profile(cli_ctx=cli, storage=storage_mock, use_global_creds_cache=False, async_persist=False) - consolidated = profile._normalize_properties(self.user1, deepcopy([self.subscription1]), False) - profile._set_subscriptions(consolidated) - mock_auth_context.acquire_token_with_username_password.return_value = self.token_entry1 - mock_auth_context.acquire_token.return_value = self.token_entry1 - mock_arm_client = mock.MagicMock() - mock_arm_client.tenants.list.return_value = [TenantStub(self.tenant_id)] - mock_arm_client.subscriptions.list.return_value = deepcopy([self.subscription1_raw, self.subscription2_raw]) - finder = SubscriptionFinder(cli, lambda _, _1, _2: mock_auth_context, None, lambda _: mock_arm_client) - # action - profile.refresh_accounts(finder) - - # assert - result = storage_mock['subscriptions'] - self.assertEqual(2, len(result)) - self.assertEqual(self.id1.split('/')[-1], result[0]['id']) - self.assertEqual(self.id2.split('/')[-1], result[1]['id']) - self.assertTrue(result[0]['isDefault']) - - @mock.patch('adal.AuthenticationContext', autospec=True) - def test_refresh_accounts_one_user_account_one_sp_account(self, mock_auth_context): - cli = DummyCli() - storage_mock = {'subscriptions': None} - profile = Profile(cli_ctx=cli, storage=storage_mock, use_global_creds_cache=False, async_persist=False) - sp_subscription1 = SubscriptionStub('sp-sub/3', 'foo-subname', self.state1, 'foo_tenant.onmicrosoft.com') - consolidated = profile._normalize_properties(self.user1, deepcopy([self.subscription1]), False) - consolidated += profile._normalize_properties('http://foo', [sp_subscription1], True) - profile._set_subscriptions(consolidated) - mock_auth_context.acquire_token_with_username_password.return_value = self.token_entry1 - mock_auth_context.acquire_token.return_value = self.token_entry1 - mock_auth_context.acquire_token_with_client_credentials.return_value = self.token_entry1 - mock_arm_client = mock.MagicMock() - mock_arm_client.tenants.list.return_value = [TenantStub(self.tenant_id)] - mock_arm_client.subscriptions.list.side_effect = deepcopy([[self.subscription1], [self.subscription2, sp_subscription1]]) - finder = SubscriptionFinder(cli, lambda _, _1, _2: mock_auth_context, None, lambda _: mock_arm_client) - profile._creds_cache.retrieve_cred_for_service_principal = lambda _: 'verySecret' - profile._creds_cache.flush_to_disk = lambda _: '' - # action - profile.refresh_accounts(finder) - - # assert - result = storage_mock['subscriptions'] - self.assertEqual(3, len(result)) - self.assertEqual(self.id1.split('/')[-1], result[0]['id']) - self.assertEqual(self.id2.split('/')[-1], result[1]['id']) - self.assertEqual('3', result[2]['id']) - self.assertTrue(result[0]['isDefault']) - - @mock.patch('adal.AuthenticationContext', autospec=True) - def test_refresh_accounts_with_nothing(self, mock_auth_context): - cli = DummyCli() - storage_mock = {'subscriptions': None} - profile = Profile(cli_ctx=cli, storage=storage_mock, use_global_creds_cache=False, async_persist=False) - consolidated = profile._normalize_properties(self.user1, deepcopy([self.subscription1]), False) - profile._set_subscriptions(consolidated) - mock_auth_context.acquire_token_with_username_password.return_value = self.token_entry1 - mock_auth_context.acquire_token.return_value = self.token_entry1 - mock_arm_client = mock.MagicMock() - mock_arm_client.tenants.list.return_value = [TenantStub(self.tenant_id)] - mock_arm_client.subscriptions.list.return_value = [] - finder = SubscriptionFinder(cli, lambda _, _1, _2: mock_auth_context, None, lambda _: mock_arm_client) - # action - profile.refresh_accounts(finder) - - # assert - result = storage_mock['subscriptions'] - self.assertEqual(0, len(result)) - - @mock.patch('azure.cli.core._profile._load_tokens_from_file', autospec=True) - def test_credscache_load_tokens_and_sp_creds_with_secret(self, mock_read_file): - cli = DummyCli() - test_sp = { - "servicePrincipalId": "myapp", - "servicePrincipalTenant": "mytenant", - "accessToken": "Secret" - } - mock_read_file.return_value = [self.token_entry1, test_sp] - - # action - creds_cache = CredsCache(cli, async_persist=False) - - # assert - token_entries = [entry for _, entry in creds_cache.load_adal_token_cache().read_items()] - self.assertEqual(token_entries, [self.token_entry1]) - self.assertEqual(creds_cache._service_principal_creds, [test_sp]) - - @mock.patch('azure.cli.core._profile._load_tokens_from_file', autospec=True) - def test_credscache_load_tokens_and_sp_creds_with_cert(self, mock_read_file): - cli = DummyCli() - test_sp = { - "servicePrincipalId": "myapp", - "servicePrincipalTenant": "mytenant", - "certificateFile": 'junkcert.pem' - } - mock_read_file.return_value = [test_sp] - - # action - creds_cache = CredsCache(cli, async_persist=False) - creds_cache.load_adal_token_cache() - - # assert - self.assertEqual(creds_cache._service_principal_creds, [test_sp]) - - @mock.patch('azure.cli.core._profile._load_tokens_from_file', autospec=True) - def test_credscache_retrieve_sp_cred(self, mock_read_file): - cli = DummyCli() - test_cache = [ - { - "servicePrincipalId": "myapp", - "servicePrincipalTenant": "mytenant", - "accessToken": "Secret" - }, - { - "servicePrincipalId": "myapp2", - "servicePrincipalTenant": "mytenant", - "certificateFile": 'junkcert.pem' - } - ] - mock_read_file.return_value = test_cache - - # action - creds_cache = CredsCache(cli, async_persist=False) - creds_cache.load_adal_token_cache() - - # assert - self.assertEqual(creds_cache.retrieve_cred_for_service_principal('myapp'), 'Secret') - self.assertEqual(creds_cache.retrieve_cred_for_service_principal('myapp2'), 'junkcert.pem') - - @mock.patch('azure.cli.core._profile._load_tokens_from_file', autospec=True) - @mock.patch('os.fdopen', autospec=True) - @mock.patch('os.open', autospec=True) - def test_credscache_add_new_sp_creds(self, _, mock_open_for_write, mock_read_file): - cli = DummyCli() - test_sp = { - "servicePrincipalId": "myapp", - "servicePrincipalTenant": "mytenant", - "accessToken": "Secret" - } - test_sp2 = { - "servicePrincipalId": "myapp2", - "servicePrincipalTenant": "mytenant2", - "accessToken": "Secret2" - } - mock_open_for_write.return_value = FileHandleStub() - mock_read_file.return_value = [self.token_entry1, test_sp] - creds_cache = CredsCache(cli, async_persist=False) - - # action - creds_cache.save_service_principal_cred(test_sp2) - - # assert - token_entries = [e for _, e in creds_cache.adal_token_cache.read_items()] # noqa: F812 - self.assertEqual(token_entries, [self.token_entry1]) - self.assertEqual(creds_cache._service_principal_creds, [test_sp, test_sp2]) - mock_open_for_write.assert_called_with(mock.ANY, 'w+') - - @mock.patch('azure.cli.core._profile._load_tokens_from_file', autospec=True) - @mock.patch('os.fdopen', autospec=True) - @mock.patch('os.open', autospec=True) - def test_credscache_add_preexisting_sp_creds(self, _, mock_open_for_write, mock_read_file): - cli = DummyCli() - test_sp = { - "servicePrincipalId": "myapp", - "servicePrincipalTenant": "mytenant", - "accessToken": "Secret" - } - mock_open_for_write.return_value = FileHandleStub() - mock_read_file.return_value = [test_sp] - creds_cache = CredsCache(cli, async_persist=False) - - # action - creds_cache.save_service_principal_cred(test_sp) - - # assert - self.assertEqual(creds_cache._service_principal_creds, [test_sp]) - self.assertFalse(mock_open_for_write.called) - - @mock.patch('azure.cli.core._profile._load_tokens_from_file', autospec=True) - @mock.patch('os.fdopen', autospec=True) - @mock.patch('os.open', autospec=True) - def test_credscache_add_preexisting_sp_new_secret(self, _, mock_open_for_write, mock_read_file): - cli = DummyCli() - test_sp = { - "servicePrincipalId": "myapp", - "servicePrincipalTenant": "mytenant", - "accessToken": "Secret" - } - mock_open_for_write.return_value = FileHandleStub() - mock_read_file.return_value = [test_sp] - creds_cache = CredsCache(cli, async_persist=False) - - new_creds = test_sp.copy() - new_creds['accessToken'] = 'Secret2' - # action - creds_cache.save_service_principal_cred(new_creds) - - # assert - self.assertEqual(creds_cache._service_principal_creds, [new_creds]) - self.assertTrue(mock_open_for_write.called) - - @mock.patch('azure.cli.core._profile._load_tokens_from_file', autospec=True) - @mock.patch('os.fdopen', autospec=True) - @mock.patch('os.open', autospec=True) - def test_credscache_match_service_principal_correctly(self, _, mock_open_for_write, mock_read_file): - cli = DummyCli() - test_sp = { - "servicePrincipalId": "myapp", - "servicePrincipalTenant": "mytenant", - "accessToken": "Secret" - } - mock_open_for_write.return_value = FileHandleStub() - mock_read_file.return_value = [test_sp] - factory = mock.MagicMock() - factory.side_effect = ValueError('SP was found') - creds_cache = CredsCache(cli, factory, async_persist=False) - - # action and verify(we plant an exception to throw after the SP was found; so if the exception is thrown, - # we know the matching did go through) - self.assertRaises(ValueError, creds_cache.retrieve_token_for_service_principal, 'myapp', 'resource1', 'mytenant', False) - - @mock.patch('azure.cli.core._profile._load_tokens_from_file', autospec=True) - @mock.patch('os.fdopen', autospec=True) - @mock.patch('os.open', autospec=True) - def test_credscache_remove_creds(self, _, mock_open_for_write, mock_read_file): - cli = DummyCli() - test_sp = { - "servicePrincipalId": "myapp", - "servicePrincipalTenant": "mytenant", - "accessToken": "Secret" - } - mock_open_for_write.return_value = FileHandleStub() - mock_read_file.return_value = [self.token_entry1, test_sp] - creds_cache = CredsCache(cli, async_persist=False) - - # action #1, logout a user - creds_cache.remove_cached_creds(self.user1) - - # assert #1 - token_entries = [e for _, e in creds_cache.adal_token_cache.read_items()] # noqa: F812 - self.assertEqual(token_entries, []) - - # action #2 logout a service principal - creds_cache.remove_cached_creds('myapp') - - # assert #2 - self.assertEqual(creds_cache._service_principal_creds, []) - - mock_open_for_write.assert_called_with(mock.ANY, 'w+') - self.assertEqual(mock_open_for_write.call_count, 2) - - @mock.patch('azure.cli.core._profile._load_tokens_from_file', autospec=True) - @mock.patch('os.fdopen', autospec=True) - @mock.patch('os.open', autospec=True) - @mock.patch('adal.AuthenticationContext', autospec=True) - def test_credscache_new_token_added_by_adal(self, mock_adal_auth_context, _, mock_open_for_write, mock_read_file): # pylint: disable=line-too-long - cli = DummyCli() - token_entry2 = { - "accessToken": "new token", - "tokenType": "Bearer", - "userId": self.user1 - } - - def acquire_token_side_effect(*args): # pylint: disable=unused-argument - creds_cache.adal_token_cache.has_state_changed = True - return token_entry2 - - def get_auth_context(_, authority, **kwargs): # pylint: disable=unused-argument - mock_adal_auth_context.cache = kwargs['cache'] - return mock_adal_auth_context - - mock_adal_auth_context.acquire_token.side_effect = acquire_token_side_effect - mock_open_for_write.return_value = FileHandleStub() - mock_read_file.return_value = [self.token_entry1] - creds_cache = CredsCache(cli, auth_ctx_factory=get_auth_context, async_persist=False) - - # action - mgmt_resource = 'https://management.core.windows.net/' - token_type, token, _ = creds_cache.retrieve_token_for_user(self.user1, self.tenant_id, - mgmt_resource) - mock_adal_auth_context.acquire_token.assert_called_once_with( - 'https://management.core.windows.net/', - self.user1, - mock.ANY) - - # assert - mock_open_for_write.assert_called_with(mock.ANY, 'w+') - self.assertEqual(token, 'new token') - self.assertEqual(token_type, token_entry2['tokenType']) - - @mock.patch('azure.cli.core._profile.get_file_json', autospec=True) - def test_credscache_good_error_on_file_corruption(self, mock_read_file): - mock_read_file.side_effect = ValueError('a bad error for you') - cli = DummyCli() - - # action - creds_cache = CredsCache(cli, async_persist=False) - - # assert - with self.assertRaises(CLIError) as context: - creds_cache.load_adal_token_cache() - - self.assertTrue(re.findall(r'bad error for you', str(context.exception))) - - def test_service_principal_auth_client_secret(self): - sp_auth = ServicePrincipalAuth('verySecret!') - result = sp_auth.get_entry_to_persist('sp_id1', 'tenant1') - self.assertEqual(result, { - 'servicePrincipalId': 'sp_id1', - 'servicePrincipalTenant': 'tenant1', - 'accessToken': 'verySecret!' - }) - - def test_service_principal_auth_client_cert(self): - curr_dir = os.path.dirname(os.path.realpath(__file__)) - test_cert_file = os.path.join(curr_dir, 'sp_cert.pem') - sp_auth = ServicePrincipalAuth(test_cert_file) - - result = sp_auth.get_entry_to_persist('sp_id1', 'tenant1') - self.assertEqual(result, { - 'servicePrincipalId': 'sp_id1', - 'servicePrincipalTenant': 'tenant1', - 'certificateFile': test_cert_file, - 'thumbprint': 'F0:6A:53:84:8B:BE:71:4A:42:90:D6:9D:33:52:79:C1:D0:10:73:FD' - }) - - def test_detect_adfs_authority_url(self): - cli = DummyCli() - adfs_url_1 = 'https://adfs.redmond.ext-u15f2402.masd.stbtest.microsoft.com/adfs/' - cli.cloud.endpoints.active_directory = adfs_url_1 - storage_mock = {'subscriptions': None} - profile = Profile(cli_ctx=cli, storage=storage_mock, use_global_creds_cache=False, async_persist=False) - - # test w/ trailing slash - r = profile.auth_ctx_factory(cli, 'common', None) - self.assertEqual(r.authority.url, adfs_url_1.rstrip('/')) - - # test w/o trailing slash - adfs_url_2 = 'https://adfs.redmond.ext-u15f2402.masd.stbtest.microsoft.com/adfs' - cli.cloud.endpoints.active_directory = adfs_url_2 - r = profile.auth_ctx_factory(cli, 'common', None) - self.assertEqual(r.authority.url, adfs_url_2) - - # test w/ regular aad - aad_url = 'https://login.microsoftonline.com' - cli.cloud.endpoints.active_directory = aad_url - r = profile.auth_ctx_factory(cli, 'common', None) - self.assertEqual(r.authority.url, aad_url + '/common') - - -class FileHandleStub(object): # pylint: disable=too-few-public-methods - - def write(self, content): - pass - - def __enter__(self): - return self - - def __exit__(self, _2, _3, _4): - pass - - -class SubscriptionStub(Subscription): # pylint: disable=too-few-public-methods - - def __init__(self, id, display_name, state, tenant_id=None): # pylint: disable=redefined-builtin - policies = SubscriptionPolicies() - policies.spending_limit = SpendingLimit.current_period_off - policies.quota_id = 'some quota' - super(SubscriptionStub, self).__init__(subscription_policies=policies, authorization_source='some_authorization_source') - self.id = id - self.subscription_id = id.split('/')[1] - self.display_name = display_name - self.state = state - # for a SDK Subscription, tenant_id isn't present - # for a _find_using_specific_tenant Subscription, tenant_id means token tenant id - if tenant_id: - self.tenant_id = tenant_id - - -class TenantStub(object): # pylint: disable=too-few-public-methods - - def __init__(self, tenant_id): - self.tenant_id = tenant_id - - -class MSRestAzureAuthStub: - def __init__(self, *args, **kwargs): - self._token = { - 'token_type': 'Bearer', - 'access_token': TestProfile.test_msi_access_token - } - self.set_token_invoked_count = 0 - self.token_read_count = 0 - self.client_id = kwargs.get('client_id') - self.object_id = kwargs.get('object_id') - self.msi_res_id = kwargs.get('msi_res_id') - - def set_token(self): - self.set_token_invoked_count += 1 - - @property - def token(self): - self.token_read_count += 1 - return self._token - - @token.setter - def token(self, value): - self._token = value - - -if __name__ == '__main__': - unittest.main() diff --git a/src/azure-cli-core/azure/cli/core/tests/test_util.py b/src/azure-cli-core/azure/cli/core/tests/test_util.py index 463ded643e0..66aedfe2b8d 100644 --- a/src/azure-cli-core/azure/cli/core/tests/test_util.py +++ b/src/azure-cli-core/azure/cli/core/tests/test_util.py @@ -380,47 +380,6 @@ def test_send_raw_requests(self, send_mock, get_raw_token_mock): request = send_mock.call_args[0][1] self.assertEqual(request.headers['User-Agent'], get_az_rest_user_agent() + ' env-ua ARG-UA') - def test_scopes_to_resource(self): - from azure.cli.core.util import scopes_to_resource - # scopes as a list - self.assertEqual(scopes_to_resource(['https://management.core.windows.net//.default']), - 'https://management.core.windows.net/') - # scopes as a tuple - self.assertEqual(scopes_to_resource(('https://storage.azure.com/.default',)), - 'https://storage.azure.com') - - # resource with trailing slash - self.assertEqual(scopes_to_resource(('https://management.azure.com//.default',)), - 'https://management.azure.com/') - self.assertEqual(scopes_to_resource(['https://datalake.azure.net//.default']), - 'https://datalake.azure.net/') - - # resource without trailing slash - self.assertEqual(scopes_to_resource(('https://managedhsm.azure.com/.default',)), - 'https://managedhsm.azure.com') - - # VM SSH - self.assertEqual(scopes_to_resource(["https://pas.windows.net/CheckMyAccess/Linux/.default"]), - 'https://pas.windows.net/CheckMyAccess/Linux') - self.assertEqual(scopes_to_resource(["https://pas.windows.net/CheckMyAccess/Linux/user_impersonation"]), - 'https://pas.windows.net/CheckMyAccess/Linux') - - def test_resource_to_scopes(self): - from azure.cli.core.util import resource_to_scopes - # resource converted to a scopes list - self.assertEqual(resource_to_scopes('https://management.core.windows.net/'), - ['https://management.core.windows.net//.default']) - - # resource with trailing slash - self.assertEqual(resource_to_scopes('https://management.azure.com/'), - ['https://management.azure.com//.default']) - self.assertEqual(resource_to_scopes('https://datalake.azure.net/'), - ['https://datalake.azure.net//.default']) - - # resource without trailing slash - self.assertEqual(resource_to_scopes('https://managedhsm.azure.com'), - ['https://managedhsm.azure.com/.default']) - @mock.patch("psutil.Process") def test_get_parent_proc_name(self, mock_process_type): process = mock_process_type.return_value diff --git a/src/azure-cli-core/azure/cli/core/util.py b/src/azure-cli-core/azure/cli/core/util.py index 13dccd8645e..524c45cf85e 100644 --- a/src/azure-cli-core/azure/cli/core/util.py +++ b/src/azure-cli-core/azure/cli/core/util.py @@ -1191,43 +1191,6 @@ def handle_version_update(): logger.warning(ex) -def resource_to_scopes(resource): - """Convert the ADAL resource ID to MSAL scopes by appending the /.default suffix and return a list. - For example: - 'https://management.core.windows.net/' -> ['https://management.core.windows.net//.default'] - 'https://managedhsm.azure.com' -> ['https://managedhsm.azure.com/.default'] - - :param resource: The ADAL resource ID - :return: A list of scopes - """ - # https://docs.microsoft.com/en-us/azure/active-directory/develop/v2-permissions-and-consent#trailing-slash-and-default - # We should not trim the trailing slash, like in https://management.azure.com/ - # In other word, the trailing slash should be preserved and scope should be https://management.azure.com//.default - scope = resource + '/.default' - return [scope] - - -def scopes_to_resource(scopes): - """Convert MSAL scopes to ADAL resource by stripping the /.default suffix and return a str. - For example: - ['https://management.core.windows.net//.default'] -> 'https://management.core.windows.net/' - ['https://managedhsm.azure.com/.default'] -> 'https://managedhsm.azure.com' - - :param scopes: The MSAL scopes. It can be a list or tuple of string - :return: The ADAL resource - :rtype: str - """ - scope = scopes[0] - - suffixes = ['/.default', '/user_impersonation'] - - for s in suffixes: - if scope.endswith(s): - return scope[:-len(s)] - - return scope - - def _get_parent_proc_name(): # Un-cached function to get parent process name. try: diff --git a/src/azure-cli-core/setup.py b/src/azure-cli-core/setup.py index 28588e94247..3305ea6bb76 100644 --- a/src/azure-cli-core/setup.py +++ b/src/azure-cli-core/setup.py @@ -43,7 +43,6 @@ ] DEPENDENCIES = [ - 'adal~=1.2.7', 'argcomplete~=1.8', 'azure-cli-telemetry==1.0.6.*', 'azure-mgmt-core>=1.2.0,<2', @@ -51,7 +50,7 @@ 'humanfriendly>=4.7,<10.0', 'jmespath', 'knack~=0.8.2', - 'msal>=1.10.0,<2.0.0', + 'msal>=1.15.0,<2.0.0', 'paramiko>=2.0.8,<3.0.0', 'pkginfo>=1.5.0.1', 'PyJWT>=2.1.0', @@ -82,5 +81,5 @@ packages=find_packages(exclude=["*.tests", "*.tests.*", "tests.*", "tests", "azure", "azure.cli"]), install_requires=DEPENDENCIES, python_requires='>=3.6.0', - package_data={'azure.cli.core': ['auth_landing_pages/*.html']} + package_data={'azure.cli.core': ['auth/landing_pages/*.html']} ) diff --git a/src/azure-cli-testsdk/azure/cli/testsdk/patches.py b/src/azure-cli-testsdk/azure/cli/testsdk/patches.py index d59833dfa16..0d331179309 100644 --- a/src/azure-cli-testsdk/azure/cli/testsdk/patches.py +++ b/src/azure-cli-testsdk/azure/cli/testsdk/patches.py @@ -40,16 +40,19 @@ def _handle_main_exception(ex, *args, **kwargs): # pylint: disable=unused-argum def patch_load_cached_subscriptions(unit_test): def _handle_load_cached_subscription(*args, **kwargs): # pylint: disable=unused-argument - return [{ - "id": MOCKED_SUBSCRIPTION_ID, - "user": { - "name": MOCKED_USER_NAME, - "type": "user" - }, - "state": "Enabled", - "name": "Example", - "tenantId": MOCKED_TENANT_ID, - "isDefault": True}] + return [ + { + "id": MOCKED_SUBSCRIPTION_ID, + "state": "Enabled", + "name": "Example", + "tenantId": MOCKED_TENANT_ID, + "isDefault": True, + "user": { + "name": MOCKED_USER_NAME, + "type": "user" + } + } + ] mock_in_unit_test(unit_test, 'azure.cli.core._profile.Profile.load_cached_subscriptions', @@ -57,21 +60,20 @@ def _handle_load_cached_subscription(*args, **kwargs): # pylint: disable=unused def patch_retrieve_token_for_user(unit_test): - def _retrieve_token_for_user(*args, **kwargs): # pylint: disable=unused-argument - import datetime - fake_token = 'top-secret-token-for-you' - return 'Bearer', fake_token, { - "tokenType": "Bearer", - "expiresIn": 3600, - "expiresOn": (datetime.datetime.now() + datetime.timedelta(hours=1)).strftime("%Y-%m-%d %H:%M:%S.%f"), - "resource": args[3], - "accessToken": fake_token, - "refreshToken": fake_token - } - mock_in_unit_test(unit_test, - 'azure.cli.core._profile.CredsCache.retrieve_token_for_user', - _retrieve_token_for_user) + class UserCredentialMock: + + def __init__(self, *args, **kwargs): + pass + + def get_token(*args, **kwargs): # pylint: disable=unused-argument + from azure.core.credentials import AccessToken + import time + fake_raw_token = 'top-secret-token-for-you' + now = int(time.time()) + return AccessToken(fake_raw_token, now + 3600) + + mock_in_unit_test(unit_test, 'azure.cli.core.auth.identity.UserCredential', UserCredentialMock) def patch_long_run_operation_delay(unit_test): diff --git a/src/azure-cli/azure/cli/command_modules/acs/custom.py b/src/azure-cli/azure/cli/command_modules/acs/custom.py index 2f01654f651..0cbe0bd5f4c 100644 --- a/src/azure-cli/azure/cli/command_modules/acs/custom.py +++ b/src/azure-cli/azure/cli/command_modules/acs/custom.py @@ -3365,17 +3365,7 @@ def _get_command_context(command_files): def _get_dataplane_aad_token(cli_ctx, serverAppId): # this function is mostly copied from keyvault cli - import adal - try: - return Profile(cli_ctx=cli_ctx).get_raw_token(resource=serverAppId)[0][2].get('accessToken') - except adal.AdalError as err: - # pylint: disable=no-member - if (hasattr(err, 'error_response') and - ('error_description' in err.error_response) and - ('AADSTS70008:' in err.error_response['error_description'])): - raise CLIError( - "Credentials have expired due to inactivity. Please run 'az login'") - raise CLIError(err) + return Profile(cli_ctx=cli_ctx).get_raw_token(resource=serverAppId)[0][2].get('accessToken') DEV_SPACES_EXTENSION_NAME = 'dev-spaces' diff --git a/src/azure-cli/azure/cli/command_modules/appservice/tests/latest/test_app_service_environment_commands_thru_mock.py b/src/azure-cli/azure/cli/command_modules/appservice/tests/latest/test_app_service_environment_commands_thru_mock.py index df8712087c5..98ead035f6b 100644 --- a/src/azure-cli/azure/cli/command_modules/appservice/tests/latest/test_app_service_environment_commands_thru_mock.py +++ b/src/azure-cli/azure/cli/command_modules/appservice/tests/latest/test_app_service_environment_commands_thru_mock.py @@ -14,7 +14,6 @@ from azure.mgmt.web import WebSiteManagementClient from azure.mgmt.web.models import HostingEnvironmentProfile from azure.mgmt.network.models import (Subnet, RouteTable, Route, NetworkSecurityGroup, SecurityRule, Delegation) -from azure.cli.core.adal_authentication import AdalAuthentication from azure.cli.command_modules.appservice.appservice_environment import (show_appserviceenvironment, list_appserviceenvironments, @@ -30,7 +29,7 @@ def setUp(self): self.mock_logger = mock.MagicMock() self.mock_cmd = mock.MagicMock() self.mock_cmd.cli_ctx = mock.MagicMock() - self.client = WebSiteManagementClient(AdalAuthentication(lambda: ('bearer', 'secretToken')), '123455678') + self.client = WebSiteManagementClient(mock.MagicMock(), '123455678') @mock.patch('azure.cli.command_modules.appservice.appservice_environment._get_ase_client_factory', autospec=True) def test_app_service_environment_show(self, ase_client_factory_mock): diff --git a/src/azure-cli/azure/cli/command_modules/appservice/tests/latest/test_functionapp_commands_thru_mock.py b/src/azure-cli/azure/cli/command_modules/appservice/tests/latest/test_functionapp_commands_thru_mock.py index cbb9f3f0cb9..d3895939513 100644 --- a/src/azure-cli/azure/cli/command_modules/appservice/tests/latest/test_functionapp_commands_thru_mock.py +++ b/src/azure-cli/azure/cli/command_modules/appservice/tests/latest/test_functionapp_commands_thru_mock.py @@ -7,7 +7,6 @@ import os from azure.mgmt.web import WebSiteManagementClient -from azure.cli.core.adal_authentication import AdalAuthentication from knack.util import CLIError from azure.cli.command_modules.appservice.custom import ( enable_zip_deploy_functionapp, @@ -34,7 +33,7 @@ def _get_test_cmd(): class TestFunctionappMocked(unittest.TestCase): def setUp(self): - self.client = WebSiteManagementClient(AdalAuthentication(lambda: ('bearer', 'secretToken')), '123455678') + self.client = WebSiteManagementClient(mock.MagicMock(), '123455678') @mock.patch('azure.cli.command_modules.appservice.custom.web_client_factory', autospec=True) @mock.patch('azure.cli.command_modules.appservice.custom.parse_resource_id') diff --git a/src/azure-cli/azure/cli/command_modules/appservice/tests/latest/test_webapp_commands_thru_mock.py b/src/azure-cli/azure/cli/command_modules/appservice/tests/latest/test_webapp_commands_thru_mock.py index a6fa5a4ef60..f0057bcd1d7 100644 --- a/src/azure-cli/azure/cli/command_modules/appservice/tests/latest/test_webapp_commands_thru_mock.py +++ b/src/azure-cli/azure/cli/command_modules/appservice/tests/latest/test_webapp_commands_thru_mock.py @@ -8,7 +8,6 @@ from msrestazure.azure_exceptions import CloudError from azure.mgmt.web import WebSiteManagementClient -from azure.cli.core.adal_authentication import AdalAuthentication from knack.util import CLIError from azure.cli.command_modules.appservice.custom import (set_deployment_user, update_git_token, add_hostname, @@ -46,7 +45,7 @@ def _get_test_cmd(): class TestWebappMocked(unittest.TestCase): def setUp(self): - self.client = WebSiteManagementClient(AdalAuthentication(lambda: ('bearer', 'secretToken')), '123455678') + self.client = WebSiteManagementClient(mock.MagicMock(), '123455678') @mock.patch('azure.cli.command_modules.appservice.custom.web_client_factory', autospec=True) def test_set_deployment_user_creds(self, client_factory_mock): diff --git a/src/azure-cli/azure/cli/command_modules/keyvault/_client_factory.py b/src/azure-cli/azure/cli/command_modules/keyvault/_client_factory.py index 2eb71090187..d1bb4d7fda9 100644 --- a/src/azure-cli/azure/cli/command_modules/keyvault/_client_factory.py +++ b/src/azure-cli/azure/cli/command_modules/keyvault/_client_factory.py @@ -151,18 +151,8 @@ def keyvault_data_plane_factory(cli_ctx, *_): version = str(get_api_version(cli_ctx, ResourceType.DATA_KEYVAULT)) def get_token(server, resource, scope): # pylint: disable=unused-argument - import adal - try: - return Profile(cli_ctx=cli_ctx).get_raw_token(resource=resource, - subscription=cli_ctx.data.get('subscription_id'))[0] - except adal.AdalError as err: - # pylint: disable=no-member - if (hasattr(err, 'error_response') and - ('error_description' in err.error_response) and - ('AADSTS70008:' in err.error_response['error_description'])): - raise CLIError( - "Credentials have expired due to inactivity. Please run 'az login'") - raise CLIError(err) + return Profile(cli_ctx=cli_ctx).get_raw_token(resource=resource, + subscription=cli_ctx.data.get('subscription_id'))[0] client = KeyVaultClient(KeyVaultAuthentication(get_token), api_version=version) @@ -188,19 +178,8 @@ def keyvault_private_data_plane_factory_v7_2_preview(cli_ctx, _): version = str(get_api_version(cli_ctx, ResourceType.DATA_PRIVATE_KEYVAULT)) def get_token(server, resource, scope): # pylint: disable=unused-argument - import adal - try: - return Profile(cli_ctx=cli_ctx).get_raw_token(resource=resource, - subscription=cli_ctx.data.get('subscription_id'))[0] - except adal.AdalError as err: - # pylint: disable=no-member - if (hasattr(err, 'error_response') and - ('error_description' in err.error_response) and - ('AADSTS70008:' in err.error_response['error_description'])): - raise CLIError( - "Credentials have expired due to inactivity. Please run 'az login'") - raise CLIError(err) - + return Profile(cli_ctx=cli_ctx).get_raw_token(resource=resource, + subscription=cli_ctx.data.get('subscription_id'))[0] client = KeyVaultClient(KeyVaultAuthentication(get_token), api_version=version) # HACK, work around the fact that KeyVault library does't take confiuration object on constructor diff --git a/src/azure-cli/azure/cli/command_modules/profile/__init__.py b/src/azure-cli/azure/cli/command_modules/profile/__init__.py index b408737bde2..ff66ec5ee20 100644 --- a/src/azure-cli/azure/cli/command_modules/profile/__init__.py +++ b/src/azure-cli/azure/cli/command_modules/profile/__init__.py @@ -58,6 +58,7 @@ def load_arguments(self, command): help="Use CLI's old authentication flow based on device code. CLI will also use this if it can't launch a browser in your behalf, e.g. in remote SSH or Cloud Shell") c.argument('use_cert_sn_issuer', action='store_true', help='used with a service principal configured with Subject Name and Issuer Authentication in order to support automatic certificate rolls') c.argument('scopes', options_list=['--scope'], nargs='+', help='Used in the /authorize request. It can cover only one static resource.') + c.argument('client_assertion', options_list=['--federated-token'], help='Federated token that can be used for OIDC token exchange.') with self.argument_context('logout') as c: c.argument('username', help='account user, if missing, logout the current active account') @@ -79,6 +80,8 @@ def load_arguments(self, command): with self.argument_context('account get-access-token') as c: c.argument('resource_type', get_enum_type(cloud_resource_types), options_list=['--resource-type'], arg_group='', help='Type of well-known resource.') + c.argument('resource', options_list=['--resource'], help='Azure resource endpoints in AAD v1.0.') + c.argument('scopes', options_list=['--scope'], nargs='*', help='Space-separated AAD scopes in AAD v2.0. Default to Azure Resource Manager.') c.argument('tenant', options_list=['--tenant', '-t'], help='Tenant ID for which the token is acquired. Only available for user and service principal account, not for MSI or Cloud Shell account') diff --git a/src/azure-cli/azure/cli/command_modules/profile/_help.py b/src/azure-cli/azure/cli/command_modules/profile/_help.py index 37e648e1fdb..779fe627797 100644 --- a/src/azure-cli/azure/cli/command_modules/profile/_help.py +++ b/src/azure-cli/azure/cli/command_modules/profile/_help.py @@ -10,25 +10,24 @@ helps['login'] = """ type: command short-summary: Log in to Azure. +long-summary: >- + By default, this command logs in with a user account. CLI will try to launch a web browser to log in interactively. + If a web browser is not available, CLI will fall back to device code login. + + To login with a service principal, specify --service-principal. examples: - name: Log in interactively. - text: > - az login + text: az login - name: Log in with user name and password. This doesn't work with Microsoft accounts or accounts that have two-factor authentication enabled. Use -p=secret if the first character of the password is '-'. - text: > - az login -u johndoe@contoso.com -p VerySecret + text: az login -u johndoe@contoso.com -p VerySecret - name: Log in with a service principal using client secret. Use -p=secret if the first character of the password is '-'. - text: > - az login --service-principal -u http://azure-cli-2016-08-05-14-31-15 -p VerySecret --tenant contoso.onmicrosoft.com + text: az login --service-principal -u http://azure-cli-2016-08-05-14-31-15 -p VerySecret --tenant contoso.onmicrosoft.com - name: Log in with a service principal using client certificate. - text: > - az login --service-principal -u http://azure-cli-2016-08-05-14-31-15 -p ~/mycertfile.pem --tenant contoso.onmicrosoft.com - - name: Log in using a VM's system assigned identity - text: > - az login --identity - - name: Log in using a VM's user assigned identity. Client or object ids of the service identity also work - text: > - az login --identity -u /subscriptions//resourcegroups/myRG/providers/Microsoft.ManagedIdentity/userAssignedIdentities/myID + text: az login --service-principal -u http://azure-cli-2016-08-05-14-31-15 -p ~/mycertfile.pem --tenant contoso.onmicrosoft.com + - name: Log in using a VM's system-assigned managed identity. + text: az login --identity + - name: Log in using a VM's user-assigned managed identity. Client or object ids of the service identity also work. + text: az login --identity -u /subscriptions//resourcegroups/myRG/providers/Microsoft.ManagedIdentity/userAssignedIdentities/myID """ helps['account'] = """ diff --git a/src/azure-cli/azure/cli/command_modules/profile/custom.py b/src/azure-cli/azure/cli/command_modules/profile/custom.py index 9b6f1ce869e..bafbdf24471 100644 --- a/src/azure-cli/azure/cli/command_modules/profile/custom.py +++ b/src/azure-cli/azure/cli/command_modules/profile/custom.py @@ -61,37 +61,29 @@ def show_subscription(cmd, subscription=None, show_auth_for_sdk=None): return profile.get_subscription(subscription) -def get_access_token(cmd, subscription=None, resource=None, resource_type=None, tenant=None): +def get_access_token(cmd, subscription=None, resource=None, scopes=None, resource_type=None, tenant=None): """ - get AAD token to access to a specified resource - :param resource: Azure resource endpoints. Default to Azure Resource Manager - :param resource-type: Name of Azure resource endpoints. Can be used instead of resource. + get AAD token to access to a specified resource. Use 'az cloud show' command for other Azure resources """ - if resource is None and resource_type is not None: + if resource is None and resource_type: endpoints_attr_name = cloud_resource_type_mappings[resource_type] resource = getattr(cmd.cli_ctx.cloud.endpoints, endpoints_attr_name) - else: - resource = (resource or cmd.cli_ctx.cloud.endpoints.active_directory_resource_id) - profile = Profile(cli_ctx=cmd.cli_ctx) - creds, subscription, tenant = profile.get_raw_token(subscription=subscription, resource=resource, tenant=tenant) - token_entry = creds[2] - # MSIAuthentication's token entry has `expires_on`, while ADAL's token entry has `expiresOn` - # Unify to ISO `expiresOn`, like "2020-06-30 06:14:41" - if 'expires_on' in token_entry: - # https://docs.python.org/3.8/library/datetime.html#strftime-and-strptime-format-codes - token_entry['expiresOn'] = _fromtimestamp(int(token_entry['expires_on']))\ - .strftime("%Y-%m-%d %H:%M:%S.%f") + profile = Profile(cli_ctx=cmd.cli_ctx) + creds, subscription, tenant = profile.get_raw_token(subscription=subscription, resource=resource, scopes=scopes, + tenant=tenant) result = { 'tokenType': creds[0], 'accessToken': creds[1], - 'expiresOn': creds[2].get('expiresOn', 'N/A'), + # 'expires_on': creds[2].get('expires_on', None), + 'expiresOn': creds[2].get('expiresOn', None), 'tenant': tenant } if subscription: result['subscription'] = subscription + return result @@ -111,12 +103,10 @@ def account_clear(cmd): profile.logout_all() -# pylint: disable=inconsistent-return-statements +# pylint: disable=inconsistent-return-statements, too-many-branches def login(cmd, username=None, password=None, service_principal=None, tenant=None, allow_no_subscriptions=False, - identity=False, use_device_code=False, use_cert_sn_issuer=None, scopes=None): + identity=False, use_device_code=False, use_cert_sn_issuer=None, scopes=None, client_assertion=None): """Log in to access Azure subscriptions""" - from adal.adal_error import AdalError - import requests # quick argument usage check if any([password, service_principal, tenant]) and identity: @@ -130,17 +120,17 @@ def login(cmd, username=None, password=None, service_principal=None, tenant=None interactive = False - profile = Profile(cli_ctx=cmd.cli_ctx, async_persist=False) + profile = Profile(cli_ctx=cmd.cli_ctx) if identity: if in_cloud_console(): - return profile.find_subscriptions_in_cloud_console() - return profile.find_subscriptions_in_vm_with_msi(username, allow_no_subscriptions) + return profile.login_in_cloud_shell() + return profile.login_with_managed_identity(username, allow_no_subscriptions) if in_cloud_console(): # tell users they might not need login logger.warning(_CLOUD_CONSOLE_LOGIN_WARNING) if username: - if not password: + if not (password or client_assertion): try: password = prompt_pass('Password: ') except NoTTYException: @@ -148,38 +138,20 @@ def login(cmd, username=None, password=None, service_principal=None, tenant=None else: interactive = True - try: - subscriptions = profile.find_subscriptions_on_login( - interactive, - username, - password, - service_principal, - tenant, - scopes=scopes, - use_device_code=use_device_code, - allow_no_subscriptions=allow_no_subscriptions, - use_cert_sn_issuer=use_cert_sn_issuer) - except AdalError as err: - # try polish unfriendly server errors - if username: - msg = str(err) - suggestion = "For cross-check, try 'az login' to authenticate through browser." - if ('ID3242:' in msg) or ('Server returned an unknown AccountType' in msg): - raise CLIError("The user name might be invalid. " + suggestion) - if 'Server returned error in RSTR - ErrorCode' in msg: - raise CLIError("Logging in through command line is not supported. " + suggestion) - if 'wstrust' in msg: - raise CLIError("Authentication failed due to error of '" + msg + "' " - "This typically happens when attempting a Microsoft account, which requires " - "interactive login. Please invoke 'az login' to cross check. " - # pylint: disable=line-too-long - "More details are available at https://github.com/AzureAD/microsoft-authentication-library-for-python/wiki/Username-Password-Authentication") - raise CLIError(err) - except requests.exceptions.SSLError as err: - from azure.cli.core.util import SSLERROR_TEMPLATE - raise CLIError(SSLERROR_TEMPLATE + " Error detail: {}".format(str(err))) - except requests.exceptions.ConnectionError as err: - raise CLIError('Please ensure you have network connection. Error detail: ' + str(err)) + if service_principal: + from azure.cli.core.auth.identity import ServicePrincipalAuth + password = ServicePrincipalAuth.build_credential(password, client_assertion, use_cert_sn_issuer) + + subscriptions = profile.login( + interactive, + username, + password, + service_principal, + tenant, + scopes=scopes, + use_device_code=use_device_code, + allow_no_subscriptions=allow_no_subscriptions, + use_cert_sn_issuer=use_cert_sn_issuer) all_subscriptions = list(subscriptions) for sub in all_subscriptions: sub['cloudName'] = sub.pop('environmentName', None) @@ -232,13 +204,3 @@ def check_cli(cmd): print('CLI self-test completed: OK') else: raise CLIError(exceptions) - - -def _fromtimestamp(t): - # datetime.datetime can't be patched: - # TypeError: can't set attributes of built-in/extension type 'datetime.datetime' - # So we wrap datetime.datetime.fromtimestamp with this function. - # https://docs.python.org/3/library/unittest.mock-examples.html#partial-mocking - # https://williambert.online/2011/07/how-to-unit-testing-in-django-with-mocking-and-patching/ - from datetime import datetime - return datetime.fromtimestamp(t) diff --git a/src/azure-cli/azure/cli/command_modules/profile/tests/latest/test_auth_e2e.py b/src/azure-cli/azure/cli/command_modules/profile/tests/latest/test_auth_e2e.py index 46803fe0236..efb135d7779 100644 --- a/src/azure-cli/azure/cli/command_modules/profile/tests/latest/test_auth_e2e.py +++ b/src/azure-cli/azure/cli/command_modules/profile/tests/latest/test_auth_e2e.py @@ -3,9 +3,9 @@ # Licensed under the MIT License. See License.txt in the project root for license information. # -------------------------------------------------------------------------------------------- +from azure.cli.core.auth.util import decode_access_token from azure.cli.core.azclierror import AuthenticationError from azure.cli.testsdk import LiveScenarioTest -from azure.cli.core.auth.util import decode_access_token ARM_URL = "https://eastus2euap.management.azure.com/" # ARM canary ARM_MAX_RETRY = 30 @@ -25,18 +25,17 @@ def test_conditional_access_mfa(self): - doesn't require MFA for ARM - requires MFA for data-plane resource - The result ATs are checked per https://docs.microsoft.com/en-us/azure/active-directory/develop/access-tokens + The result ATs are checked per + Microsoft identity platform access tokens + https://docs.microsoft.com/en-us/azure/active-directory/develop/access-tokens Following claims are checked: - aud (Audience): https://tools.ietf.org/html/rfc7519#section-4.1.3 - amr (Authentication Method Reference): https://tools.ietf.org/html/rfc8176 """ - resource = 'https://pas.windows.net/CheckMyAccess/Linux' - scope = resource + '/.default' - + scope = 'https://pas.windows.net/CheckMyAccess/Linux/.default' self.kwargs['scope'] = scope - self.kwargs['resource'] = resource # region non-MFA session @@ -52,7 +51,7 @@ def test_conditional_access_mfa(self): # Getting data-plane AT with ARM RT (step-up) fails with self.assertRaises(AuthenticationError) as cm: - self.cmd('az account get-access-token --resource {resource}') + self.cmd('az account get-access-token --scope {scope}') # Check re-login recommendation re_login_command = 'az login --scope {scope}'.format(**self.kwargs) @@ -74,7 +73,7 @@ def test_conditional_access_mfa(self): assert decoded['amr'] == ['pwd'] # Getting data-plane AT and check claims - result = self.cmd('az account get-access-token --resource {resource}').get_output_in_json() + result = self.cmd('az account get-access-token --scope {scope}').get_output_in_json() decoded = decode_access_token(result['accessToken']) assert decoded['aud'] in scope assert decoded['amr'] == ['pwd', 'mfa'] diff --git a/src/azure-cli/azure/cli/command_modules/profile/tests/latest/test_profile_custom.py b/src/azure-cli/azure/cli/command_modules/profile/tests/latest/test_profile_custom.py index aadef8afb0e..168a121474e 100644 --- a/src/azure-cli/azure/cli/command_modules/profile/tests/latest/test_profile_custom.py +++ b/src/azure-cli/azure/cli/command_modules/profile/tests/latest/test_profile_custom.py @@ -35,14 +35,12 @@ def test_get_raw_token(self, get_raw_token_mock): cmd = mock.MagicMock() cmd.cli_ctx = DummyCli() - # arrange get_raw_token_mock.return_value = (['bearer', 'token123', {'expiresOn': '2100-01-01'}], 'sub123', 'tenant123') - # action result = get_access_token(cmd) # assert - get_raw_token_mock.assert_called_with(mock.ANY, 'https://management.core.windows.net/', None, None) + get_raw_token_mock.assert_called_with(mock.ANY, None, None, None, None) expected_result = { 'tokenType': 'bearer', 'accessToken': 'token123', @@ -58,51 +56,25 @@ def test_get_raw_token(self, get_raw_token_mock): get_raw_token_mock.return_value = (['bearer', 'token123', {'expiresOn': '2100-01-01'}], subscription_id, 'tenant123') result = get_access_token(cmd, subscription=subscription_id, resource=resource) - get_raw_token_mock.assert_called_with(mock.ANY, resource, subscription_id, None) - expected_result = { - 'tokenType': 'bearer', - 'accessToken': 'token123', - 'expiresOn': '2100-01-01', - 'subscription': subscription_id, - 'tenant': 'tenant123' - } - self.assertEqual(result, expected_result) + get_raw_token_mock.assert_called_with(mock.ANY, resource, None, subscription_id, None) + + # assert it takes customized scopes + get_access_token(cmd, scopes='https://graph.microsoft.com/.default') + get_raw_token_mock.assert_called_with(mock.ANY, None, scopes='https://graph.microsoft.com/.default', + subscription=None, tenant=None) # test get token with tenant tenant_id = '00000000-0000-0000-0000-000000000000' get_raw_token_mock.return_value = (['bearer', 'token123', {'expiresOn': '2100-01-01'}], None, tenant_id) result = get_access_token(cmd, tenant=tenant_id) - get_raw_token_mock.assert_called_with(mock.ANY, 'https://management.core.windows.net/', None, tenant_id) - expected_result = { - 'tokenType': 'bearer', - 'accessToken': 'token123', - 'expiresOn': '2100-01-01', # subscription shouldn't be present - 'tenant': tenant_id - } - self.assertEqual(result, expected_result) - - @mock.patch('azure.cli.core._profile.Profile.get_raw_token', autospec=True) - def test_get_raw_token_managed_identity(self, get_raw_token_mock): - cmd = mock.MagicMock() - cmd.cli_ctx = DummyCli() - - # test get token with Managed Identity - tenant_id = '00000000-0000-0000-0000-000000000000' - get_raw_token_mock.return_value = (['bearer', 'token123', {'expires_on': '1593497681'}], None, tenant_id) - - import datetime - # Force POSIX timestamp to be converted to datetime in UTC during testing. - with mock.patch('azure.cli.command_modules.profile.custom._fromtimestamp', datetime.datetime.utcfromtimestamp): - result = get_access_token(cmd) - - get_raw_token_mock.assert_called_with(mock.ANY, 'https://management.core.windows.net/', None, None) expected_result = { 'tokenType': 'bearer', 'accessToken': 'token123', - 'expiresOn': '2020-06-30 06:14:41.000000', + 'expiresOn': '2100-01-01', 'tenant': tenant_id } self.assertEqual(result, expected_result) + get_raw_token_mock.assert_called_with(mock.ANY, None, None, None, tenant_id) @mock.patch('azure.cli.command_modules.profile.custom.Profile', autospec=True) def test_get_login(self, profile_mock): @@ -113,7 +85,7 @@ def test_login(msi_port, identity_id=None): # mock the instance profile_instance = mock.MagicMock() - profile_instance.find_subscriptions_in_vm_with_msi = test_login + profile_instance.login_with_managed_identity = test_login # mock the constructor profile_mock.return_value = profile_instance diff --git a/src/azure-cli/requirements.py3.Darwin.txt b/src/azure-cli/requirements.py3.Darwin.txt index e031bbeaa0e..dbc6e98147c 100644 --- a/src/azure-cli/requirements.py3.Darwin.txt +++ b/src/azure-cli/requirements.py3.Darwin.txt @@ -1,4 +1,3 @@ -adal==1.2.7 antlr4-python3-runtime==4.7.2 applicationinsights==0.11.9 argcomplete==1.11.1 @@ -14,7 +13,7 @@ azure-cosmos==3.2.0 azure-datalake-store==0.0.49 azure-functions-devops-build==0.0.22 azure-graphrbac==0.60.0 -azure-identity==1.5.0 +azure-identity==1.6.1 azure-keyvault-administration==4.0.0b3 azure-keyvault-keys==4.4.0 azure-keyvault==1.1.0 @@ -111,7 +110,7 @@ jmespath==0.9.5 jsondiff==1.2.0 knack==0.8.2 MarkupSafe==1.1.1 -msal==1.10.0 +msal==1.15.0 msrest==0.6.21 msrestazure==0.6.3 oauthlib==3.0.1 diff --git a/src/azure-cli/requirements.py3.Linux.txt b/src/azure-cli/requirements.py3.Linux.txt index 2d25f42f481..ebd97a3f6d0 100644 --- a/src/azure-cli/requirements.py3.Linux.txt +++ b/src/azure-cli/requirements.py3.Linux.txt @@ -1,4 +1,3 @@ -adal==1.2.7 antlr4-python3-runtime==4.7.2 applicationinsights==0.11.9 argcomplete==1.11.1 @@ -14,7 +13,7 @@ azure-cosmos==3.2.0 azure-datalake-store==0.0.49 azure-functions-devops-build==0.0.22 azure-graphrbac==0.60.0 -azure-identity==1.5.0 +azure-identity==1.6.1 azure-keyvault-administration==4.0.0b3 azure-keyvault-keys==4.4.0 azure-keyvault==1.1.0 @@ -112,7 +111,7 @@ jmespath==0.9.5 jsondiff==1.2.0 knack==0.8.2 MarkupSafe==1.1.1 -msal==1.10.0 +msal==1.15.0 msrest==0.6.21 msrestazure==0.6.3 oauthlib==3.0.1 diff --git a/src/azure-cli/requirements.py3.windows.txt b/src/azure-cli/requirements.py3.windows.txt index 4e0334e5321..9ae8c3816f0 100644 --- a/src/azure-cli/requirements.py3.windows.txt +++ b/src/azure-cli/requirements.py3.windows.txt @@ -1,4 +1,3 @@ -adal==1.2.7 antlr4-python3-runtime==4.7.2 applicationinsights==0.11.7 argcomplete==1.11.1 @@ -14,7 +13,7 @@ azure-cosmos==3.2.0 azure-datalake-store==0.0.49 azure-functions-devops-build==0.0.22 azure-graphrbac==0.60.0 -azure-identity==1.5.0 +azure-identity==1.6.1 azure-keyvault-administration==4.0.0b3 azure-keyvault-keys==4.4.0 azure-keyvault==1.1.0 @@ -110,7 +109,7 @@ jmespath==0.9.5 jsondiff==1.2.0 knack==0.8.2 MarkupSafe==1.1.1 -msal==1.10.0 +msal==1.15.0 msrest==0.6.21 msrestazure==0.6.3 oauthlib==3.0.1