Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[azidentity] Making ChainedTokenCredential re-use the first successful credential #16392

Merged
merged 24 commits into from
Jan 11, 2022
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
709222e
wip
sadasant Dec 2, 2021
fa5e663
tests passing
sadasant Dec 2, 2021
0438cc4
last -> first
sadasant Dec 2, 2021
3f58cbc
no need for an explicit pointer for successfulCredential
sadasant Dec 2, 2021
d5d1ea7
better reference to successfulCredential
sadasant Dec 2, 2021
3715c8a
removed unnecessary else
sadasant Dec 2, 2021
f55385d
moved formatError to be within the GetToken method
sadasant Dec 2, 2021
fa7d460
Better tests
sadasant Dec 2, 2021
cf1ce47
RetryAllSources and related documentation
sadasant Dec 2, 2021
a1bbdd1
updated the logic so that successfulCredential is always populated. I…
sadasant Dec 3, 2021
dab4067
Apply suggestions from code review
sadasant Dec 3, 2021
2183d96
Merge remote-tracking branch 'Azure/main' into azidentity/fix16268
sadasant Dec 9, 2021
e288440
RetryAllSources to RetrySources
sadasant Dec 9, 2021
cdd4bb6
Apply suggestions from code review
sadasant Dec 9, 2021
e2d1201
Feedback from Charles and adding RetrySources to the DefaultAzureCred…
sadasant Dec 9, 2021
8e51b2c
undoing changes to the DefaultAzureCredential
sadasant Dec 10, 2021
20197a9
simplified tests with new test credentials
sadasant Dec 10, 2021
02bb755
Merge remote-tracking branch 'Azure/main' into azidentity/fix16268
sadasant Dec 10, 2021
e14544a
removed the formatError function
sadasant Dec 10, 2021
a19c2f2
Avoiding testing that .successfulCredential is set
sadasant Dec 10, 2021
af22a81
Apply suggestions from code review
sadasant Dec 13, 2021
2080f96
removed some redundant tests, the unavailableCredential, unexposed th…
sadasant Dec 13, 2021
ae92c66
formatting
sadasant Dec 13, 2021
0cac965
Update sdk/azidentity/chained_token_credential_test.go
sadasant Jan 11, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions sdk/azidentity/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@

### Breaking Changes

* Instances of `ChainedTokenCredential` will now skip looping through the list of source credentials and re-use the first successful credential on subsequent calls to `GetToken`.
sadasant marked this conversation as resolved.
Show resolved Hide resolved
* If the `RetryAllSources` option is set to true, it will not assume the first successful credential should be always used.
sadasant marked this conversation as resolved.
Show resolved Hide resolved

### Bugs Fixed

### Other Changes
Expand Down
43 changes: 33 additions & 10 deletions sdk/azidentity/chained_token_credential.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,20 @@ import (

// ChainedTokenCredentialOptions contains optional parameters for ChainedTokenCredential.
type ChainedTokenCredentialOptions struct {
// placeholder for future options
// If true, it will not assume the first successful credential should be always used.
sadasant marked this conversation as resolved.
Show resolved Hide resolved
RetryAllSources bool
sadasant marked this conversation as resolved.
Show resolved Hide resolved
}

// ChainedTokenCredential is a chain of credentials that enables fallback behavior when a credential can't authenticate.
type ChainedTokenCredential struct {
sources []azcore.TokenCredential
sources []azcore.TokenCredential
successfulCredential azcore.TokenCredential
retryAllSources bool
}

// NewChainedTokenCredential creates a ChainedTokenCredential.
// By default, this credential will assume that the first successful credential should be the only credential used on future requests.
sadasant marked this conversation as resolved.
Show resolved Hide resolved
// If the `RetryAllSources` option is set to true, it will always try to get a token using all of the available credentials.
sadasant marked this conversation as resolved.
Show resolved Hide resolved
// sources: Credential instances to comprise the chain. GetToken() will invoke them in the given order.
// options: Optional configuration.
func NewChainedTokenCredential(sources []azcore.TokenCredential, options *ChainedTokenCredentialOptions) (*ChainedTokenCredential, error) {
Expand All @@ -36,32 +41,50 @@ func NewChainedTokenCredential(sources []azcore.TokenCredential, options *Chaine
}
cp := make([]azcore.TokenCredential, len(sources))
copy(cp, sources)
return &ChainedTokenCredential{sources: cp}, nil
credentialOptions := ChainedTokenCredentialOptions{}
if options != nil {
credentialOptions = *options
}
return &ChainedTokenCredential{sources: cp, retryAllSources: credentialOptions.RetryAllSources}, nil
}

// GetToken calls GetToken on the chained credentials in turn, stopping when one returns a token. This method is called automatically by Azure SDK clients.
// ctx: Context controlling the request lifetime.
// opts: Options for the token request, in particular the desired scope of the access token.
func (c *ChainedTokenCredential) GetToken(ctx context.Context, opts policy.TokenRequestOptions) (token *azcore.AccessToken, err error) {
var errList []CredentialUnavailableError

formatError := func(err error) error {
sadasant marked this conversation as resolved.
Show resolved Hide resolved
var authFailed AuthenticationFailedError
if errors.As(err, &authFailed) {
err = fmt.Errorf("Authentication failed:\n%s\n%s"+createChainedErrorMessage(errList), err)
authErr := newAuthenticationFailedError(err, authFailed.RawResponse())
return authErr
}
return err
}

if c.successfulCredential != nil && !c.retryAllSources {
token, err = c.successfulCredential.GetToken(ctx, opts)
if err != nil {
sadasant marked this conversation as resolved.
Show resolved Hide resolved
return nil, formatError(err)
sadasant marked this conversation as resolved.
Show resolved Hide resolved
}
return token, nil
}
for _, cred := range c.sources {
token, err = cred.GetToken(ctx, opts)
var credErr CredentialUnavailableError
if errors.As(err, &credErr) {
errList = append(errList, credErr)
} else if err != nil {
var authFailed AuthenticationFailedError
if errors.As(err, &authFailed) {
err = fmt.Errorf("Authentication failed:\n%s\n%s"+createChainedErrorMessage(errList), err)
authErr := newAuthenticationFailedError(err, authFailed.RawResponse())
return nil, authErr
}
return nil, err
return nil, formatError(err)
} else {
logGetTokenSuccess(c, opts)
c.successfulCredential = cred
return token, nil
}
}

// if we reach this point it means that all of the credentials in the chain returned CredentialUnavailableError
credErr := newCredentialUnavailableError("Chained Token Credential", createChainedErrorMessage(errList))
// skip adding the stack trace here as it was already logged by other calls to GetToken()
Expand Down
214 changes: 214 additions & 0 deletions sdk/azidentity/chained_token_credential_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,220 @@ func TestChainedTokenCredential_GetTokenWithUnavailableCredentialInChain(t *test
}
}

func TestChainedTokenCredential_ChecksThatSuccessfulCredentialIsSet(t *testing.T) {
sadasant marked this conversation as resolved.
Show resolved Hide resolved
err := initEnvironmentVarsForTest()
if err != nil {
t.Fatalf("Could not set environment variables for testing: %v", err)
}
srv, close := mock.NewTLSServer()
defer close()
srv.AppendResponse(mock.WithBody([]byte(accessTokenRespSuccess)))
options := ClientSecretCredentialOptions{}
options.AuthorityHost = AuthorityHost(srv.URL())
options.Transport = srv
secCred, err := NewClientSecretCredential(tenantID, clientID, secret, &options)
if err != nil {
t.Fatalf("Unable to create credential. Received: %v", err)
}
envCred, err := NewEnvironmentCredential(&EnvironmentCredentialOptions{
ClientOptions: azcore.ClientOptions{Transport: srv},
AuthorityHost: AuthorityHost(srv.URL()),
})
if err != nil {
t.Fatalf("Failed to create environment credential: %v", err)
}
cred, err := NewChainedTokenCredential([]azcore.TokenCredential{secCred, envCred}, nil)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
tk, err := cred.GetToken(context.Background(), policy.TokenRequestOptions{Scopes: []string{scope}})
if err != nil {
t.Fatalf("Received an error when attempting to get a token but expected none")
}
if tk.Token != tokenValue {
t.Fatalf("Received an incorrect access token")
}
if tk.ExpiresOn.IsZero() {
t.Fatalf("Received an incorrect time in the response")
}
if cred.successfulCredential == nil {
t.Fatalf("The successful credential was not assigned")
}
if cred.successfulCredential != secCred {
t.Fatalf("The successful credential should have been the secret credential")
}
}

/**
* Helps count the number of times a credential is called.
*/
type TestCountPolicy struct{ count int }

/**
* Helps count the number of times a credential is called.
*/
func (p *TestCountPolicy) Do(req *policy.Request) (*http.Response, error) {
p.count += 1
return req.Next()
}

func TestChainedTokenCredential_RepeatedGetTokenWithSuccessfulCredential(t *testing.T) {
err := initEnvironmentVarsForTest()
if err != nil {
t.Fatalf("Could not set environment variables for testing: %v", err)
}
srv, close := mock.NewTLSServer()
defer close()

secretCountPolicy := &TestCountPolicy{}
environmentCountPolicy := &TestCountPolicy{}

srv.AppendResponse(mock.WithBody([]byte(accessTokenRespSuccess)))
srv.AppendResponse(mock.WithBody([]byte(accessTokenRespSuccess)))
options := ClientSecretCredentialOptions{}
options.AuthorityHost = AuthorityHost(srv.URL())
options.PerCallPolicies = []policy.Policy{secretCountPolicy}
options.Transport = srv
secCred, err := NewClientSecretCredential(tenantID, clientID, secret, &options)
sadasant marked this conversation as resolved.
Show resolved Hide resolved
if err != nil {
t.Fatalf("Unable to create credential. Received: %v", err)
}
envCred, err := NewEnvironmentCredential(&EnvironmentCredentialOptions{
ClientOptions: azcore.ClientOptions{
Transport: srv,
PerCallPolicies: []policy.Policy{environmentCountPolicy},
},
AuthorityHost: AuthorityHost(srv.URL()),
})
if err != nil {
t.Fatalf("Failed to create environment credential: %v", err)
}
cred, err := NewChainedTokenCredential([]azcore.TokenCredential{secCred, envCred}, nil)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
tk, err := cred.GetToken(context.Background(), policy.TokenRequestOptions{Scopes: []string{scope}})
if err != nil {
t.Fatalf("Received an error when attempting to get a token but expected none")
}
if tk.Token != tokenValue {
t.Fatalf("Received an incorrect access token")
}
if tk.ExpiresOn.IsZero() {
t.Fatalf("Received an incorrect time in the response")
}
if cred.successfulCredential == nil {
sadasant marked this conversation as resolved.
Show resolved Hide resolved
t.Fatalf("The successful credential was not assigned")
}
if cred.successfulCredential != secCred {
t.Fatalf("The successful credential should have been the secret credential")
}
if secretCountPolicy.count != 1 {
t.Fatalf("The secret credential policies should have been triggered once")
}
if environmentCountPolicy.count != 0 {
t.Fatalf("The environment credential policies should not have been triggered")
}
tk2, err2 := cred.GetToken(context.Background(), policy.TokenRequestOptions{Scopes: []string{scope}})
if err2 != nil {
t.Fatalf("Received an error when attempting to get a token but expected none. Error: %v", err2)
}
if tk2.Token != tokenValue {
t.Fatalf("Received an incorrect access token")
}
if tk2.ExpiresOn.IsZero() {
t.Fatalf("Received an incorrect time in the response")
}
if secretCountPolicy.count != 2 {
t.Fatalf("The secret credential policies should have been triggered twice")
}
if environmentCountPolicy.count != 0 {
t.Fatalf("The environment credential policies should not have been triggered")
}
}

// A credential that always throws a CredentialUnavailableError
type UnavailableCredential struct {
callCount int
}

func NewUnavailableCredential() (*UnavailableCredential, error) {
return &UnavailableCredential{callCount: 0}, nil
}
func (c *UnavailableCredential) GetToken(ctx context.Context, opts policy.TokenRequestOptions) (token *azcore.AccessToken, err error) {
c.callCount += 1
return nil, newCredentialUnavailableError("UnavailableCredential", "Expected CredentialUnavailableError")
}

func TestChainedTokenCredential_RepeatedGetTokenWithSuccessfulCredentialWithRetryAllSources(t *testing.T) {
err := initEnvironmentVarsForTest()
if err != nil {
t.Fatalf("Could not set environment variables for testing: %v", err)
}
srv, close := mock.NewTLSServer()
defer close()

secretCountPolicy := &TestCountPolicy{}

srv.AppendResponse(mock.WithBody([]byte(accessTokenRespSuccess)))
srv.AppendResponse(mock.WithBody([]byte(accessTokenRespSuccess)))

options := ClientSecretCredentialOptions{}
options.AuthorityHost = AuthorityHost(srv.URL())
options.PerCallPolicies = []policy.Policy{secretCountPolicy}
options.Transport = srv

unavailableCred, _ := NewUnavailableCredential()
secCred, err := NewClientSecretCredential(tenantID, clientID, secret, &options)
if err != nil {
t.Fatalf("Unable to create credential. Received: %v", err)
}

// Backwards order: envCred first, secCred later, to check that envCred is always called when RetryAllSources is set to true.
cred, err := NewChainedTokenCredential([]azcore.TokenCredential{unavailableCred, secCred}, &ChainedTokenCredentialOptions{RetryAllSources: true})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
tk, err := cred.GetToken(context.Background(), policy.TokenRequestOptions{Scopes: []string{scope}})
if err != nil {
t.Fatalf("Received an error when attempting to get a token but expected none. Error: %v", err)
}
if tk.Token != tokenValue {
t.Fatalf("Received an incorrect access token")
}
if tk.ExpiresOn.IsZero() {
t.Fatalf("Received an incorrect time in the response")
}
if cred.successfulCredential == nil {
t.Fatalf("The successful credential was not assigned")
}
if cred.successfulCredential != secCred {
t.Fatalf("The successful credential should have been the secret credential")
}
if secretCountPolicy.count != 1 {
t.Fatalf("The secret credential policies should have been triggered once")
}
if unavailableCred.callCount != 1 {
t.Fatalf("The environment credential policies should have been triggered once")
}
sadasant marked this conversation as resolved.
Show resolved Hide resolved
tk2, err2 := cred.GetToken(context.Background(), policy.TokenRequestOptions{Scopes: []string{scope}})
if err2 != nil {
t.Fatalf("Received an error when attempting to get a token but expected none. Error: %v", err2)
}
sadasant marked this conversation as resolved.
Show resolved Hide resolved
if tk2.Token != tokenValue {
t.Fatalf("Received an incorrect access token")
}
if tk2.ExpiresOn.IsZero() {
t.Fatalf("Received an incorrect time in the response")
}
if secretCountPolicy.count != 2 {
t.Fatalf("The secret credential policies should have been triggered twice")
}
if unavailableCred.callCount != 2 {
t.Fatalf("The environment credential policies should have been triggered twice")
}
}

func TestBearerPolicy_ChainedTokenCredential(t *testing.T) {
err := initEnvironmentVarsForTest()
if err != nil {
Expand Down