diff --git a/sdk/azidentity/CHANGELOG.md b/sdk/azidentity/CHANGELOG.md index 6afba175220e..a8908b91261b 100644 --- a/sdk/azidentity/CHANGELOG.md +++ b/sdk/azidentity/CHANGELOG.md @@ -6,6 +6,12 @@ ### Breaking Changes +* Instances of `ChainedTokenCredential` will now skip looping through the list of source credentials and re-use the first successful credential on subsequent calls to `GetToken`. + * If `ChainedTokenCredentialOptions.RetrySources` is true, `ChainedTokenCredential` will continue to try all of the originally provided credentials each time the `GetToken` method is called. + * `ChainedTokenCredential.successfulCredential` will contain a reference to the last successful credential. + * `DefaultAzureCredenial` will also re-use the first successful credential on subsequent calls to `GetToken`. + * `DefaultAzureCredential.chain.successfulCredential` will also contain a reference to the last successful credential. + ### Bugs Fixed ### Other Changes diff --git a/sdk/azidentity/chained_token_credential.go b/sdk/azidentity/chained_token_credential.go index 7e3c8aefbaaf..fe833d637674 100644 --- a/sdk/azidentity/chained_token_credential.go +++ b/sdk/azidentity/chained_token_credential.go @@ -14,12 +14,20 @@ import ( // ChainedTokenCredentialOptions contains optional parameters for ChainedTokenCredential. type ChainedTokenCredentialOptions struct { - // placeholder for future options + // RetrySources configures how the credential uses its sources. + // When true, the credential will always request a token from each source in turn, + // stopping when one provides a token. When false, the credential requests a token + // only from the source that previously retrieved a token--it never again tries the sources which failed. + RetrySources bool } // ChainedTokenCredential is a chain of credentials that enables fallback behavior when a credential can't authenticate. +// By default, this credential will assume that the first successful credential should be the only credential used on future requests. +// If the `RetrySources` option is set to true, it will always try to get a token using all of the originally provided credentials. type ChainedTokenCredential struct { - sources []azcore.TokenCredential + sources []azcore.TokenCredential + successfulCredential azcore.TokenCredential + retrySources bool } // NewChainedTokenCredential creates a ChainedTokenCredential. @@ -36,7 +44,10 @@ func NewChainedTokenCredential(sources []azcore.TokenCredential, options *Chaine } cp := make([]azcore.TokenCredential, len(sources)) copy(cp, sources) - return &ChainedTokenCredential{sources: cp}, nil + if options == nil { + options = &ChainedTokenCredentialOptions{} + } + return &ChainedTokenCredential{sources: cp, retrySources: options.RetrySources}, nil } // GetToken calls GetToken on the chained credentials in turn, stopping when one returns a token. This method is called automatically by Azure SDK clients. @@ -44,6 +55,10 @@ func NewChainedTokenCredential(sources []azcore.TokenCredential, options *Chaine // opts: Options for the token request, in particular the desired scope of the access token. func (c *ChainedTokenCredential) GetToken(ctx context.Context, opts policy.TokenRequestOptions) (token *azcore.AccessToken, err error) { var errList []CredentialUnavailableError + + if c.successfulCredential != nil && !c.retrySources { + return c.successfulCredential.GetToken(ctx, opts) + } for _, cred := range c.sources { token, err = cred.GetToken(ctx, opts) var credErr CredentialUnavailableError @@ -59,9 +74,11 @@ func (c *ChainedTokenCredential) GetToken(ctx context.Context, opts policy.Token return nil, err } else { logGetTokenSuccess(c, opts) + c.successfulCredential = cred return token, nil } } + // if we reach this point it means that all of the credentials in the chain returned CredentialUnavailableError credErr := newCredentialUnavailableError("Chained Token Credential", createChainedErrorMessage(errList)) // skip adding the stack trace here as it was already logged by other calls to GetToken() diff --git a/sdk/azidentity/chained_token_credential_test.go b/sdk/azidentity/chained_token_credential_test.go index c0ad339139ed..1863df3d2189 100644 --- a/sdk/azidentity/chained_token_credential_test.go +++ b/sdk/azidentity/chained_token_credential_test.go @@ -113,30 +113,103 @@ func TestChainedTokenCredential_GetTokenFail(t *testing.T) { } } -type unavailableCredential struct{} +// TestCredential response +type testCredentialResponse struct { + token *azcore.AccessToken + err error +} -func (*unavailableCredential) GetToken(ctx context.Context, opts policy.TokenRequestOptions) (token *azcore.AccessToken, err error) { - return nil, newCredentialUnavailableError("unavailableCredential", "is unavailable") +// Credential used for testing +type TestCredential struct { + getTokenCalls int + responses []testCredentialResponse } -func TestChainedTokenCredential_GetTokenWithUnavailableCredentialInChain(t *testing.T) { - secCred, err := NewClientSecretCredential(fakeTenantID, fakeClientID, secret, nil) +func (c *TestCredential) GetToken(ctx context.Context, opts policy.TokenRequestOptions) (token *azcore.AccessToken, err error) { + index := c.getTokenCalls + c.getTokenCalls += 1 + response := c.responses[index] + return response.token, response.err +} + +func testGoodGetTokenResponse(t *testing.T, token *azcore.AccessToken, err error) { if err != nil { - t.Fatalf("Unable to create credential. Received: %v", err) + t.Fatalf("Received an error when attempting to get a token but expected none") + } + if token.Token != tokenValue { + t.Fatalf("Received an incorrect access token") } - secCred.client = fakeConfidentialClient{ar: confidential.AuthResult{AccessToken: tokenValue, ExpiresOn: time.Now().Add(time.Hour)}} - cred, err := NewChainedTokenCredential([]azcore.TokenCredential{&unavailableCredential{}, secCred}, nil) + if token.ExpiresOn.IsZero() { + t.Fatalf("Received an incorrect time in the response") + } +} + +func TestChainedTokenCredential_RepeatedGetTokenWithSuccessfulCredential(t *testing.T) { + failedCredential := &TestCredential{responses: []testCredentialResponse{ + {err: newCredentialUnavailableError("MockCredential", "Mocking a credential unavailable error")}, + {err: newCredentialUnavailableError("MockCredential", "Mocking a credential unavailable error")}, + }} + successfulCredential := &TestCredential{responses: []testCredentialResponse{ + {token: &azcore.AccessToken{Token: tokenValue, ExpiresOn: time.Now()}}, + {token: &azcore.AccessToken{Token: tokenValue, ExpiresOn: time.Now()}}, + }} + + cred, err := NewChainedTokenCredential([]azcore.TokenCredential{failedCredential, successfulCredential}, nil) if err != nil { t.Fatalf("unexpected error: %v", err) } - tk, err := cred.GetToken(context.Background(), policy.TokenRequestOptions{Scopes: []string{liveTestScope}}) + + getTokenOptions := policy.TokenRequestOptions{Scopes: []string{liveTestScope}} + + tk, err := cred.GetToken(context.Background(), getTokenOptions) + testGoodGetTokenResponse(t, tk, err) + if failedCredential.getTokenCalls != 1 { + t.Fatal("The failed credential getToken should have been called once") + } + if successfulCredential.getTokenCalls != 1 { + t.Fatalf("The successful credential getToken should have been called once") + } + tk2, err2 := cred.GetToken(context.Background(), getTokenOptions) + testGoodGetTokenResponse(t, tk2, err2) + if failedCredential.getTokenCalls != 1 { + t.Fatalf("The failed credential getToken should not have been called again") + } + if successfulCredential.getTokenCalls != 2 { + t.Fatalf("The successful credential getToken should have been called twice") + } +} + +func TestChainedTokenCredential_RepeatedGetTokenWithSuccessfulCredentialWithRetrySources(t *testing.T) { + failedCredential := &TestCredential{responses: []testCredentialResponse{ + {err: newCredentialUnavailableError("MockCredential", "Mocking a credential unavailable error")}, + {err: newCredentialUnavailableError("MockCredential", "Mocking a credential unavailable error")}, + }} + successfulCredential := &TestCredential{responses: []testCredentialResponse{ + {token: &azcore.AccessToken{Token: tokenValue, ExpiresOn: time.Now()}}, + {token: &azcore.AccessToken{Token: tokenValue, ExpiresOn: time.Now()}}, + }} + + cred, err := NewChainedTokenCredential([]azcore.TokenCredential{failedCredential, successfulCredential}, &ChainedTokenCredentialOptions{RetrySources: true}) if err != nil { - t.Fatalf("Received an error when attempting to get a token but expected none") + t.Fatalf("unexpected error: %v", err) } - if tk.Token != tokenValue { - t.Fatalf("Received an incorrect access token") + + getTokenOptions := policy.TokenRequestOptions{Scopes: []string{liveTestScope}} + + tk, err := cred.GetToken(context.Background(), getTokenOptions) + testGoodGetTokenResponse(t, tk, err) + if failedCredential.getTokenCalls != 1 { + t.Fatalf("The failed credential getToken should have been called once") } - if tk.ExpiresOn.IsZero() { - t.Fatalf("Received an incorrect time in the response") + if successfulCredential.getTokenCalls != 1 { + t.Fatalf("The successful credential getToken should have been called once") + } + tk2, err2 := cred.GetToken(context.Background(), getTokenOptions) + testGoodGetTokenResponse(t, tk2, err2) + if failedCredential.getTokenCalls != 2 { + t.Fatalf("The failed credential getToken should have been called twice") + } + if successfulCredential.getTokenCalls != 2 { + t.Fatalf("The successful credential getToken should have been called twice") } }