diff --git a/sdk/azidentity/bearer_token_policy.go b/sdk/azidentity/bearer_token_policy.go index d330afe98cc6..50b13cde5317 100644 --- a/sdk/azidentity/bearer_token_policy.go +++ b/sdk/azidentity/bearer_token_policy.go @@ -4,7 +4,9 @@ package azidentity import ( + "fmt" "net/http" + "strings" "sync" "time" @@ -16,97 +18,158 @@ const ( ) type bearerTokenPolicy struct { - // cond is used to synchronize token refresh. the locker - // must be locked when updating the following shared state. + // mainResource is the resource to be retreived using the tenant specified in the credential + mainResource *expiringResource + // auxResources are additional resources that are required for cross-tenant applications + auxResources map[string]*expiringResource + // the following fields are read-only + creds azcore.TokenCredential + options azcore.TokenRequestOptions +} + +type expiringResource struct { + // cond is used to synchronize access to the shared resource embodied by the remaining fields cond *sync.Cond - // renewing indicates that the token is in the process of being refreshed - renewing bool + // acquiring indicates that some thread/goroutine is in the process of acquiring/updating the resource + acquiring bool - // header contains the authorization header value - header string + // resource contains the value of the shared resource + resource interface{} - // expiresOn is when the token will expire - expiresOn time.Time + // expiration indicates when the shared resource expires; it is 0 if the resource was never acquired + expiration time.Time - // the following fields are read-only - creds azcore.TokenCredential - options azcore.TokenRequestOptions + // acquireResource is the callback function that actually acquires the resource + acquireResource acquireResource } -func newBearerTokenPolicy(creds azcore.TokenCredential, opts azcore.AuthenticationPolicyOptions) *bearerTokenPolicy { - return &bearerTokenPolicy{ - cond: sync.NewCond(&sync.Mutex{}), - creds: creds, - options: opts.Options, - } +type acquireResource func(state interface{}) (newResource interface{}, newExpiration time.Time, err error) + +type acquiringResourceState struct { + req *azcore.Request + p bearerTokenPolicy } -func (b *bearerTokenPolicy) Do(req *azcore.Request) (*azcore.Response, error) { - if req.URL.Scheme != "https" { - // HTTPS must be used, otherwise the tokens are at the risk of being exposed - return nil, &AuthenticationFailedError{msg: "token credentials require a URL using the HTTPS protocol scheme"} +// acquire acquires or updates the resource; only one +// thread/goroutine at a time ever calls this function +func acquire(state interface{}) (newResource interface{}, newExpiration time.Time, err error) { + s := state.(acquiringResourceState) + tk, err := s.p.creds.GetToken(s.req.Context(), s.p.options) + if err != nil { + return nil, time.Time{}, err } - // create a "refresh window" before the token's real expiration date. - // this allows callers to continue to use the old token while the - // refresh is in progress. - const window = 2 * time.Minute - now, getToken, header := time.Now(), false, "" + return tk, tk.ExpiresOn, nil +} + +func newExpiringResource(ar acquireResource) *expiringResource { + return &expiringResource{cond: sync.NewCond(&sync.Mutex{}), acquireResource: ar} +} + +func (er *expiringResource) GetResource(state interface{}) (interface{}, error) { + // If the resource is expiring within this time window, update it eagerly. + // This allows other threads/goroutines to keep running by using the not-yet-expired + // resource value while one thread/goroutine updates the resource. + const window = 2 * time.Minute // This example updates the resource 2 minutes prior to expiration + + now, acquire, resource := time.Now(), false, er.resource // acquire exclusive lock - b.cond.L.Lock() + er.cond.L.Lock() for { - if b.expiresOn.IsZero() || b.expiresOn.Before(now) { - // token was never obtained or has expired - if !b.renewing { - // another go routine isn't refreshing the token so this one will - b.renewing = true - getToken = true + if er.expiration.IsZero() || er.expiration.Before(now) { + // The resource was never acquired or has expired + if !er.acquiring { + // If another thread/goroutine is not acquiring/updating the resource, this thread/goroutine will do it + er.acquiring, acquire = true, true break } - // getting here means this go routine will wait for the token to refresh - } else if b.expiresOn.Add(-window).Before(now) { - // token is within the expiration window - if !b.renewing { - // another go routine isn't refreshing the token so this one will - b.renewing = true - getToken = true + // Getting here means that this thread/goroutine will wait for the updated resource + } else if er.expiration.Add(-window).Before(now) { + // The resource is valid but is expiring within the time window + if !er.acquiring { + // If another thread/goroutine is not acquiring/renewing the resource, this thread/goroutine will do it + er.acquiring, acquire = true, true break } - // this go routine will use the existing token while another refreshes it - header = b.header + // This thread/goroutine will use the existing resource value while another updates it + resource = er.resource break } else { - // token is not expiring yet so use it as-is - header = b.header + // The resource is not close to expiring, this thread/goroutine should use its current value + resource = er.resource break } - // wait for the token to refresh - b.cond.Wait() + // If we get here, wait for the new resource value to be acquired/updated + er.cond.Wait() } - b.cond.L.Unlock() - if getToken { - // this go routine has been elected to refresh the token - tk, err := b.creds.GetToken(req.Context(), b.options) - // update shared state - b.cond.L.Lock() - // to avoid a deadlock if GetToken() fails we MUST reset b.renewing to false before returning - b.renewing = false + er.cond.L.Unlock() // Release the lock so no threads/goroutines are blocked + + var err error + if acquire { + // This thread/goroutine has been selected to acquire/update the resource + var expiration time.Time + resource, expiration, err = er.acquireResource(state) + + // Atomically, update the shared resource's new value & expiration. + er.cond.L.Lock() + if err == nil { + // No error, update resource & expiration + er.resource, er.expiration = resource, expiration + } + er.acquiring = false // Indicate that no thread/goroutine is currently acquiring the resrouce + + // Wake up any waiting threads/goroutines since there is a resource they can ALL use + er.cond.L.Unlock() + er.cond.Broadcast() + } + return resource, err // Return the resource this thread/goroutine can use +} + +func newBearerTokenPolicy(creds azcore.TokenCredential, opts azcore.AuthenticationOptions) *bearerTokenPolicy { + p := &bearerTokenPolicy{ + creds: creds, + options: opts.TokenRequest, + mainResource: newExpiringResource(acquire), + } + if len(opts.AuxiliaryTenants) > 0 { + p.auxResources = map[string]*expiringResource{} + } + for _, t := range opts.AuxiliaryTenants { + p.auxResources[t] = newExpiringResource(acquire) + + } + return p +} + +func (b *bearerTokenPolicy) Do(req *azcore.Request) (*azcore.Response, error) { + as := acquiringResourceState{ + p: *b, + req: req, + } + tk, err := b.mainResource.GetResource(as) + if err != nil { + return nil, err + } + if token, ok := tk.(*azcore.AccessToken); ok { + req.Request.Header.Set(headerXmsDate, time.Now().UTC().Format(http.TimeFormat)) + req.Request.Header.Set(headerAuthorization, fmt.Sprintf("Bearer %s", token.Token)) + } + auxTokens := []string{} + for tenant, er := range b.auxResources { + bCopy := *b + bCopy.options.TenantID = tenant + auxAS := acquiringResourceState{ + p: bCopy, + req: req, + } + auxTk, err := er.GetResource(auxAS) if err != nil { - b.unlock() return nil, err } - header = bearerTokenPrefix + tk.Token - b.header = header - b.expiresOn = tk.ExpiresOn - b.unlock() + auxTokens = append(auxTokens, fmt.Sprintf("%s%s", bearerTokenPrefix, auxTk.(*azcore.AccessToken).Token)) + } + if len(auxTokens) > 0 { + req.Request.Header.Set(headerAuxiliaryAuthorization, strings.Join(auxTokens, ", ")) } - req.Request.Header.Set(azcore.HeaderXmsDate, time.Now().UTC().Format(http.TimeFormat)) - req.Request.Header.Set(azcore.HeaderAuthorization, header) return req.Next() } - -// signal any waiters that the token has been refreshed -func (b *bearerTokenPolicy) unlock() { - b.cond.Broadcast() - b.cond.L.Unlock() -} diff --git a/sdk/azidentity/bearer_token_policy_test.go b/sdk/azidentity/bearer_token_policy_test.go index c3f2b0991af2..0ce05cd47b0d 100644 --- a/sdk/azidentity/bearer_token_policy_test.go +++ b/sdk/azidentity/bearer_token_policy_test.go @@ -29,7 +29,7 @@ func defaultTestPipeline(srv azcore.Transport, cred azcore.Credential, scope str return azcore.NewPipeline( srv, azcore.NewRetryPolicy(&retryOpts), - cred.AuthenticationPolicy(azcore.AuthenticationPolicyOptions{Options: azcore.TokenRequestOptions{Scopes: []string{scope}}}), + cred.NewAuthenticationPolicy(azcore.AuthenticationOptions{TokenRequest: azcore.TokenRequestOptions{Scopes: []string{scope}}}), azcore.NewLogPolicy(nil)) } @@ -55,7 +55,7 @@ func TestBearerPolicy_SuccessGetToken(t *testing.T) { t.Fatalf("Expected nil error but received one") } const expectedToken = bearerTokenPrefix + tokenValue - if token := resp.Request.Header.Get(azcore.HeaderAuthorization); token != expectedToken { + if token := resp.Request.Header.Get(headerAuthorization); token != expectedToken { t.Fatalf("expected token '%s', got '%s'", expectedToken, token) } } @@ -197,3 +197,52 @@ func TestBearerPolicy_GetTokenFailsNoDeadlock(t *testing.T) { t.Fatal("expected nil response") } } + +func TestBearerTokenWithAuxiliaryTenants(t *testing.T) { + srv, close := mock.NewTLSServer() + defer close() + headerResult := "Bearer new_token, Bearer new_token, Bearer new_token" + srv.AppendResponse(mock.WithBody([]byte(accessTokenRespSuccess))) + srv.AppendResponse(mock.WithBody([]byte(accessTokenRespSuccess))) + srv.AppendResponse(mock.WithBody([]byte(accessTokenRespSuccess))) + srv.AppendResponse(mock.WithBody([]byte(accessTokenRespSuccess))) + srv.AppendResponse() + options := ClientSecretCredentialOptions{ + AuthorityHost: srv.URL(), + HTTPClient: srv, + } + cred, err := NewClientSecretCredential(tenantID, clientID, secret, &options) + if err != nil { + t.Fatalf("Unable to create credential. Received: %v", err) + } + retryOpts := azcore.RetryOptions{ + MaxRetryDelay: 500 * time.Millisecond, + RetryDelay: 50 * time.Millisecond, + } + pipeline := azcore.NewPipeline( + srv, + azcore.NewRetryPolicy(&retryOpts), + cred.NewAuthenticationPolicy( + azcore.AuthenticationOptions{ + TokenRequest: azcore.TokenRequestOptions{ + Scopes: []string{scope}, + }, + AuxiliaryTenants: []string{"tenant1", "tenant2", "tenant3"}, + }), + azcore.NewLogPolicy(nil)) + + req, err := azcore.NewRequest(context.Background(), http.MethodGet, srv.URL()) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + resp, err := pipeline.Do(req) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + if resp.StatusCode != http.StatusOK { + t.Fatalf("unexpected status code: %d", resp.StatusCode) + } + if auxH := resp.Request.Header.Get(headerAuxiliaryAuthorization); auxH != headerResult { + t.Fatalf("unexpected auxiliary authorization header %s", auxH) + } +}