From 709222ea4726ad30dd8622e8c29f0efae08067bb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20Rodr=C3=ADguez?= Date: Thu, 2 Dec 2021 18:50:26 +0000 Subject: [PATCH 01/22] wip --- sdk/azidentity/chained_token_credential.go | 48 ++++++---- .../chained_token_credential_test.go | 95 +++++++++++++++++++ 2 files changed, 127 insertions(+), 16 deletions(-) diff --git a/sdk/azidentity/chained_token_credential.go b/sdk/azidentity/chained_token_credential.go index 7e3c8aefbaaf..59a2a557a543 100644 --- a/sdk/azidentity/chained_token_credential.go +++ b/sdk/azidentity/chained_token_credential.go @@ -19,7 +19,8 @@ type ChainedTokenCredentialOptions struct { // ChainedTokenCredential is a chain of credentials that enables fallback behavior when a credential can't authenticate. type ChainedTokenCredential struct { - sources []azcore.TokenCredential + sources []azcore.TokenCredential + successfulCredential *azcore.TokenCredential } // NewChainedTokenCredential creates a ChainedTokenCredential. @@ -44,24 +45,29 @@ func NewChainedTokenCredential(sources []azcore.TokenCredential, options *Chaine // opts: Options for the token request, in particular the desired scope of the access token. func (c *ChainedTokenCredential) GetToken(ctx context.Context, opts policy.TokenRequestOptions) (token *azcore.AccessToken, err error) { var errList []CredentialUnavailableError - for _, cred := range c.sources { - token, err = cred.GetToken(ctx, opts) - var credErr CredentialUnavailableError - if errors.As(err, &credErr) { - errList = append(errList, credErr) - } else if err != nil { - var authFailed AuthenticationFailedError - if errors.As(err, &authFailed) { - err = fmt.Errorf("Authentication failed:\n%s\n%s"+createChainedErrorMessage(errList), err) - authErr := newAuthenticationFailedError(err, authFailed.RawResponse()) - return nil, authErr + + if c.successfulCredential != nil { + successfulCredential := *c.successfulCredential + token, err = successfulCredential.GetToken(ctx, opts) + if err != nil { + return nil, formatError(err, errList) + } + } else { + for _, cred := range c.sources { + token, err = cred.GetToken(ctx, opts) + var credErr CredentialUnavailableError + if errors.As(err, &credErr) { + errList = append(errList, credErr) + } else if err != nil { + return nil, formatError(err, errList) + } else { + logGetTokenSuccess(c, opts) + c.successfulCredential = &cred + return token, nil } - return nil, err - } else { - logGetTokenSuccess(c, opts) - return token, nil } } + // if we reach this point it means that all of the credentials in the chain returned CredentialUnavailableError credErr := newCredentialUnavailableError("Chained Token Credential", createChainedErrorMessage(errList)) // skip adding the stack trace here as it was already logged by other calls to GetToken() @@ -78,4 +84,14 @@ func createChainedErrorMessage(errList []CredentialUnavailableError) string { return msg } +func formatError(err error, errList []CredentialUnavailableError) error { + var authFailed AuthenticationFailedError + if errors.As(err, &authFailed) { + err = fmt.Errorf("Authentication failed:\n%s\n%s"+createChainedErrorMessage(errList), err) + authErr := newAuthenticationFailedError(err, authFailed.RawResponse()) + return authErr + } + return err +} + var _ azcore.TokenCredential = (*ChainedTokenCredential)(nil) diff --git a/sdk/azidentity/chained_token_credential_test.go b/sdk/azidentity/chained_token_credential_test.go index 4d32c0b05bc1..3fc5a4d3fe0d 100644 --- a/sdk/azidentity/chained_token_credential_test.go +++ b/sdk/azidentity/chained_token_credential_test.go @@ -153,6 +153,101 @@ func TestChainedTokenCredential_GetTokenWithUnavailableCredentialInChain(t *test } } +func TestChainedTokenCredential_ChecksThatSuccessfulCredentialIsSet(t *testing.T) { + err := initEnvironmentVarsForTest() + if err != nil { + t.Fatalf("Could not set environment variables for testing: %v", err) + } + srv, close := mock.NewTLSServer() + defer close() + srv.AppendResponse(mock.WithBody([]byte(accessTokenRespSuccess))) + options := ClientSecretCredentialOptions{} + options.AuthorityHost = AuthorityHost(srv.URL()) + options.Transport = srv + secCred, err := NewClientSecretCredential(tenantID, clientID, secret, &options) + if err != nil { + t.Fatalf("Unable to create credential. Received: %v", err) + } + envCred, err := NewEnvironmentCredential(&EnvironmentCredentialOptions{ + ClientOptions: azcore.ClientOptions{Transport: srv}, + AuthorityHost: AuthorityHost(srv.URL()), + }) + if err != nil { + t.Fatalf("Failed to create environment credential: %v", err) + } + cred, err := NewChainedTokenCredential([]azcore.TokenCredential{secCred, envCred}, nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + tk, err := cred.GetToken(context.Background(), policy.TokenRequestOptions{Scopes: []string{scope}}) + if err != nil { + t.Fatalf("Received an error when attempting to get a token but expected none") + } + if tk.Token != tokenValue { + t.Fatalf("Received an incorrect access token") + } + if tk.ExpiresOn.IsZero() { + t.Fatalf("Received an incorrect time in the response") + } + if cred.successfulCredential == nil { + t.Fatalf("The successful credential pointer was not assigned") + } +} + +func TestChainedTokenCredential_RepeatedGetTokenWithSuccessfulCredential(t *testing.T) { + fatalIf := func(err error, message string) { + if err != nil { + t.Fatalf("%s. Error: %v", message, err) + } + } + + err := initEnvironmentVarsForTest() + fatalIf(err, "Could not set environment variables for testing") + + srv, close := mock.NewTLSServer() + defer close() + srv.AppendResponse(mock.WithBody([]byte(accessTokenRespSuccess))) + srv.AppendResponse(mock.WithBody([]byte(accessTokenRespSuccess))) + options := ClientSecretCredentialOptions{} + options.AuthorityHost = AuthorityHost(srv.URL()) + options.Transport = srv + + secCred, err := NewClientSecretCredential(tenantID, clientID, secret, &options) + fatalIf(err, "Unable to create credential") + + envCred, err := NewEnvironmentCredential(&EnvironmentCredentialOptions{ + ClientOptions: azcore.ClientOptions{Transport: srv}, + AuthorityHost: AuthorityHost(srv.URL()), + }) + fatalIf(err, "Failed to create environment credential") + + cred, err := NewChainedTokenCredential([]azcore.TokenCredential{secCred, envCred}, nil) + fatalIf(err, "Unexpected error") + + tk, err := cred.GetToken(context.Background(), policy.TokenRequestOptions{Scopes: []string{scope}}) + fatalIf(err, "Received an error when attempting to get a token but expected none") + + if tk.Token != tokenValue { + t.Fatalf("Received an incorrect access token") + } + if tk.ExpiresOn.IsZero() { + t.Fatalf("Received an incorrect time in the response") + } + if cred.successfulCredential == nil { + t.Fatalf("The successful credential pointer was not assigned") + } + + tk2, err2 := cred.GetToken(context.Background(), policy.TokenRequestOptions{Scopes: []string{scope}}) + fatalIf(err2, "Received an error when attempting to get a token but expected none") + + if tk2.Token != tokenValue { + t.Fatalf("Received an incorrect access token") + } + if tk2.ExpiresOn.IsZero() { + t.Fatalf("Received an incorrect time in the response") + } +} + func TestBearerPolicy_ChainedTokenCredential(t *testing.T) { err := initEnvironmentVarsForTest() if err != nil { From fa5e6631b63133d8c1d391e91919f946bae0d5e3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20Rodr=C3=ADguez?= Date: Thu, 2 Dec 2021 20:15:05 +0000 Subject: [PATCH 02/22] tests passing --- sdk/azidentity/CHANGELOG.md | 2 + sdk/azidentity/chained_token_credential.go | 1 + .../chained_token_credential_test.go | 38 +++++++++---------- 3 files changed, 21 insertions(+), 20 deletions(-) diff --git a/sdk/azidentity/CHANGELOG.md b/sdk/azidentity/CHANGELOG.md index ce5be9f15710..05da98c6a655 100644 --- a/sdk/azidentity/CHANGELOG.md +++ b/sdk/azidentity/CHANGELOG.md @@ -6,6 +6,8 @@ ### Breaking Changes +- Instances of `ChainedTokenCredential` will now skip looping through the list of source credentials and re-use the last successful credential on subsequent calls to `GetToken`. + ### Bugs Fixed ### Other Changes diff --git a/sdk/azidentity/chained_token_credential.go b/sdk/azidentity/chained_token_credential.go index 59a2a557a543..eb3f2ac96c70 100644 --- a/sdk/azidentity/chained_token_credential.go +++ b/sdk/azidentity/chained_token_credential.go @@ -52,6 +52,7 @@ func (c *ChainedTokenCredential) GetToken(ctx context.Context, opts policy.Token if err != nil { return nil, formatError(err, errList) } + return token, nil } else { for _, cred := range c.sources { token, err = cred.GetToken(ctx, opts) diff --git a/sdk/azidentity/chained_token_credential_test.go b/sdk/azidentity/chained_token_credential_test.go index 3fc5a4d3fe0d..457f93829d38 100644 --- a/sdk/azidentity/chained_token_credential_test.go +++ b/sdk/azidentity/chained_token_credential_test.go @@ -195,15 +195,10 @@ func TestChainedTokenCredential_ChecksThatSuccessfulCredentialIsSet(t *testing.T } func TestChainedTokenCredential_RepeatedGetTokenWithSuccessfulCredential(t *testing.T) { - fatalIf := func(err error, message string) { - if err != nil { - t.Fatalf("%s. Error: %v", message, err) - } - } - err := initEnvironmentVarsForTest() - fatalIf(err, "Could not set environment variables for testing") - + if err != nil { + t.Fatalf("Could not set environment variables for testing: %v", err) + } srv, close := mock.NewTLSServer() defer close() srv.AppendResponse(mock.WithBody([]byte(accessTokenRespSuccess))) @@ -211,22 +206,25 @@ func TestChainedTokenCredential_RepeatedGetTokenWithSuccessfulCredential(t *test options := ClientSecretCredentialOptions{} options.AuthorityHost = AuthorityHost(srv.URL()) options.Transport = srv - secCred, err := NewClientSecretCredential(tenantID, clientID, secret, &options) - fatalIf(err, "Unable to create credential") - + if err != nil { + t.Fatalf("Unable to create credential. Received: %v", err) + } envCred, err := NewEnvironmentCredential(&EnvironmentCredentialOptions{ ClientOptions: azcore.ClientOptions{Transport: srv}, AuthorityHost: AuthorityHost(srv.URL()), }) - fatalIf(err, "Failed to create environment credential") - + if err != nil { + t.Fatalf("Failed to create environment credential: %v", err) + } cred, err := NewChainedTokenCredential([]azcore.TokenCredential{secCred, envCred}, nil) - fatalIf(err, "Unexpected error") - + if err != nil { + t.Fatalf("unexpected error: %v", err) + } tk, err := cred.GetToken(context.Background(), policy.TokenRequestOptions{Scopes: []string{scope}}) - fatalIf(err, "Received an error when attempting to get a token but expected none") - + if err != nil { + t.Fatalf("Received an error when attempting to get a token but expected none") + } if tk.Token != tokenValue { t.Fatalf("Received an incorrect access token") } @@ -236,10 +234,10 @@ func TestChainedTokenCredential_RepeatedGetTokenWithSuccessfulCredential(t *test if cred.successfulCredential == nil { t.Fatalf("The successful credential pointer was not assigned") } - tk2, err2 := cred.GetToken(context.Background(), policy.TokenRequestOptions{Scopes: []string{scope}}) - fatalIf(err2, "Received an error when attempting to get a token but expected none") - + if err2 != nil { + t.Fatalf("Received an error when attempting to get a token but expected none. Error: %v", err2) + } if tk2.Token != tokenValue { t.Fatalf("Received an incorrect access token") } From 0438cc445b2f44bacebf33789f156c037c25e5d3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20Rodr=C3=ADguez?= Date: Thu, 2 Dec 2021 20:18:42 +0000 Subject: [PATCH 03/22] last -> first --- sdk/azidentity/CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdk/azidentity/CHANGELOG.md b/sdk/azidentity/CHANGELOG.md index 05da98c6a655..bce930f82568 100644 --- a/sdk/azidentity/CHANGELOG.md +++ b/sdk/azidentity/CHANGELOG.md @@ -6,7 +6,7 @@ ### Breaking Changes -- Instances of `ChainedTokenCredential` will now skip looping through the list of source credentials and re-use the last successful credential on subsequent calls to `GetToken`. +- Instances of `ChainedTokenCredential` will now skip looping through the list of source credentials and re-use the first successful credential on subsequent calls to `GetToken`. ### Bugs Fixed From 3f58cbc342caedd8887c5e5829c9d14ed5f83597 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20Rodr=C3=ADguez?= Date: Thu, 2 Dec 2021 20:32:44 +0000 Subject: [PATCH 04/22] no need for an explicit pointer for successfulCredential --- sdk/azidentity/chained_token_credential.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sdk/azidentity/chained_token_credential.go b/sdk/azidentity/chained_token_credential.go index eb3f2ac96c70..520f68c6bf51 100644 --- a/sdk/azidentity/chained_token_credential.go +++ b/sdk/azidentity/chained_token_credential.go @@ -20,7 +20,7 @@ type ChainedTokenCredentialOptions struct { // ChainedTokenCredential is a chain of credentials that enables fallback behavior when a credential can't authenticate. type ChainedTokenCredential struct { sources []azcore.TokenCredential - successfulCredential *azcore.TokenCredential + successfulCredential azcore.TokenCredential } // NewChainedTokenCredential creates a ChainedTokenCredential. @@ -47,7 +47,7 @@ func (c *ChainedTokenCredential) GetToken(ctx context.Context, opts policy.Token var errList []CredentialUnavailableError if c.successfulCredential != nil { - successfulCredential := *c.successfulCredential + successfulCredential := c.successfulCredential token, err = successfulCredential.GetToken(ctx, opts) if err != nil { return nil, formatError(err, errList) @@ -63,7 +63,7 @@ func (c *ChainedTokenCredential) GetToken(ctx context.Context, opts policy.Token return nil, formatError(err, errList) } else { logGetTokenSuccess(c, opts) - c.successfulCredential = &cred + c.successfulCredential = cred return token, nil } } From d5d1ea7b72d05b3f8a462e9b5e96da0f83981b87 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20Rodr=C3=ADguez?= Date: Thu, 2 Dec 2021 20:34:07 +0000 Subject: [PATCH 05/22] better reference to successfulCredential --- sdk/azidentity/chained_token_credential.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/sdk/azidentity/chained_token_credential.go b/sdk/azidentity/chained_token_credential.go index 520f68c6bf51..b979e7cb05be 100644 --- a/sdk/azidentity/chained_token_credential.go +++ b/sdk/azidentity/chained_token_credential.go @@ -47,8 +47,7 @@ func (c *ChainedTokenCredential) GetToken(ctx context.Context, opts policy.Token var errList []CredentialUnavailableError if c.successfulCredential != nil { - successfulCredential := c.successfulCredential - token, err = successfulCredential.GetToken(ctx, opts) + token, err = c.successfulCredential.GetToken(ctx, opts) if err != nil { return nil, formatError(err, errList) } From 3715c8a1746de653c571fe9bc8319195c8436f4e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20Rodr=C3=ADguez?= Date: Thu, 2 Dec 2021 20:36:19 +0000 Subject: [PATCH 06/22] removed unnecessary else --- sdk/azidentity/chained_token_credential.go | 25 +++++++++++----------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/sdk/azidentity/chained_token_credential.go b/sdk/azidentity/chained_token_credential.go index b979e7cb05be..bca5e8b20c86 100644 --- a/sdk/azidentity/chained_token_credential.go +++ b/sdk/azidentity/chained_token_credential.go @@ -52,19 +52,18 @@ func (c *ChainedTokenCredential) GetToken(ctx context.Context, opts policy.Token return nil, formatError(err, errList) } return token, nil - } else { - for _, cred := range c.sources { - token, err = cred.GetToken(ctx, opts) - var credErr CredentialUnavailableError - if errors.As(err, &credErr) { - errList = append(errList, credErr) - } else if err != nil { - return nil, formatError(err, errList) - } else { - logGetTokenSuccess(c, opts) - c.successfulCredential = cred - return token, nil - } + } + for _, cred := range c.sources { + token, err = cred.GetToken(ctx, opts) + var credErr CredentialUnavailableError + if errors.As(err, &credErr) { + errList = append(errList, credErr) + } else if err != nil { + return nil, formatError(err, errList) + } else { + logGetTokenSuccess(c, opts) + c.successfulCredential = cred + return token, nil } } From f55385d749046fd85a11407e2027abe93fdb460a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20Rodr=C3=ADguez?= Date: Thu, 2 Dec 2021 20:53:19 +0000 Subject: [PATCH 07/22] moved formatError to be within the GetToken method --- sdk/azidentity/chained_token_credential.go | 24 +++++++++++----------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/sdk/azidentity/chained_token_credential.go b/sdk/azidentity/chained_token_credential.go index bca5e8b20c86..b520c3a4e964 100644 --- a/sdk/azidentity/chained_token_credential.go +++ b/sdk/azidentity/chained_token_credential.go @@ -46,10 +46,20 @@ func NewChainedTokenCredential(sources []azcore.TokenCredential, options *Chaine func (c *ChainedTokenCredential) GetToken(ctx context.Context, opts policy.TokenRequestOptions) (token *azcore.AccessToken, err error) { var errList []CredentialUnavailableError + formatError := func(err error) error { + var authFailed AuthenticationFailedError + if errors.As(err, &authFailed) { + err = fmt.Errorf("Authentication failed:\n%s\n%s"+createChainedErrorMessage(errList), err) + authErr := newAuthenticationFailedError(err, authFailed.RawResponse()) + return authErr + } + return err + } + if c.successfulCredential != nil { token, err = c.successfulCredential.GetToken(ctx, opts) if err != nil { - return nil, formatError(err, errList) + return nil, formatError(err) } return token, nil } @@ -59,7 +69,7 @@ func (c *ChainedTokenCredential) GetToken(ctx context.Context, opts policy.Token if errors.As(err, &credErr) { errList = append(errList, credErr) } else if err != nil { - return nil, formatError(err, errList) + return nil, formatError(err) } else { logGetTokenSuccess(c, opts) c.successfulCredential = cred @@ -83,14 +93,4 @@ func createChainedErrorMessage(errList []CredentialUnavailableError) string { return msg } -func formatError(err error, errList []CredentialUnavailableError) error { - var authFailed AuthenticationFailedError - if errors.As(err, &authFailed) { - err = fmt.Errorf("Authentication failed:\n%s\n%s"+createChainedErrorMessage(errList), err) - authErr := newAuthenticationFailedError(err, authFailed.RawResponse()) - return authErr - } - return err -} - var _ azcore.TokenCredential = (*ChainedTokenCredential)(nil) From fa7d4609dbb2c85ed9f854b3e3b3acf0e2713369 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20Rodr=C3=ADguez?= Date: Thu, 2 Dec 2021 21:10:53 +0000 Subject: [PATCH 08/22] Better tests Now I'm testing that the right credential is the one that appears as the successfulCredential, and that each credential is only called as many times as expected. --- .../chained_token_credential_test.go | 41 ++++++++++++++++++- 1 file changed, 40 insertions(+), 1 deletion(-) diff --git a/sdk/azidentity/chained_token_credential_test.go b/sdk/azidentity/chained_token_credential_test.go index 457f93829d38..54c4a5217832 100644 --- a/sdk/azidentity/chained_token_credential_test.go +++ b/sdk/azidentity/chained_token_credential_test.go @@ -192,6 +192,22 @@ func TestChainedTokenCredential_ChecksThatSuccessfulCredentialIsSet(t *testing.T if cred.successfulCredential == nil { t.Fatalf("The successful credential pointer was not assigned") } + if cred.successfulCredential != secCred { + t.Fatalf("The successful credential should have been the secret credential") + } +} + +/** + * Helps count the number of times a credential is called. + */ +type TestCountPolicy struct{ count int } + +/** + * Helps count the number of times a credential is called. + */ +func (p *TestCountPolicy) Do(req *policy.Request) (*http.Response, error) { + p.count += 1 + return req.Next() } func TestChainedTokenCredential_RepeatedGetTokenWithSuccessfulCredential(t *testing.T) { @@ -201,17 +217,25 @@ func TestChainedTokenCredential_RepeatedGetTokenWithSuccessfulCredential(t *test } srv, close := mock.NewTLSServer() defer close() + + secretCountPolicy := &TestCountPolicy{} + environmentCountPolicy := &TestCountPolicy{} + srv.AppendResponse(mock.WithBody([]byte(accessTokenRespSuccess))) srv.AppendResponse(mock.WithBody([]byte(accessTokenRespSuccess))) options := ClientSecretCredentialOptions{} options.AuthorityHost = AuthorityHost(srv.URL()) + options.PerCallPolicies = []policy.Policy{secretCountPolicy} options.Transport = srv secCred, err := NewClientSecretCredential(tenantID, clientID, secret, &options) if err != nil { t.Fatalf("Unable to create credential. Received: %v", err) } envCred, err := NewEnvironmentCredential(&EnvironmentCredentialOptions{ - ClientOptions: azcore.ClientOptions{Transport: srv}, + ClientOptions: azcore.ClientOptions{ + Transport: srv, + PerCallPolicies: []policy.Policy{environmentCountPolicy}, + }, AuthorityHost: AuthorityHost(srv.URL()), }) if err != nil { @@ -234,6 +258,15 @@ func TestChainedTokenCredential_RepeatedGetTokenWithSuccessfulCredential(t *test if cred.successfulCredential == nil { t.Fatalf("The successful credential pointer was not assigned") } + if cred.successfulCredential != secCred { + t.Fatalf("The successful credential should have been the secret credential") + } + if secretCountPolicy.count != 1 { + t.Fatalf("The secret credential policies should have been triggered once") + } + if environmentCountPolicy.count != 0 { + t.Fatalf("The environment credential policies should not have been triggered") + } tk2, err2 := cred.GetToken(context.Background(), policy.TokenRequestOptions{Scopes: []string{scope}}) if err2 != nil { t.Fatalf("Received an error when attempting to get a token but expected none. Error: %v", err2) @@ -244,6 +277,12 @@ func TestChainedTokenCredential_RepeatedGetTokenWithSuccessfulCredential(t *test if tk2.ExpiresOn.IsZero() { t.Fatalf("Received an incorrect time in the response") } + if secretCountPolicy.count != 2 { + t.Fatalf("The secret credential policies should have been triggered twice") + } + if environmentCountPolicy.count != 0 { + t.Fatalf("The environment credential policies should not have been triggered") + } } func TestBearerPolicy_ChainedTokenCredential(t *testing.T) { From cf1ce472c99ad79bdf25686cb8bf1b8621163a85 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20Rodr=C3=ADguez?= Date: Thu, 2 Dec 2021 22:23:27 +0000 Subject: [PATCH 09/22] RetryAllSources and related documentation --- sdk/azidentity/CHANGELOG.md | 3 +- sdk/azidentity/chained_token_credential.go | 18 ++++++-- .../chained_token_credential_test.go | 41 +++++++++++++++++++ 3 files changed, 58 insertions(+), 4 deletions(-) diff --git a/sdk/azidentity/CHANGELOG.md b/sdk/azidentity/CHANGELOG.md index bce930f82568..f874702734c5 100644 --- a/sdk/azidentity/CHANGELOG.md +++ b/sdk/azidentity/CHANGELOG.md @@ -6,7 +6,8 @@ ### Breaking Changes -- Instances of `ChainedTokenCredential` will now skip looping through the list of source credentials and re-use the first successful credential on subsequent calls to `GetToken`. +* Instances of `ChainedTokenCredential` will now skip looping through the list of source credentials and re-use the first successful credential on subsequent calls to `GetToken`. + * If the `RetryAllSources` option is set to true, it will not assume the first successful credential should be always used. ### Bugs Fixed diff --git a/sdk/azidentity/chained_token_credential.go b/sdk/azidentity/chained_token_credential.go index b520c3a4e964..ead07e0e59fa 100644 --- a/sdk/azidentity/chained_token_credential.go +++ b/sdk/azidentity/chained_token_credential.go @@ -14,16 +14,22 @@ import ( // ChainedTokenCredentialOptions contains optional parameters for ChainedTokenCredential. type ChainedTokenCredentialOptions struct { - // placeholder for future options + // If true, it will not assume the first successful credential should be always used. + RetryAllSources bool } // ChainedTokenCredential is a chain of credentials that enables fallback behavior when a credential can't authenticate. +// By default, this credential will assume that the first successful credential should be the only credential used on future requests. +// When `retryAllSources` is true, it will always try to get a token with every credential available on the `sources` array. type ChainedTokenCredential struct { sources []azcore.TokenCredential successfulCredential azcore.TokenCredential + retryAllSources bool } // NewChainedTokenCredential creates a ChainedTokenCredential. +// By default, this credential will assume that the first successful credential should be the only credential used on future requests. +// If the `RetryAllSources` option is set to true, it will always try to get a token using all of the available credentials. // sources: Credential instances to comprise the chain. GetToken() will invoke them in the given order. // options: Optional configuration. func NewChainedTokenCredential(sources []azcore.TokenCredential, options *ChainedTokenCredentialOptions) (*ChainedTokenCredential, error) { @@ -37,7 +43,11 @@ func NewChainedTokenCredential(sources []azcore.TokenCredential, options *Chaine } cp := make([]azcore.TokenCredential, len(sources)) copy(cp, sources) - return &ChainedTokenCredential{sources: cp}, nil + credentialOptions := ChainedTokenCredentialOptions{} + if options != nil { + credentialOptions = *options + } + return &ChainedTokenCredential{sources: cp, retryAllSources: credentialOptions.RetryAllSources}, nil } // GetToken calls GetToken on the chained credentials in turn, stopping when one returns a token. This method is called automatically by Azure SDK clients. @@ -72,7 +82,9 @@ func (c *ChainedTokenCredential) GetToken(ctx context.Context, opts policy.Token return nil, formatError(err) } else { logGetTokenSuccess(c, opts) - c.successfulCredential = cred + if !c.retryAllSources { + c.successfulCredential = cred + } return token, nil } } diff --git a/sdk/azidentity/chained_token_credential_test.go b/sdk/azidentity/chained_token_credential_test.go index 54c4a5217832..dbcd8388043d 100644 --- a/sdk/azidentity/chained_token_credential_test.go +++ b/sdk/azidentity/chained_token_credential_test.go @@ -197,6 +197,47 @@ func TestChainedTokenCredential_ChecksThatSuccessfulCredentialIsSet(t *testing.T } } +func TestChainedTokenCredential_ChecksThatSuccessfulCredentialIsNotSetIfRetryAllSourcesIsTrue(t *testing.T) { + err := initEnvironmentVarsForTest() + if err != nil { + t.Fatalf("Could not set environment variables for testing: %v", err) + } + srv, close := mock.NewTLSServer() + defer close() + srv.AppendResponse(mock.WithBody([]byte(accessTokenRespSuccess))) + options := ClientSecretCredentialOptions{} + options.AuthorityHost = AuthorityHost(srv.URL()) + options.Transport = srv + secCred, err := NewClientSecretCredential(tenantID, clientID, secret, &options) + if err != nil { + t.Fatalf("Unable to create credential. Received: %v", err) + } + envCred, err := NewEnvironmentCredential(&EnvironmentCredentialOptions{ + ClientOptions: azcore.ClientOptions{Transport: srv}, + AuthorityHost: AuthorityHost(srv.URL()), + }) + if err != nil { + t.Fatalf("Failed to create environment credential: %v", err) + } + cred, err := NewChainedTokenCredential([]azcore.TokenCredential{secCred, envCred}, &ChainedTokenCredentialOptions{RetryAllSources: true}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + tk, err := cred.GetToken(context.Background(), policy.TokenRequestOptions{Scopes: []string{scope}}) + if err != nil { + t.Fatalf("Received an error when attempting to get a token but expected none") + } + if tk.Token != tokenValue { + t.Fatalf("Received an incorrect access token") + } + if tk.ExpiresOn.IsZero() { + t.Fatalf("Received an incorrect time in the response") + } + if cred.successfulCredential != nil { + t.Fatalf("The successful credential pointer should not be assigned when RetryAllSources is provided as true") + } +} + /** * Helps count the number of times a credential is called. */ From a1bbdd1629549e699082252aefd84aa5376c3f83 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20Rodr=C3=ADguez?= Date: Fri, 3 Dec 2021 00:31:55 +0000 Subject: [PATCH 10/22] updated the logic so that successfulCredential is always populated. Improved the tests --- sdk/azidentity/chained_token_credential.go | 8 +- .../chained_token_credential_test.go | 111 ++++++++++++------ 2 files changed, 78 insertions(+), 41 deletions(-) diff --git a/sdk/azidentity/chained_token_credential.go b/sdk/azidentity/chained_token_credential.go index ead07e0e59fa..f35a4b7389eb 100644 --- a/sdk/azidentity/chained_token_credential.go +++ b/sdk/azidentity/chained_token_credential.go @@ -19,8 +19,6 @@ type ChainedTokenCredentialOptions struct { } // ChainedTokenCredential is a chain of credentials that enables fallback behavior when a credential can't authenticate. -// By default, this credential will assume that the first successful credential should be the only credential used on future requests. -// When `retryAllSources` is true, it will always try to get a token with every credential available on the `sources` array. type ChainedTokenCredential struct { sources []azcore.TokenCredential successfulCredential azcore.TokenCredential @@ -66,7 +64,7 @@ func (c *ChainedTokenCredential) GetToken(ctx context.Context, opts policy.Token return err } - if c.successfulCredential != nil { + if c.successfulCredential != nil && !c.retryAllSources { token, err = c.successfulCredential.GetToken(ctx, opts) if err != nil { return nil, formatError(err) @@ -82,9 +80,7 @@ func (c *ChainedTokenCredential) GetToken(ctx context.Context, opts policy.Token return nil, formatError(err) } else { logGetTokenSuccess(c, opts) - if !c.retryAllSources { - c.successfulCredential = cred - } + c.successfulCredential = cred return token, nil } } diff --git a/sdk/azidentity/chained_token_credential_test.go b/sdk/azidentity/chained_token_credential_test.go index dbcd8388043d..3d5b3125dc92 100644 --- a/sdk/azidentity/chained_token_credential_test.go +++ b/sdk/azidentity/chained_token_credential_test.go @@ -190,36 +190,58 @@ func TestChainedTokenCredential_ChecksThatSuccessfulCredentialIsSet(t *testing.T t.Fatalf("Received an incorrect time in the response") } if cred.successfulCredential == nil { - t.Fatalf("The successful credential pointer was not assigned") + t.Fatalf("The successful credential was not assigned") } if cred.successfulCredential != secCred { t.Fatalf("The successful credential should have been the secret credential") } } -func TestChainedTokenCredential_ChecksThatSuccessfulCredentialIsNotSetIfRetryAllSourcesIsTrue(t *testing.T) { +/** + * Helps count the number of times a credential is called. + */ +type TestCountPolicy struct{ count int } + +/** + * Helps count the number of times a credential is called. + */ +func (p *TestCountPolicy) Do(req *policy.Request) (*http.Response, error) { + p.count += 1 + return req.Next() +} + +func TestChainedTokenCredential_RepeatedGetTokenWithSuccessfulCredential(t *testing.T) { err := initEnvironmentVarsForTest() if err != nil { t.Fatalf("Could not set environment variables for testing: %v", err) } srv, close := mock.NewTLSServer() defer close() + + secretCountPolicy := &TestCountPolicy{} + environmentCountPolicy := &TestCountPolicy{} + + srv.AppendResponse(mock.WithBody([]byte(accessTokenRespSuccess))) srv.AppendResponse(mock.WithBody([]byte(accessTokenRespSuccess))) options := ClientSecretCredentialOptions{} options.AuthorityHost = AuthorityHost(srv.URL()) + options.PerCallPolicies = []policy.Policy{secretCountPolicy} options.Transport = srv secCred, err := NewClientSecretCredential(tenantID, clientID, secret, &options) if err != nil { t.Fatalf("Unable to create credential. Received: %v", err) } envCred, err := NewEnvironmentCredential(&EnvironmentCredentialOptions{ - ClientOptions: azcore.ClientOptions{Transport: srv}, + ClientOptions: azcore.ClientOptions{ + Transport: srv, + PerCallPolicies: []policy.Policy{environmentCountPolicy}, + }, AuthorityHost: AuthorityHost(srv.URL()), }) if err != nil { t.Fatalf("Failed to create environment credential: %v", err) } - cred, err := NewChainedTokenCredential([]azcore.TokenCredential{secCred, envCred}, &ChainedTokenCredentialOptions{RetryAllSources: true}) + cred, err := NewChainedTokenCredential([]azcore.TokenCredential{secCred, envCred}, nil) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -233,25 +255,50 @@ func TestChainedTokenCredential_ChecksThatSuccessfulCredentialIsNotSetIfRetryAll if tk.ExpiresOn.IsZero() { t.Fatalf("Received an incorrect time in the response") } - if cred.successfulCredential != nil { - t.Fatalf("The successful credential pointer should not be assigned when RetryAllSources is provided as true") + if cred.successfulCredential == nil { + t.Fatalf("The successful credential was not assigned") + } + if cred.successfulCredential != secCred { + t.Fatalf("The successful credential should have been the secret credential") + } + if secretCountPolicy.count != 1 { + t.Fatalf("The secret credential policies should have been triggered once") + } + if environmentCountPolicy.count != 0 { + t.Fatalf("The environment credential policies should not have been triggered") + } + tk2, err2 := cred.GetToken(context.Background(), policy.TokenRequestOptions{Scopes: []string{scope}}) + if err2 != nil { + t.Fatalf("Received an error when attempting to get a token but expected none. Error: %v", err2) + } + if tk2.Token != tokenValue { + t.Fatalf("Received an incorrect access token") + } + if tk2.ExpiresOn.IsZero() { + t.Fatalf("Received an incorrect time in the response") + } + if secretCountPolicy.count != 2 { + t.Fatalf("The secret credential policies should have been triggered twice") + } + if environmentCountPolicy.count != 0 { + t.Fatalf("The environment credential policies should not have been triggered") } } -/** - * Helps count the number of times a credential is called. - */ -type TestCountPolicy struct{ count int } +// A credential that always throws a CredentialUnavailableError +type UnavailableCredential struct { + callCount int +} -/** - * Helps count the number of times a credential is called. - */ -func (p *TestCountPolicy) Do(req *policy.Request) (*http.Response, error) { - p.count += 1 - return req.Next() +func NewUnavailableCredential() (*UnavailableCredential, error) { + return &UnavailableCredential{callCount: 0}, nil +} +func (c *UnavailableCredential) GetToken(ctx context.Context, opts policy.TokenRequestOptions) (token *azcore.AccessToken, err error) { + c.callCount += 1 + return nil, newCredentialUnavailableError("UnavailableCredential", "Expected CredentialUnavailableError") } -func TestChainedTokenCredential_RepeatedGetTokenWithSuccessfulCredential(t *testing.T) { +func TestChainedTokenCredential_RepeatedGetTokenWithSuccessfulCredentialWithRetryAllSources(t *testing.T) { err := initEnvironmentVarsForTest() if err != nil { t.Fatalf("Could not set environment variables for testing: %v", err) @@ -260,35 +307,29 @@ func TestChainedTokenCredential_RepeatedGetTokenWithSuccessfulCredential(t *test defer close() secretCountPolicy := &TestCountPolicy{} - environmentCountPolicy := &TestCountPolicy{} srv.AppendResponse(mock.WithBody([]byte(accessTokenRespSuccess))) srv.AppendResponse(mock.WithBody([]byte(accessTokenRespSuccess))) + options := ClientSecretCredentialOptions{} options.AuthorityHost = AuthorityHost(srv.URL()) options.PerCallPolicies = []policy.Policy{secretCountPolicy} options.Transport = srv + + unavailableCred, _ := NewUnavailableCredential() secCred, err := NewClientSecretCredential(tenantID, clientID, secret, &options) if err != nil { t.Fatalf("Unable to create credential. Received: %v", err) } - envCred, err := NewEnvironmentCredential(&EnvironmentCredentialOptions{ - ClientOptions: azcore.ClientOptions{ - Transport: srv, - PerCallPolicies: []policy.Policy{environmentCountPolicy}, - }, - AuthorityHost: AuthorityHost(srv.URL()), - }) - if err != nil { - t.Fatalf("Failed to create environment credential: %v", err) - } - cred, err := NewChainedTokenCredential([]azcore.TokenCredential{secCred, envCred}, nil) + + // Backwards order: envCred first, secCred later, to check that envCred is always called when RetryAllSources is set to true. + cred, err := NewChainedTokenCredential([]azcore.TokenCredential{unavailableCred, secCred}, &ChainedTokenCredentialOptions{RetryAllSources: true}) if err != nil { t.Fatalf("unexpected error: %v", err) } tk, err := cred.GetToken(context.Background(), policy.TokenRequestOptions{Scopes: []string{scope}}) if err != nil { - t.Fatalf("Received an error when attempting to get a token but expected none") + t.Fatalf("Received an error when attempting to get a token but expected none. Error: %v", err) } if tk.Token != tokenValue { t.Fatalf("Received an incorrect access token") @@ -297,7 +338,7 @@ func TestChainedTokenCredential_RepeatedGetTokenWithSuccessfulCredential(t *test t.Fatalf("Received an incorrect time in the response") } if cred.successfulCredential == nil { - t.Fatalf("The successful credential pointer was not assigned") + t.Fatalf("The successful credential was not assigned") } if cred.successfulCredential != secCred { t.Fatalf("The successful credential should have been the secret credential") @@ -305,8 +346,8 @@ func TestChainedTokenCredential_RepeatedGetTokenWithSuccessfulCredential(t *test if secretCountPolicy.count != 1 { t.Fatalf("The secret credential policies should have been triggered once") } - if environmentCountPolicy.count != 0 { - t.Fatalf("The environment credential policies should not have been triggered") + if unavailableCred.callCount != 1 { + t.Fatalf("The environment credential policies should have been triggered once") } tk2, err2 := cred.GetToken(context.Background(), policy.TokenRequestOptions{Scopes: []string{scope}}) if err2 != nil { @@ -321,8 +362,8 @@ func TestChainedTokenCredential_RepeatedGetTokenWithSuccessfulCredential(t *test if secretCountPolicy.count != 2 { t.Fatalf("The secret credential policies should have been triggered twice") } - if environmentCountPolicy.count != 0 { - t.Fatalf("The environment credential policies should not have been triggered") + if unavailableCred.callCount != 2 { + t.Fatalf("The environment credential policies should have been triggered twice") } } From dab40674a0cb8e372c2b65d39830d2129465bb23 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20Rodr=C3=ADguez?= Date: Thu, 2 Dec 2021 20:59:32 -0500 Subject: [PATCH 11/22] Apply suggestions from code review --- sdk/azidentity/CHANGELOG.md | 2 +- sdk/azidentity/chained_token_credential.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sdk/azidentity/CHANGELOG.md b/sdk/azidentity/CHANGELOG.md index f874702734c5..c43f225f20a0 100644 --- a/sdk/azidentity/CHANGELOG.md +++ b/sdk/azidentity/CHANGELOG.md @@ -7,7 +7,7 @@ ### Breaking Changes * Instances of `ChainedTokenCredential` will now skip looping through the list of source credentials and re-use the first successful credential on subsequent calls to `GetToken`. - * If the `RetryAllSources` option is set to true, it will not assume the first successful credential should be always used. + * If the `RetryAllSources` option is set to true, it will not assume the first successful credential should be always used. It will continue to try all of the originally provided credentials each time the `GetToken` method is called. ### Bugs Fixed diff --git a/sdk/azidentity/chained_token_credential.go b/sdk/azidentity/chained_token_credential.go index f35a4b7389eb..76fcfb885a4c 100644 --- a/sdk/azidentity/chained_token_credential.go +++ b/sdk/azidentity/chained_token_credential.go @@ -27,7 +27,7 @@ type ChainedTokenCredential struct { // NewChainedTokenCredential creates a ChainedTokenCredential. // By default, this credential will assume that the first successful credential should be the only credential used on future requests. -// If the `RetryAllSources` option is set to true, it will always try to get a token using all of the available credentials. +// If the `RetryAllSources` option is set to true, it will always try to get a token using all of the originally provided credentials. // sources: Credential instances to comprise the chain. GetToken() will invoke them in the given order. // options: Optional configuration. func NewChainedTokenCredential(sources []azcore.TokenCredential, options *ChainedTokenCredentialOptions) (*ChainedTokenCredential, error) { From e288440b96d230c971216758a8bc5881d9241141 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20Rodr=C3=ADguez?= Date: Thu, 9 Dec 2021 20:36:22 +0000 Subject: [PATCH 12/22] RetryAllSources to RetrySources --- sdk/azidentity/CHANGELOG.md | 2 +- sdk/azidentity/chained_token_credential.go | 10 +++++----- sdk/azidentity/chained_token_credential_test.go | 6 +++--- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/sdk/azidentity/CHANGELOG.md b/sdk/azidentity/CHANGELOG.md index c43f225f20a0..1be2169bdfd2 100644 --- a/sdk/azidentity/CHANGELOG.md +++ b/sdk/azidentity/CHANGELOG.md @@ -7,7 +7,7 @@ ### Breaking Changes * Instances of `ChainedTokenCredential` will now skip looping through the list of source credentials and re-use the first successful credential on subsequent calls to `GetToken`. - * If the `RetryAllSources` option is set to true, it will not assume the first successful credential should be always used. It will continue to try all of the originally provided credentials each time the `GetToken` method is called. + * If the `RetrySources` option is set to true, it will not assume the first successful credential should be always used. It will continue to try all of the originally provided credentials each time the `GetToken` method is called. ### Bugs Fixed diff --git a/sdk/azidentity/chained_token_credential.go b/sdk/azidentity/chained_token_credential.go index 76fcfb885a4c..ba44ab366886 100644 --- a/sdk/azidentity/chained_token_credential.go +++ b/sdk/azidentity/chained_token_credential.go @@ -15,19 +15,19 @@ import ( // ChainedTokenCredentialOptions contains optional parameters for ChainedTokenCredential. type ChainedTokenCredentialOptions struct { // If true, it will not assume the first successful credential should be always used. - RetryAllSources bool + RetrySources bool } // ChainedTokenCredential is a chain of credentials that enables fallback behavior when a credential can't authenticate. type ChainedTokenCredential struct { sources []azcore.TokenCredential successfulCredential azcore.TokenCredential - retryAllSources bool + retrySources bool } // NewChainedTokenCredential creates a ChainedTokenCredential. // By default, this credential will assume that the first successful credential should be the only credential used on future requests. -// If the `RetryAllSources` option is set to true, it will always try to get a token using all of the originally provided credentials. +// If the `RetrySources` option is set to true, it will always try to get a token using all of the originally provided credentials. // sources: Credential instances to comprise the chain. GetToken() will invoke them in the given order. // options: Optional configuration. func NewChainedTokenCredential(sources []azcore.TokenCredential, options *ChainedTokenCredentialOptions) (*ChainedTokenCredential, error) { @@ -45,7 +45,7 @@ func NewChainedTokenCredential(sources []azcore.TokenCredential, options *Chaine if options != nil { credentialOptions = *options } - return &ChainedTokenCredential{sources: cp, retryAllSources: credentialOptions.RetryAllSources}, nil + return &ChainedTokenCredential{sources: cp, retrySources: credentialOptions.RetrySources}, nil } // GetToken calls GetToken on the chained credentials in turn, stopping when one returns a token. This method is called automatically by Azure SDK clients. @@ -64,7 +64,7 @@ func (c *ChainedTokenCredential) GetToken(ctx context.Context, opts policy.Token return err } - if c.successfulCredential != nil && !c.retryAllSources { + if c.successfulCredential != nil && !c.retrySources { token, err = c.successfulCredential.GetToken(ctx, opts) if err != nil { return nil, formatError(err) diff --git a/sdk/azidentity/chained_token_credential_test.go b/sdk/azidentity/chained_token_credential_test.go index 3d5b3125dc92..3f7d05d1859b 100644 --- a/sdk/azidentity/chained_token_credential_test.go +++ b/sdk/azidentity/chained_token_credential_test.go @@ -298,7 +298,7 @@ func (c *UnavailableCredential) GetToken(ctx context.Context, opts policy.TokenR return nil, newCredentialUnavailableError("UnavailableCredential", "Expected CredentialUnavailableError") } -func TestChainedTokenCredential_RepeatedGetTokenWithSuccessfulCredentialWithRetryAllSources(t *testing.T) { +func TestChainedTokenCredential_RepeatedGetTokenWithSuccessfulCredentialWithRetrySources(t *testing.T) { err := initEnvironmentVarsForTest() if err != nil { t.Fatalf("Could not set environment variables for testing: %v", err) @@ -322,8 +322,8 @@ func TestChainedTokenCredential_RepeatedGetTokenWithSuccessfulCredentialWithRetr t.Fatalf("Unable to create credential. Received: %v", err) } - // Backwards order: envCred first, secCred later, to check that envCred is always called when RetryAllSources is set to true. - cred, err := NewChainedTokenCredential([]azcore.TokenCredential{unavailableCred, secCred}, &ChainedTokenCredentialOptions{RetryAllSources: true}) + // Backwards order: envCred first, secCred later, to check that envCred is always called when RetrySources is set to true. + cred, err := NewChainedTokenCredential([]azcore.TokenCredential{unavailableCred, secCred}, &ChainedTokenCredentialOptions{RetrySources: true}) if err != nil { t.Fatalf("unexpected error: %v", err) } From cdd4bb63b856cf7090abd6912dda6f4b19ab5bd7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20Rodr=C3=ADguez?= Date: Thu, 9 Dec 2021 17:20:42 -0500 Subject: [PATCH 13/22] Apply suggestions from code review Co-authored-by: Charles Lowell <10964656+chlowell@users.noreply.github.com> --- sdk/azidentity/CHANGELOG.md | 2 +- sdk/azidentity/chained_token_credential.go | 9 ++++----- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/sdk/azidentity/CHANGELOG.md b/sdk/azidentity/CHANGELOG.md index 1be2169bdfd2..8978422bb250 100644 --- a/sdk/azidentity/CHANGELOG.md +++ b/sdk/azidentity/CHANGELOG.md @@ -7,7 +7,7 @@ ### Breaking Changes * Instances of `ChainedTokenCredential` will now skip looping through the list of source credentials and re-use the first successful credential on subsequent calls to `GetToken`. - * If the `RetrySources` option is set to true, it will not assume the first successful credential should be always used. It will continue to try all of the originally provided credentials each time the `GetToken` method is called. + * If `ChainedTokenCredentialOptions.RetrySources` is true, `ChainedTokenCredential` will continue to try all of the originally provided credentials each time the `GetToken` method is called. ### Bugs Fixed diff --git a/sdk/azidentity/chained_token_credential.go b/sdk/azidentity/chained_token_credential.go index ba44ab366886..4c6413f5ef7d 100644 --- a/sdk/azidentity/chained_token_credential.go +++ b/sdk/azidentity/chained_token_credential.go @@ -14,7 +14,7 @@ import ( // ChainedTokenCredentialOptions contains optional parameters for ChainedTokenCredential. type ChainedTokenCredentialOptions struct { - // If true, it will not assume the first successful credential should be always used. + // RetrySources configures how the credential uses its sources. When true, the credential will always request a token from each source in turn, stopping when one provides a token. When false, the credential requests a token from only that source--it never again tries the sources which failed. RetrySources bool } @@ -41,11 +41,10 @@ func NewChainedTokenCredential(sources []azcore.TokenCredential, options *Chaine } cp := make([]azcore.TokenCredential, len(sources)) copy(cp, sources) - credentialOptions := ChainedTokenCredentialOptions{} - if options != nil { - credentialOptions = *options + if options == nil { + options = &ChainedTokenCredentialOptions{} } - return &ChainedTokenCredential{sources: cp, retrySources: credentialOptions.RetrySources}, nil + return &ChainedTokenCredential{sources: cp, retrySources: options.RetrySources}, nil } // GetToken calls GetToken on the chained credentials in turn, stopping when one returns a token. This method is called automatically by Azure SDK clients. From e2d1201e28ba824b2b640a5eda0b1be1c5120fbc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20Rodr=C3=ADguez?= Date: Thu, 9 Dec 2021 22:42:50 +0000 Subject: [PATCH 14/22] Feedback from Charles and adding RetrySources to the DefaultAzureCredential --- sdk/azidentity/CHANGELOG.md | 5 +++- sdk/azidentity/chained_token_credential.go | 11 ++++--- sdk/azidentity/default_azure_credential.go | 8 ++++- .../default_azure_credential_test.go | 29 +++++++++++++++++++ 4 files changed, 47 insertions(+), 6 deletions(-) diff --git a/sdk/azidentity/CHANGELOG.md b/sdk/azidentity/CHANGELOG.md index 8978422bb250..e276c3b1cd18 100644 --- a/sdk/azidentity/CHANGELOG.md +++ b/sdk/azidentity/CHANGELOG.md @@ -6,8 +6,11 @@ ### Breaking Changes -* Instances of `ChainedTokenCredential` will now skip looping through the list of source credentials and re-use the first successful credential on subsequent calls to `GetToken`. +* Instances of `ChainedTokenCredential` and `DefaultAzureCredential` will now skip looping through the list of source credentials and re-use the first successful credential on subsequent calls to `GetToken`. * If `ChainedTokenCredentialOptions.RetrySources` is true, `ChainedTokenCredential` will continue to try all of the originally provided credentials each time the `GetToken` method is called. + * `ChainedTokenCredential.successfulCredential` will contain a reference to the last successful credential. + * If `DefaultAzureCredential.RetrySources` is true, `DefaultAzureCredential` will continue to try all of the underlying credentials (`EnvironmentCredential`, `ManagedIdentityCredential`, `AzureCLICredential`) each time the `GetToken` method is called. + * `DefaultAzureCredential.chain.successfulCredential` will contain a reference to the last successful credential. ### Bugs Fixed diff --git a/sdk/azidentity/chained_token_credential.go b/sdk/azidentity/chained_token_credential.go index 4c6413f5ef7d..960029ab57f9 100644 --- a/sdk/azidentity/chained_token_credential.go +++ b/sdk/azidentity/chained_token_credential.go @@ -14,11 +14,16 @@ import ( // ChainedTokenCredentialOptions contains optional parameters for ChainedTokenCredential. type ChainedTokenCredentialOptions struct { - // RetrySources configures how the credential uses its sources. When true, the credential will always request a token from each source in turn, stopping when one provides a token. When false, the credential requests a token from only that source--it never again tries the sources which failed. + // RetrySources configures how the credential uses its sources. + // When true, the credential will always request a token from each source in turn, + // stopping when one provides a token. When false, the credential requests a token + // only from the source that previously retrieved a token--it never again tries the sources which failed. RetrySources bool } // ChainedTokenCredential is a chain of credentials that enables fallback behavior when a credential can't authenticate. +// By default, this credential will assume that the first successful credential should be the only credential used on future requests. +// If the `RetrySources` option is set to true, it will always try to get a token using all of the originally provided credentials. type ChainedTokenCredential struct { sources []azcore.TokenCredential successfulCredential azcore.TokenCredential @@ -26,8 +31,6 @@ type ChainedTokenCredential struct { } // NewChainedTokenCredential creates a ChainedTokenCredential. -// By default, this credential will assume that the first successful credential should be the only credential used on future requests. -// If the `RetrySources` option is set to true, it will always try to get a token using all of the originally provided credentials. // sources: Credential instances to comprise the chain. GetToken() will invoke them in the given order. // options: Optional configuration. func NewChainedTokenCredential(sources []azcore.TokenCredential, options *ChainedTokenCredentialOptions) (*ChainedTokenCredential, error) { @@ -66,7 +69,7 @@ func (c *ChainedTokenCredential) GetToken(ctx context.Context, opts policy.Token if c.successfulCredential != nil && !c.retrySources { token, err = c.successfulCredential.GetToken(ctx, opts) if err != nil { - return nil, formatError(err) + return nil, err } return token, nil } diff --git a/sdk/azidentity/default_azure_credential.go b/sdk/azidentity/default_azure_credential.go index b4374acd3553..941757c13b76 100644 --- a/sdk/azidentity/default_azure_credential.go +++ b/sdk/azidentity/default_azure_credential.go @@ -27,6 +27,12 @@ type DefaultAzureCredentialOptions struct { // TenantID identifies the tenant the Azure CLI should authenticate in. // Defaults to the CLI's default tenant, which is typically the home tenant of the user logged in to the CLI. TenantID string + + // RetrySources configures how the credential uses its sources. + // When true, the credential will always request a token from each underling credential in turn, + // stopping when one provides a token. When false, the credential requests a token + // only from the credential that previously retrieved a token--it never again tries the sources which failed. + RetrySources bool } // DefaultAzureCredential is a default credential chain for applications that will deploy to Azure. @@ -78,7 +84,7 @@ func NewDefaultAzureCredential(options *DefaultAzureCredentialOptions) (*Default logCredentialError("Default Azure Credential", err) return nil, err } - chain, err := NewChainedTokenCredential(creds, nil) + chain, err := NewChainedTokenCredential(creds, &ChainedTokenCredentialOptions{RetrySources: options.RetrySources}) if err != nil { return nil, err } diff --git a/sdk/azidentity/default_azure_credential_test.go b/sdk/azidentity/default_azure_credential_test.go index 7031a1c8dfed..601cd9a5ef1e 100644 --- a/sdk/azidentity/default_azure_credential_test.go +++ b/sdk/azidentity/default_azure_credential_test.go @@ -27,3 +27,32 @@ func TestDefaultAzureCredential_GetTokenSuccess(t *testing.T) { t.Fatalf("GetToken error: %v", err) } } + +func TestDefaultAzureCredential_WithRetrySources(t *testing.T) { + env := map[string]string{"AZURE_TENANT_ID": tenantID, "AZURE_CLIENT_ID": clientID, "AZURE_CLIENT_SECRET": secret} + setEnvironmentVariables(t, env) + srv, close := mock.NewTLSServer() + defer close() + + srv.AppendResponse(mock.WithBody([]byte(accessTokenRespSuccess))) + srv.AppendResponse(mock.WithBody([]byte(accessTokenRespSuccess))) + + cred, err := NewDefaultAzureCredential(&DefaultAzureCredentialOptions{AuthorityHost: AuthorityHost(srv.URL()), ClientOptions: policy.ClientOptions{Transport: srv}, RetrySources: true}) + if err != nil { + t.Fatalf("Unable to create credential. Received: %v", err) + } + + tk, err := cred.GetToken(context.Background(), policy.TokenRequestOptions{Scopes: []string{scope}}) + if err != nil { + t.Fatalf("Received an error when attempting to get a token but expected none. Error: %v", err) + } + if tk.Token != tokenValue { + t.Fatalf("Received an incorrect access token") + } + if tk.ExpiresOn.IsZero() { + t.Fatalf("Received an incorrect time in the response") + } + if cred.chain.successfulCredential == nil { + t.Fatalf("The successful credential was not assigned") + } +} From 8e51b2cbe871a31d4b5fe7145bb5ff836227a26f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20Rodr=C3=ADguez?= Date: Fri, 10 Dec 2021 17:57:05 +0000 Subject: [PATCH 15/22] undoing changes to the DefaultAzureCredential --- sdk/azidentity/CHANGELOG.md | 6 ++-- sdk/azidentity/default_azure_credential.go | 8 +---- .../default_azure_credential_test.go | 29 ------------------- 3 files changed, 4 insertions(+), 39 deletions(-) diff --git a/sdk/azidentity/CHANGELOG.md b/sdk/azidentity/CHANGELOG.md index e276c3b1cd18..ffbdc09662b4 100644 --- a/sdk/azidentity/CHANGELOG.md +++ b/sdk/azidentity/CHANGELOG.md @@ -6,11 +6,11 @@ ### Breaking Changes -* Instances of `ChainedTokenCredential` and `DefaultAzureCredential` will now skip looping through the list of source credentials and re-use the first successful credential on subsequent calls to `GetToken`. +* Instances of `ChainedTokenCredential` will now skip looping through the list of source credentials and re-use the first successful credential on subsequent calls to `GetToken`. * If `ChainedTokenCredentialOptions.RetrySources` is true, `ChainedTokenCredential` will continue to try all of the originally provided credentials each time the `GetToken` method is called. * `ChainedTokenCredential.successfulCredential` will contain a reference to the last successful credential. - * If `DefaultAzureCredential.RetrySources` is true, `DefaultAzureCredential` will continue to try all of the underlying credentials (`EnvironmentCredential`, `ManagedIdentityCredential`, `AzureCLICredential`) each time the `GetToken` method is called. - * `DefaultAzureCredential.chain.successfulCredential` will contain a reference to the last successful credential. + * `DefaultAzureCredenial` will also re-use the first successful credential on subsequent calls to `GetToken`. + * `DefaultAzureCredential.chain.successfulCredential` will also contain a reference to the last successful credential. ### Bugs Fixed diff --git a/sdk/azidentity/default_azure_credential.go b/sdk/azidentity/default_azure_credential.go index 941757c13b76..b4374acd3553 100644 --- a/sdk/azidentity/default_azure_credential.go +++ b/sdk/azidentity/default_azure_credential.go @@ -27,12 +27,6 @@ type DefaultAzureCredentialOptions struct { // TenantID identifies the tenant the Azure CLI should authenticate in. // Defaults to the CLI's default tenant, which is typically the home tenant of the user logged in to the CLI. TenantID string - - // RetrySources configures how the credential uses its sources. - // When true, the credential will always request a token from each underling credential in turn, - // stopping when one provides a token. When false, the credential requests a token - // only from the credential that previously retrieved a token--it never again tries the sources which failed. - RetrySources bool } // DefaultAzureCredential is a default credential chain for applications that will deploy to Azure. @@ -84,7 +78,7 @@ func NewDefaultAzureCredential(options *DefaultAzureCredentialOptions) (*Default logCredentialError("Default Azure Credential", err) return nil, err } - chain, err := NewChainedTokenCredential(creds, &ChainedTokenCredentialOptions{RetrySources: options.RetrySources}) + chain, err := NewChainedTokenCredential(creds, nil) if err != nil { return nil, err } diff --git a/sdk/azidentity/default_azure_credential_test.go b/sdk/azidentity/default_azure_credential_test.go index 601cd9a5ef1e..7031a1c8dfed 100644 --- a/sdk/azidentity/default_azure_credential_test.go +++ b/sdk/azidentity/default_azure_credential_test.go @@ -27,32 +27,3 @@ func TestDefaultAzureCredential_GetTokenSuccess(t *testing.T) { t.Fatalf("GetToken error: %v", err) } } - -func TestDefaultAzureCredential_WithRetrySources(t *testing.T) { - env := map[string]string{"AZURE_TENANT_ID": tenantID, "AZURE_CLIENT_ID": clientID, "AZURE_CLIENT_SECRET": secret} - setEnvironmentVariables(t, env) - srv, close := mock.NewTLSServer() - defer close() - - srv.AppendResponse(mock.WithBody([]byte(accessTokenRespSuccess))) - srv.AppendResponse(mock.WithBody([]byte(accessTokenRespSuccess))) - - cred, err := NewDefaultAzureCredential(&DefaultAzureCredentialOptions{AuthorityHost: AuthorityHost(srv.URL()), ClientOptions: policy.ClientOptions{Transport: srv}, RetrySources: true}) - if err != nil { - t.Fatalf("Unable to create credential. Received: %v", err) - } - - tk, err := cred.GetToken(context.Background(), policy.TokenRequestOptions{Scopes: []string{scope}}) - if err != nil { - t.Fatalf("Received an error when attempting to get a token but expected none. Error: %v", err) - } - if tk.Token != tokenValue { - t.Fatalf("Received an incorrect access token") - } - if tk.ExpiresOn.IsZero() { - t.Fatalf("Received an incorrect time in the response") - } - if cred.chain.successfulCredential == nil { - t.Fatalf("The successful credential was not assigned") - } -} From 20197a94527b5dc38f8d0270df23c0fccff47e68 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20Rodr=C3=ADguez?= Date: Fri, 10 Dec 2021 18:39:52 +0000 Subject: [PATCH 16/22] simplified tests with new test credentials --- .../chained_token_credential_test.go | 192 +++++------------- 1 file changed, 53 insertions(+), 139 deletions(-) diff --git a/sdk/azidentity/chained_token_credential_test.go b/sdk/azidentity/chained_token_credential_test.go index 3f7d05d1859b..d4089d773123 100644 --- a/sdk/azidentity/chained_token_credential_test.go +++ b/sdk/azidentity/chained_token_credential_test.go @@ -8,6 +8,7 @@ import ( "errors" "net/http" "testing" + "time" "github.com/Azure/azure-sdk-for-go/sdk/azcore" "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" @@ -153,95 +154,36 @@ func TestChainedTokenCredential_GetTokenWithUnavailableCredentialInChain(t *test } } -func TestChainedTokenCredential_ChecksThatSuccessfulCredentialIsSet(t *testing.T) { - err := initEnvironmentVarsForTest() - if err != nil { - t.Fatalf("Could not set environment variables for testing: %v", err) - } - srv, close := mock.NewTLSServer() - defer close() - srv.AppendResponse(mock.WithBody([]byte(accessTokenRespSuccess))) - options := ClientSecretCredentialOptions{} - options.AuthorityHost = AuthorityHost(srv.URL()) - options.Transport = srv - secCred, err := NewClientSecretCredential(tenantID, clientID, secret, &options) - if err != nil { - t.Fatalf("Unable to create credential. Received: %v", err) - } - envCred, err := NewEnvironmentCredential(&EnvironmentCredentialOptions{ - ClientOptions: azcore.ClientOptions{Transport: srv}, - AuthorityHost: AuthorityHost(srv.URL()), - }) - if err != nil { - t.Fatalf("Failed to create environment credential: %v", err) - } - cred, err := NewChainedTokenCredential([]azcore.TokenCredential{secCred, envCred}, nil) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - tk, err := cred.GetToken(context.Background(), policy.TokenRequestOptions{Scopes: []string{scope}}) - if err != nil { - t.Fatalf("Received an error when attempting to get a token but expected none") - } - if tk.Token != tokenValue { - t.Fatalf("Received an incorrect access token") - } - if tk.ExpiresOn.IsZero() { - t.Fatalf("Received an incorrect time in the response") - } - if cred.successfulCredential == nil { - t.Fatalf("The successful credential was not assigned") - } - if cred.successfulCredential != secCred { - t.Fatalf("The successful credential should have been the secret credential") - } +// TestCredential response +type TestCredentialResponse struct { + token *azcore.AccessToken + err error } -/** - * Helps count the number of times a credential is called. - */ -type TestCountPolicy struct{ count int } +// Credential used for testing +type TestCredential struct { + getTokenCalls int + responses []TestCredentialResponse +} -/** - * Helps count the number of times a credential is called. - */ -func (p *TestCountPolicy) Do(req *policy.Request) (*http.Response, error) { - p.count += 1 - return req.Next() +func (c *TestCredential) GetToken(ctx context.Context, opts policy.TokenRequestOptions) (token *azcore.AccessToken, err error) { + index := c.getTokenCalls + c.getTokenCalls += 1 + response := c.responses[index] + return response.token, response.err } func TestChainedTokenCredential_RepeatedGetTokenWithSuccessfulCredential(t *testing.T) { - err := initEnvironmentVarsForTest() - if err != nil { - t.Fatalf("Could not set environment variables for testing: %v", err) - } - srv, close := mock.NewTLSServer() - defer close() - - secretCountPolicy := &TestCountPolicy{} - environmentCountPolicy := &TestCountPolicy{} + failedCredential := &TestCredential{responses: []TestCredentialResponse{ + {err: newCredentialUnavailableError("MockCredential", "Mocking a credential unavailable error")}, + {err: newCredentialUnavailableError("MockCredential", "Mocking a credential unavailable error")}, + }} + successfulCredential := &TestCredential{responses: []TestCredentialResponse{ + {token: &azcore.AccessToken{Token: tokenValue, ExpiresOn: time.Now()}}, + {token: &azcore.AccessToken{Token: tokenValue, ExpiresOn: time.Now()}}, + }} - srv.AppendResponse(mock.WithBody([]byte(accessTokenRespSuccess))) - srv.AppendResponse(mock.WithBody([]byte(accessTokenRespSuccess))) - options := ClientSecretCredentialOptions{} - options.AuthorityHost = AuthorityHost(srv.URL()) - options.PerCallPolicies = []policy.Policy{secretCountPolicy} - options.Transport = srv - secCred, err := NewClientSecretCredential(tenantID, clientID, secret, &options) - if err != nil { - t.Fatalf("Unable to create credential. Received: %v", err) - } - envCred, err := NewEnvironmentCredential(&EnvironmentCredentialOptions{ - ClientOptions: azcore.ClientOptions{ - Transport: srv, - PerCallPolicies: []policy.Policy{environmentCountPolicy}, - }, - AuthorityHost: AuthorityHost(srv.URL()), - }) - if err != nil { - t.Fatalf("Failed to create environment credential: %v", err) - } - cred, err := NewChainedTokenCredential([]azcore.TokenCredential{secCred, envCred}, nil) + cred, err := NewChainedTokenCredential([]azcore.TokenCredential{failedCredential, successfulCredential}, nil) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -258,14 +200,14 @@ func TestChainedTokenCredential_RepeatedGetTokenWithSuccessfulCredential(t *test if cred.successfulCredential == nil { t.Fatalf("The successful credential was not assigned") } - if cred.successfulCredential != secCred { - t.Fatalf("The successful credential should have been the secret credential") + if cred.successfulCredential != successfulCredential { + t.Fatalf("The successful credential should have been the successfulCredential") } - if secretCountPolicy.count != 1 { - t.Fatalf("The secret credential policies should have been triggered once") + if failedCredential.getTokenCalls != 1 { + t.Fatalf("The failed credential getToken should have been called once") } - if environmentCountPolicy.count != 0 { - t.Fatalf("The environment credential policies should not have been triggered") + if successfulCredential.getTokenCalls != 1 { + t.Fatalf("The successful credential getToken should have been called once") } tk2, err2 := cred.GetToken(context.Background(), policy.TokenRequestOptions{Scopes: []string{scope}}) if err2 != nil { @@ -277,53 +219,25 @@ func TestChainedTokenCredential_RepeatedGetTokenWithSuccessfulCredential(t *test if tk2.ExpiresOn.IsZero() { t.Fatalf("Received an incorrect time in the response") } - if secretCountPolicy.count != 2 { - t.Fatalf("The secret credential policies should have been triggered twice") + if failedCredential.getTokenCalls != 1 { + t.Fatalf("The failed credential getToken should not have been called again") } - if environmentCountPolicy.count != 0 { - t.Fatalf("The environment credential policies should not have been triggered") + if successfulCredential.getTokenCalls != 2 { + t.Fatalf("The successful credential getToken should have been called twice") } } -// A credential that always throws a CredentialUnavailableError -type UnavailableCredential struct { - callCount int -} - -func NewUnavailableCredential() (*UnavailableCredential, error) { - return &UnavailableCredential{callCount: 0}, nil -} -func (c *UnavailableCredential) GetToken(ctx context.Context, opts policy.TokenRequestOptions) (token *azcore.AccessToken, err error) { - c.callCount += 1 - return nil, newCredentialUnavailableError("UnavailableCredential", "Expected CredentialUnavailableError") -} - func TestChainedTokenCredential_RepeatedGetTokenWithSuccessfulCredentialWithRetrySources(t *testing.T) { - err := initEnvironmentVarsForTest() - if err != nil { - t.Fatalf("Could not set environment variables for testing: %v", err) - } - srv, close := mock.NewTLSServer() - defer close() - - secretCountPolicy := &TestCountPolicy{} - - srv.AppendResponse(mock.WithBody([]byte(accessTokenRespSuccess))) - srv.AppendResponse(mock.WithBody([]byte(accessTokenRespSuccess))) - - options := ClientSecretCredentialOptions{} - options.AuthorityHost = AuthorityHost(srv.URL()) - options.PerCallPolicies = []policy.Policy{secretCountPolicy} - options.Transport = srv - - unavailableCred, _ := NewUnavailableCredential() - secCred, err := NewClientSecretCredential(tenantID, clientID, secret, &options) - if err != nil { - t.Fatalf("Unable to create credential. Received: %v", err) - } + failedCredential := &TestCredential{responses: []TestCredentialResponse{ + {err: newCredentialUnavailableError("MockCredential", "Mocking a credential unavailable error")}, + {err: newCredentialUnavailableError("MockCredential", "Mocking a credential unavailable error")}, + }} + successfulCredential := &TestCredential{responses: []TestCredentialResponse{ + {token: &azcore.AccessToken{Token: tokenValue, ExpiresOn: time.Now()}}, + {token: &azcore.AccessToken{Token: tokenValue, ExpiresOn: time.Now()}}, + }} - // Backwards order: envCred first, secCred later, to check that envCred is always called when RetrySources is set to true. - cred, err := NewChainedTokenCredential([]azcore.TokenCredential{unavailableCred, secCred}, &ChainedTokenCredentialOptions{RetrySources: true}) + cred, err := NewChainedTokenCredential([]azcore.TokenCredential{failedCredential, successfulCredential}, &ChainedTokenCredentialOptions{RetrySources: true}) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -340,14 +254,14 @@ func TestChainedTokenCredential_RepeatedGetTokenWithSuccessfulCredentialWithRetr if cred.successfulCredential == nil { t.Fatalf("The successful credential was not assigned") } - if cred.successfulCredential != secCred { - t.Fatalf("The successful credential should have been the secret credential") + if cred.successfulCredential != successfulCredential { + t.Fatalf("The successful credential should have been the successfulCredential") } - if secretCountPolicy.count != 1 { - t.Fatalf("The secret credential policies should have been triggered once") + if failedCredential.getTokenCalls != 1 { + t.Fatalf("The failed credential getToken should have been called once") } - if unavailableCred.callCount != 1 { - t.Fatalf("The environment credential policies should have been triggered once") + if successfulCredential.getTokenCalls != 1 { + t.Fatalf("The successful credential getToken should have been called once") } tk2, err2 := cred.GetToken(context.Background(), policy.TokenRequestOptions{Scopes: []string{scope}}) if err2 != nil { @@ -359,11 +273,11 @@ func TestChainedTokenCredential_RepeatedGetTokenWithSuccessfulCredentialWithRetr if tk2.ExpiresOn.IsZero() { t.Fatalf("Received an incorrect time in the response") } - if secretCountPolicy.count != 2 { - t.Fatalf("The secret credential policies should have been triggered twice") + if failedCredential.getTokenCalls != 2 { + t.Fatalf("The failed credential getToken should have been called twice") } - if unavailableCred.callCount != 2 { - t.Fatalf("The environment credential policies should have been triggered twice") + if successfulCredential.getTokenCalls != 2 { + t.Fatalf("The successful credential getToken should have been called twice") } } From e14544a86e1c0e5344a0efe86e158af843dd46f1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20Rodr=C3=ADguez?= Date: Fri, 10 Dec 2021 22:30:35 +0000 Subject: [PATCH 17/22] removed the formatError function --- sdk/azidentity/chained_token_credential.go | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/sdk/azidentity/chained_token_credential.go b/sdk/azidentity/chained_token_credential.go index 960029ab57f9..056431b4b629 100644 --- a/sdk/azidentity/chained_token_credential.go +++ b/sdk/azidentity/chained_token_credential.go @@ -56,16 +56,6 @@ func NewChainedTokenCredential(sources []azcore.TokenCredential, options *Chaine func (c *ChainedTokenCredential) GetToken(ctx context.Context, opts policy.TokenRequestOptions) (token *azcore.AccessToken, err error) { var errList []CredentialUnavailableError - formatError := func(err error) error { - var authFailed AuthenticationFailedError - if errors.As(err, &authFailed) { - err = fmt.Errorf("Authentication failed:\n%s\n%s"+createChainedErrorMessage(errList), err) - authErr := newAuthenticationFailedError(err, authFailed.RawResponse()) - return authErr - } - return err - } - if c.successfulCredential != nil && !c.retrySources { token, err = c.successfulCredential.GetToken(ctx, opts) if err != nil { @@ -79,7 +69,13 @@ func (c *ChainedTokenCredential) GetToken(ctx context.Context, opts policy.Token if errors.As(err, &credErr) { errList = append(errList, credErr) } else if err != nil { - return nil, formatError(err) + var authFailed AuthenticationFailedError + if errors.As(err, &authFailed) { + err = fmt.Errorf("Authentication failed:\n%s\n%s"+createChainedErrorMessage(errList), err) + authErr := newAuthenticationFailedError(err, authFailed.RawResponse()) + return nil, authErr + } + return nil, err } else { logGetTokenSuccess(c, opts) c.successfulCredential = cred From a19c2f27dcbdb3efe444e80cdf7bc59829f0a577 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20Rodr=C3=ADguez?= Date: Fri, 10 Dec 2021 22:41:52 +0000 Subject: [PATCH 18/22] Avoiding testing that .successfulCredential is set --- sdk/azidentity/chained_token_credential_test.go | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/sdk/azidentity/chained_token_credential_test.go b/sdk/azidentity/chained_token_credential_test.go index d99a9a21725e..45fda2e3ed82 100644 --- a/sdk/azidentity/chained_token_credential_test.go +++ b/sdk/azidentity/chained_token_credential_test.go @@ -184,12 +184,6 @@ func TestChainedTokenCredential_RepeatedGetTokenWithSuccessfulCredential(t *test if tk.ExpiresOn.IsZero() { t.Fatalf("Received an incorrect time in the response") } - if cred.successfulCredential == nil { - t.Fatalf("The successful credential was not assigned") - } - if cred.successfulCredential != successfulCredential { - t.Fatalf("The successful credential should have been the successfulCredential") - } if failedCredential.getTokenCalls != 1 { t.Fatalf("The failed credential getToken should have been called once") } @@ -238,12 +232,6 @@ func TestChainedTokenCredential_RepeatedGetTokenWithSuccessfulCredentialWithRetr if tk.ExpiresOn.IsZero() { t.Fatalf("Received an incorrect time in the response") } - if cred.successfulCredential == nil { - t.Fatalf("The successful credential was not assigned") - } - if cred.successfulCredential != successfulCredential { - t.Fatalf("The successful credential should have been the successfulCredential") - } if failedCredential.getTokenCalls != 1 { t.Fatalf("The failed credential getToken should have been called once") } From af22a8125be0b1e4c073d7bbb7f29a5c2d35ebdf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20Rodr=C3=ADguez?= Date: Mon, 13 Dec 2021 16:41:51 -0500 Subject: [PATCH 19/22] Apply suggestions from code review Co-authored-by: Charles Lowell <10964656+chlowell@users.noreply.github.com> --- sdk/azidentity/chained_token_credential.go | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/sdk/azidentity/chained_token_credential.go b/sdk/azidentity/chained_token_credential.go index 056431b4b629..3fee9cd12cbb 100644 --- a/sdk/azidentity/chained_token_credential.go +++ b/sdk/azidentity/chained_token_credential.go @@ -57,11 +57,7 @@ func (c *ChainedTokenCredential) GetToken(ctx context.Context, opts policy.Token var errList []CredentialUnavailableError if c.successfulCredential != nil && !c.retrySources { - token, err = c.successfulCredential.GetToken(ctx, opts) - if err != nil { - return nil, err - } - return token, nil + return c.successfulCredential.GetToken(ctx, opts) } for _, cred := range c.sources { token, err = cred.GetToken(ctx, opts) From 2080f9635cba20b79f2fc07929a6dd0f9e108ac4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20Rodr=C3=ADguez?= Date: Mon, 13 Dec 2021 22:48:35 +0000 Subject: [PATCH 20/22] removed some redundant tests, the unavailableCredential, unexposed the testCredentialResponse and added a re-usable function to test a valid token response --- .../chained_token_credential_test.go | 106 ++++++------------ 1 file changed, 32 insertions(+), 74 deletions(-) diff --git a/sdk/azidentity/chained_token_credential_test.go b/sdk/azidentity/chained_token_credential_test.go index 45fda2e3ed82..8d5cd15508e3 100644 --- a/sdk/azidentity/chained_token_credential_test.go +++ b/sdk/azidentity/chained_token_credential_test.go @@ -113,36 +113,8 @@ func TestChainedTokenCredential_GetTokenFail(t *testing.T) { } } -type unavailableCredential struct{} - -func (*unavailableCredential) GetToken(ctx context.Context, opts policy.TokenRequestOptions) (token *azcore.AccessToken, err error) { - return nil, newCredentialUnavailableError("unavailableCredential", "is unavailable") -} - -func TestChainedTokenCredential_GetTokenWithUnavailableCredentialInChain(t *testing.T) { - secCred, err := NewClientSecretCredential(fakeTenantID, fakeClientID, secret, nil) - if err != nil { - t.Fatalf("Unable to create credential. Received: %v", err) - } - secCred.client = fakeConfidentialClient{ar: confidential.AuthResult{AccessToken: tokenValue, ExpiresOn: time.Now().Add(time.Hour)}} - cred, err := NewChainedTokenCredential([]azcore.TokenCredential{&unavailableCredential{}, secCred}, nil) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - tk, err := cred.GetToken(context.Background(), policy.TokenRequestOptions{Scopes: []string{liveTestScope}}) - if err != nil { - t.Fatalf("Received an error when attempting to get a token but expected none") - } - if tk.Token != tokenValue { - t.Fatalf("Received an incorrect access token") - } - if tk.ExpiresOn.IsZero() { - t.Fatalf("Received an incorrect time in the response") - } -} - // TestCredential response -type TestCredentialResponse struct { +type testCredentialResponse struct { token *azcore.AccessToken err error } @@ -150,7 +122,7 @@ type TestCredentialResponse struct { // Credential used for testing type TestCredential struct { getTokenCalls int - responses []TestCredentialResponse + responses []testCredentialResponse } func (c *TestCredential) GetToken(ctx context.Context, opts policy.TokenRequestOptions) (token *azcore.AccessToken, err error) { @@ -160,12 +132,24 @@ func (c *TestCredential) GetToken(ctx context.Context, opts policy.TokenRequestO return response.token, response.err } +func testGoodGetTokenResponse(t *testing.T, token *azcore.AccessToken, err error) { + if err != nil { + t.Fatalf("Received an error when attempting to get a token but expected none") + } + if token.Token != tokenValue { + t.Fatalf("Received an incorrect access token") + } + if token.ExpiresOn.IsZero() { + t.Fatalf("Received an incorrect time in the response") + } +} + func TestChainedTokenCredential_RepeatedGetTokenWithSuccessfulCredential(t *testing.T) { - failedCredential := &TestCredential{responses: []TestCredentialResponse{ + failedCredential := &TestCredential{responses: []testCredentialResponse{ {err: newCredentialUnavailableError("MockCredential", "Mocking a credential unavailable error")}, {err: newCredentialUnavailableError("MockCredential", "Mocking a credential unavailable error")}, }} - successfulCredential := &TestCredential{responses: []TestCredentialResponse{ + successfulCredential := &TestCredential{responses: []testCredentialResponse{ {token: &azcore.AccessToken{Token: tokenValue, ExpiresOn: time.Now()}}, {token: &azcore.AccessToken{Token: tokenValue, ExpiresOn: time.Now()}}, }} @@ -174,32 +158,19 @@ func TestChainedTokenCredential_RepeatedGetTokenWithSuccessfulCredential(t *test if err != nil { t.Fatalf("unexpected error: %v", err) } - tk, err := cred.GetToken(context.Background(), policy.TokenRequestOptions{Scopes: []string{liveTestScope}}) - if err != nil { - t.Fatalf("Received an error when attempting to get a token but expected none") - } - if tk.Token != tokenValue { - t.Fatalf("Received an incorrect access token") - } - if tk.ExpiresOn.IsZero() { - t.Fatalf("Received an incorrect time in the response") - } + + getTokenOptions := policy.TokenRequestOptions{Scopes: []string{liveTestScope}} + + tk, err := cred.GetToken(context.Background(), getTokenOptions) + testGoodGetTokenResponse(t, tk, err) if failedCredential.getTokenCalls != 1 { t.Fatalf("The failed credential getToken should have been called once") } if successfulCredential.getTokenCalls != 1 { t.Fatalf("The successful credential getToken should have been called once") } - tk2, err2 := cred.GetToken(context.Background(), policy.TokenRequestOptions{Scopes: []string{liveTestScope}}) - if err2 != nil { - t.Fatalf("Received an error when attempting to get a token but expected none. Error: %v", err2) - } - if tk2.Token != tokenValue { - t.Fatalf("Received an incorrect access token") - } - if tk2.ExpiresOn.IsZero() { - t.Fatalf("Received an incorrect time in the response") - } + tk2, err2 := cred.GetToken(context.Background(), getTokenOptions) + testGoodGetTokenResponse(t, tk2, err2) if failedCredential.getTokenCalls != 1 { t.Fatalf("The failed credential getToken should not have been called again") } @@ -209,11 +180,11 @@ func TestChainedTokenCredential_RepeatedGetTokenWithSuccessfulCredential(t *test } func TestChainedTokenCredential_RepeatedGetTokenWithSuccessfulCredentialWithRetrySources(t *testing.T) { - failedCredential := &TestCredential{responses: []TestCredentialResponse{ + failedCredential := &TestCredential{responses: []testCredentialResponse{ {err: newCredentialUnavailableError("MockCredential", "Mocking a credential unavailable error")}, {err: newCredentialUnavailableError("MockCredential", "Mocking a credential unavailable error")}, }} - successfulCredential := &TestCredential{responses: []TestCredentialResponse{ + successfulCredential := &TestCredential{responses: []testCredentialResponse{ {token: &azcore.AccessToken{Token: tokenValue, ExpiresOn: time.Now()}}, {token: &azcore.AccessToken{Token: tokenValue, ExpiresOn: time.Now()}}, }} @@ -222,32 +193,19 @@ func TestChainedTokenCredential_RepeatedGetTokenWithSuccessfulCredentialWithRetr if err != nil { t.Fatalf("unexpected error: %v", err) } - tk, err := cred.GetToken(context.Background(), policy.TokenRequestOptions{Scopes: []string{liveTestScope}}) - if err != nil { - t.Fatalf("Received an error when attempting to get a token but expected none. Error: %v", err) - } - if tk.Token != tokenValue { - t.Fatalf("Received an incorrect access token") - } - if tk.ExpiresOn.IsZero() { - t.Fatalf("Received an incorrect time in the response") - } + + getTokenOptions := policy.TokenRequestOptions{Scopes: []string{liveTestScope}} + + tk, err := cred.GetToken(context.Background(), getTokenOptions) + testGoodGetTokenResponse(t, tk, err) if failedCredential.getTokenCalls != 1 { t.Fatalf("The failed credential getToken should have been called once") } if successfulCredential.getTokenCalls != 1 { t.Fatalf("The successful credential getToken should have been called once") } - tk2, err2 := cred.GetToken(context.Background(), policy.TokenRequestOptions{Scopes: []string{liveTestScope}}) - if err2 != nil { - t.Fatalf("Received an error when attempting to get a token but expected none. Error: %v", err2) - } - if tk2.Token != tokenValue { - t.Fatalf("Received an incorrect access token") - } - if tk2.ExpiresOn.IsZero() { - t.Fatalf("Received an incorrect time in the response") - } + tk2, err2 := cred.GetToken(context.Background(), getTokenOptions) + testGoodGetTokenResponse(t, tk2, err2) if failedCredential.getTokenCalls != 2 { t.Fatalf("The failed credential getToken should have been called twice") } From ae92c66773c9b59e61681dcfdbea4ea98398bc8e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20Rodr=C3=ADguez?= Date: Mon, 13 Dec 2021 23:02:01 +0000 Subject: [PATCH 21/22] formatting --- sdk/azidentity/chained_token_credential.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdk/azidentity/chained_token_credential.go b/sdk/azidentity/chained_token_credential.go index 3fee9cd12cbb..fe833d637674 100644 --- a/sdk/azidentity/chained_token_credential.go +++ b/sdk/azidentity/chained_token_credential.go @@ -57,7 +57,7 @@ func (c *ChainedTokenCredential) GetToken(ctx context.Context, opts policy.Token var errList []CredentialUnavailableError if c.successfulCredential != nil && !c.retrySources { - return c.successfulCredential.GetToken(ctx, opts) + return c.successfulCredential.GetToken(ctx, opts) } for _, cred := range c.sources { token, err = cred.GetToken(ctx, opts) From 0cac965be1aad0bcba7eaf90900f5e20387e327c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20Rodr=C3=ADguez?= Date: Mon, 10 Jan 2022 19:33:24 -0500 Subject: [PATCH 22/22] Update sdk/azidentity/chained_token_credential_test.go Co-authored-by: Charles Lowell <10964656+chlowell@users.noreply.github.com> --- sdk/azidentity/chained_token_credential_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdk/azidentity/chained_token_credential_test.go b/sdk/azidentity/chained_token_credential_test.go index 8d5cd15508e3..1863df3d2189 100644 --- a/sdk/azidentity/chained_token_credential_test.go +++ b/sdk/azidentity/chained_token_credential_test.go @@ -164,7 +164,7 @@ func TestChainedTokenCredential_RepeatedGetTokenWithSuccessfulCredential(t *test tk, err := cred.GetToken(context.Background(), getTokenOptions) testGoodGetTokenResponse(t, tk, err) if failedCredential.getTokenCalls != 1 { - t.Fatalf("The failed credential getToken should have been called once") + t.Fatal("The failed credential getToken should have been called once") } if successfulCredential.getTokenCalls != 1 { t.Fatalf("The successful credential getToken should have been called once")