Skip to content

Commit

Permalink
[azidentity] Making ChainedTokenCredential re-use the first successfu…
Browse files Browse the repository at this point in the history
…l credential (#16392)

This PR makes it so instances of `ChainedTokenCredential` will now re-use the first successful credential on `GetToken` calls.

Fixed #16268
  • Loading branch information
sadasant authored Jan 11, 2022
1 parent 492a8f7 commit 774c062
Show file tree
Hide file tree
Showing 3 changed files with 113 additions and 17 deletions.
6 changes: 6 additions & 0 deletions sdk/azidentity/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,12 @@

### Breaking Changes

* Instances of `ChainedTokenCredential` will now skip looping through the list of source credentials and re-use the first successful credential on subsequent calls to `GetToken`.
* If `ChainedTokenCredentialOptions.RetrySources` is true, `ChainedTokenCredential` will continue to try all of the originally provided credentials each time the `GetToken` method is called.
* `ChainedTokenCredential.successfulCredential` will contain a reference to the last successful credential.
* `DefaultAzureCredenial` will also re-use the first successful credential on subsequent calls to `GetToken`.
* `DefaultAzureCredential.chain.successfulCredential` will also contain a reference to the last successful credential.

### Bugs Fixed

### Other Changes
Expand Down
23 changes: 20 additions & 3 deletions sdk/azidentity/chained_token_credential.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,20 @@ import (

// ChainedTokenCredentialOptions contains optional parameters for ChainedTokenCredential.
type ChainedTokenCredentialOptions struct {
// placeholder for future options
// RetrySources configures how the credential uses its sources.
// When true, the credential will always request a token from each source in turn,
// stopping when one provides a token. When false, the credential requests a token
// only from the source that previously retrieved a token--it never again tries the sources which failed.
RetrySources bool
}

// ChainedTokenCredential is a chain of credentials that enables fallback behavior when a credential can't authenticate.
// By default, this credential will assume that the first successful credential should be the only credential used on future requests.
// If the `RetrySources` option is set to true, it will always try to get a token using all of the originally provided credentials.
type ChainedTokenCredential struct {
sources []azcore.TokenCredential
sources []azcore.TokenCredential
successfulCredential azcore.TokenCredential
retrySources bool
}

// NewChainedTokenCredential creates a ChainedTokenCredential.
Expand All @@ -36,14 +44,21 @@ func NewChainedTokenCredential(sources []azcore.TokenCredential, options *Chaine
}
cp := make([]azcore.TokenCredential, len(sources))
copy(cp, sources)
return &ChainedTokenCredential{sources: cp}, nil
if options == nil {
options = &ChainedTokenCredentialOptions{}
}
return &ChainedTokenCredential{sources: cp, retrySources: options.RetrySources}, nil
}

// GetToken calls GetToken on the chained credentials in turn, stopping when one returns a token. This method is called automatically by Azure SDK clients.
// ctx: Context controlling the request lifetime.
// opts: Options for the token request, in particular the desired scope of the access token.
func (c *ChainedTokenCredential) GetToken(ctx context.Context, opts policy.TokenRequestOptions) (token *azcore.AccessToken, err error) {
var errList []CredentialUnavailableError

if c.successfulCredential != nil && !c.retrySources {
return c.successfulCredential.GetToken(ctx, opts)
}
for _, cred := range c.sources {
token, err = cred.GetToken(ctx, opts)
var credErr CredentialUnavailableError
Expand All @@ -59,9 +74,11 @@ func (c *ChainedTokenCredential) GetToken(ctx context.Context, opts policy.Token
return nil, err
} else {
logGetTokenSuccess(c, opts)
c.successfulCredential = cred
return token, nil
}
}

// if we reach this point it means that all of the credentials in the chain returned CredentialUnavailableError
credErr := newCredentialUnavailableError("Chained Token Credential", createChainedErrorMessage(errList))
// skip adding the stack trace here as it was already logged by other calls to GetToken()
Expand Down
101 changes: 87 additions & 14 deletions sdk/azidentity/chained_token_credential_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,30 +113,103 @@ func TestChainedTokenCredential_GetTokenFail(t *testing.T) {
}
}

type unavailableCredential struct{}
// TestCredential response
type testCredentialResponse struct {
token *azcore.AccessToken
err error
}

func (*unavailableCredential) GetToken(ctx context.Context, opts policy.TokenRequestOptions) (token *azcore.AccessToken, err error) {
return nil, newCredentialUnavailableError("unavailableCredential", "is unavailable")
// Credential used for testing
type TestCredential struct {
getTokenCalls int
responses []testCredentialResponse
}

func TestChainedTokenCredential_GetTokenWithUnavailableCredentialInChain(t *testing.T) {
secCred, err := NewClientSecretCredential(fakeTenantID, fakeClientID, secret, nil)
func (c *TestCredential) GetToken(ctx context.Context, opts policy.TokenRequestOptions) (token *azcore.AccessToken, err error) {
index := c.getTokenCalls
c.getTokenCalls += 1
response := c.responses[index]
return response.token, response.err
}

func testGoodGetTokenResponse(t *testing.T, token *azcore.AccessToken, err error) {
if err != nil {
t.Fatalf("Unable to create credential. Received: %v", err)
t.Fatalf("Received an error when attempting to get a token but expected none")
}
if token.Token != tokenValue {
t.Fatalf("Received an incorrect access token")
}
secCred.client = fakeConfidentialClient{ar: confidential.AuthResult{AccessToken: tokenValue, ExpiresOn: time.Now().Add(time.Hour)}}
cred, err := NewChainedTokenCredential([]azcore.TokenCredential{&unavailableCredential{}, secCred}, nil)
if token.ExpiresOn.IsZero() {
t.Fatalf("Received an incorrect time in the response")
}
}

func TestChainedTokenCredential_RepeatedGetTokenWithSuccessfulCredential(t *testing.T) {
failedCredential := &TestCredential{responses: []testCredentialResponse{
{err: newCredentialUnavailableError("MockCredential", "Mocking a credential unavailable error")},
{err: newCredentialUnavailableError("MockCredential", "Mocking a credential unavailable error")},
}}
successfulCredential := &TestCredential{responses: []testCredentialResponse{
{token: &azcore.AccessToken{Token: tokenValue, ExpiresOn: time.Now()}},
{token: &azcore.AccessToken{Token: tokenValue, ExpiresOn: time.Now()}},
}}

cred, err := NewChainedTokenCredential([]azcore.TokenCredential{failedCredential, successfulCredential}, nil)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
tk, err := cred.GetToken(context.Background(), policy.TokenRequestOptions{Scopes: []string{liveTestScope}})

getTokenOptions := policy.TokenRequestOptions{Scopes: []string{liveTestScope}}

tk, err := cred.GetToken(context.Background(), getTokenOptions)
testGoodGetTokenResponse(t, tk, err)
if failedCredential.getTokenCalls != 1 {
t.Fatal("The failed credential getToken should have been called once")
}
if successfulCredential.getTokenCalls != 1 {
t.Fatalf("The successful credential getToken should have been called once")
}
tk2, err2 := cred.GetToken(context.Background(), getTokenOptions)
testGoodGetTokenResponse(t, tk2, err2)
if failedCredential.getTokenCalls != 1 {
t.Fatalf("The failed credential getToken should not have been called again")
}
if successfulCredential.getTokenCalls != 2 {
t.Fatalf("The successful credential getToken should have been called twice")
}
}

func TestChainedTokenCredential_RepeatedGetTokenWithSuccessfulCredentialWithRetrySources(t *testing.T) {
failedCredential := &TestCredential{responses: []testCredentialResponse{
{err: newCredentialUnavailableError("MockCredential", "Mocking a credential unavailable error")},
{err: newCredentialUnavailableError("MockCredential", "Mocking a credential unavailable error")},
}}
successfulCredential := &TestCredential{responses: []testCredentialResponse{
{token: &azcore.AccessToken{Token: tokenValue, ExpiresOn: time.Now()}},
{token: &azcore.AccessToken{Token: tokenValue, ExpiresOn: time.Now()}},
}}

cred, err := NewChainedTokenCredential([]azcore.TokenCredential{failedCredential, successfulCredential}, &ChainedTokenCredentialOptions{RetrySources: true})
if err != nil {
t.Fatalf("Received an error when attempting to get a token but expected none")
t.Fatalf("unexpected error: %v", err)
}
if tk.Token != tokenValue {
t.Fatalf("Received an incorrect access token")

getTokenOptions := policy.TokenRequestOptions{Scopes: []string{liveTestScope}}

tk, err := cred.GetToken(context.Background(), getTokenOptions)
testGoodGetTokenResponse(t, tk, err)
if failedCredential.getTokenCalls != 1 {
t.Fatalf("The failed credential getToken should have been called once")
}
if tk.ExpiresOn.IsZero() {
t.Fatalf("Received an incorrect time in the response")
if successfulCredential.getTokenCalls != 1 {
t.Fatalf("The successful credential getToken should have been called once")
}
tk2, err2 := cred.GetToken(context.Background(), getTokenOptions)
testGoodGetTokenResponse(t, tk2, err2)
if failedCredential.getTokenCalls != 2 {
t.Fatalf("The failed credential getToken should have been called twice")
}
if successfulCredential.getTokenCalls != 2 {
t.Fatalf("The successful credential getToken should have been called twice")
}
}

0 comments on commit 774c062

Please sign in to comment.