diff --git a/sdk/azidentity/CHANGELOG.md b/sdk/azidentity/CHANGELOG.md index d4b78bd6f0cb..ec382cd83165 100644 --- a/sdk/azidentity/CHANGELOG.md +++ b/sdk/azidentity/CHANGELOG.md @@ -59,6 +59,7 @@ ### Other Changes * `NewDefaultAzureCredential()` returns `*DefaultAzureCredential` instead of `*ChainedTokenCredential` +* Added `TenantID` field to `DefaultAzureCredentialOptions` and `AzureCLICredentialOptions` ## 0.11.0 (2021-09-08) ### Breaking Changes diff --git a/sdk/azidentity/azure_cli_credential.go b/sdk/azidentity/azure_cli_credential.go index bfbe9a6ab17d..044ad9017166 100644 --- a/sdk/azidentity/azure_cli_credential.go +++ b/sdk/azidentity/azure_cli_credential.go @@ -20,12 +20,15 @@ import ( ) // used by tests to fake invoking the CLI -type azureCLITokenProvider func(ctx context.Context, resource string) ([]byte, error) +type azureCLITokenProvider func(ctx context.Context, resource string, tenantID string) ([]byte, error) // AzureCLICredentialOptions contains options used to configure the AzureCLICredential // All zero-value fields will be initialized with their default values. type AzureCLICredentialOptions struct { tokenProvider azureCLITokenProvider + // TenantID identifies the tenant the credential should authenticate in. + // Defaults to the CLI's default tenant, which is typically the home tenant of the user logged in to the CLI. + TenantID string } // init returns an instance of AzureCLICredentialOptions initialized with default values. @@ -38,6 +41,7 @@ func (o *AzureCLICredentialOptions) init() { // AzureCLICredential enables authentication to Azure Active Directory using the Azure CLI command "az account get-access-token". type AzureCLICredential struct { tokenProvider azureCLITokenProvider + tenantID string } // NewAzureCLICredential constructs a new AzureCLICredential with the details needed to authenticate against Azure Active Directory @@ -50,6 +54,7 @@ func NewAzureCLICredential(options *AzureCLICredentialOptions) (*AzureCLICredent cp.init() return &AzureCLICredential{ tokenProvider: cp.tokenProvider, + tenantID: cp.TenantID, }, nil } @@ -76,7 +81,7 @@ const timeoutCLIRequest = 10000 * time.Millisecond // ctx: The current request context // scopes: The scopes for which the token has access func (c *AzureCLICredential) authenticate(ctx context.Context, resource string) (*azcore.AccessToken, error) { - output, err := c.tokenProvider(ctx, resource) + output, err := c.tokenProvider(ctx, resource, c.tenantID) if err != nil { return nil, err } @@ -84,8 +89,8 @@ func (c *AzureCLICredential) authenticate(ctx context.Context, resource string) return c.createAccessToken(output) } -func defaultTokenProvider() func(ctx context.Context, resource string) ([]byte, error) { - return func(ctx context.Context, resource string) ([]byte, error) { +func defaultTokenProvider() func(ctx context.Context, resource string, tenantID string) ([]byte, error) { + return func(ctx context.Context, resource string, tenantID string) ([]byte, error) { // This is the path that a developer can set to tell this class what the install path for Azure CLI is. const azureCLIPath = "AZURE_CLI_PATH" @@ -121,7 +126,9 @@ func defaultTokenProvider() func(ctx context.Context, resource string) ([]byte, cliCmd.Env = append(cliCmd.Env, fmt.Sprintf("PATH=%s:%s", os.Getenv(azureCLIPath), azureCLIDefaultPath)) } cliCmd.Args = append(cliCmd.Args, "account", "get-access-token", "-o", "json", "--resource", resource) - + if tenantID != "" { + cliCmd.Args = append(cliCmd.Args, "--tenant", tenantID) + } var stderr bytes.Buffer cliCmd.Stderr = &stderr diff --git a/sdk/azidentity/azure_cli_credential_test.go b/sdk/azidentity/azure_cli_credential_test.go index 9259e9e30fc2..ac908024abe7 100644 --- a/sdk/azidentity/azure_cli_credential_test.go +++ b/sdk/azidentity/azure_cli_credential_test.go @@ -15,14 +15,14 @@ import ( ) var ( - mockCLITokenProviderSuccess = func(ctx context.Context, resource string) ([]byte, error) { + mockCLITokenProviderSuccess = func(ctx context.Context, resource string, tenantID string) ([]byte, error) { return []byte(" {\"accessToken\":\"mocktoken\" , " + "\"expiresOn\": \"2007-01-01 01:01:01.079627\"," + "\"subscription\": \"mocksub\"," + "\"tenant\": \"mocktenant\"," + "\"tokenType\": \"mocktype\"}"), nil } - mockCLITokenProviderFailure = func(ctx context.Context, resource string) ([]byte, error) { + mockCLITokenProviderFailure = func(ctx context.Context, resource string, tenantID string) ([]byte, error) { return nil, errors.New("provider failure message") } ) @@ -59,6 +59,32 @@ func TestAzureCLICredential_GetTokenInvalidToken(t *testing.T) { } } +func TestAzureCLICredential_TenantID(t *testing.T) { + expected := "expected-tenant-id" + called := false + options := AzureCLICredentialOptions{ + TenantID: expected, + tokenProvider: func(ctx context.Context, resource, tenantID string) ([]byte, error) { + called = true + if tenantID != expected { + t.Fatal("Unexpected tenant ID: " + tenantID) + } + return mockCLITokenProviderSuccess(ctx, resource, tenantID) + }, + } + cred, err := NewAzureCLICredential(&options) + if err != nil { + t.Fatalf("Unable to create credential. Received: %v", err) + } + _, err = cred.GetToken(context.Background(), policy.TokenRequestOptions{Scopes: []string{scope}}) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + if !called { + t.Fatal("token provider wasn't called") + } +} + func TestBearerPolicy_AzureCLICredential(t *testing.T) { srv, close := mock.NewTLSServer() defer close() diff --git a/sdk/azidentity/default_azure_credential.go b/sdk/azidentity/default_azure_credential.go index b4f14009bea0..29bf79f4966d 100644 --- a/sdk/azidentity/default_azure_credential.go +++ b/sdk/azidentity/default_azure_credential.go @@ -23,6 +23,9 @@ type DefaultAzureCredentialOptions struct { // The host of the Azure Active Directory authority. The default is AzurePublicCloud. // Leave empty to allow overriding the value from the AZURE_AUTHORITY_HOST environment variable. AuthorityHost AuthorityHost + // TenantID identifies the tenant the Azure CLI should authenticate in. + // Defaults to the CLI's default tenant, which is typically the home tenant of the user logged in to the CLI. + TenantID string } // DefaultAzureCredential is a default credential chain for applications that will be deployed to Azure. @@ -62,7 +65,7 @@ func NewDefaultAzureCredential(options *DefaultAzureCredentialOptions) (*Default errMsg += err.Error() } - cliCred, err := NewAzureCLICredential(nil) + cliCred, err := NewAzureCLICredential(&AzureCLICredentialOptions{TenantID: options.TenantID}) if err == nil { creds = append(creds, cliCred) } else {