Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[azidentity] Expiring resource rework w/ multi-tenant support #15138

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions sdk/azidentity/azidentity.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,13 @@ const (
)

const (
headerXmsDate = "x-ms-date"
headerUserAgent = "User-Agent"
headerURLEncoded = "application/x-www-form-urlencoded"
headerAuthorization = "Authorization"
headerMetadata = "Metadata"
headerContentType = "Content-Type"
headerXmsDate = "x-ms-date"
headerUserAgent = "User-Agent"
headerURLEncoded = "application/x-www-form-urlencoded"
headerAuthorization = "Authorization"
headerAuxiliaryAuthorization = "x-ms-authorization-auxiliary"
headerMetadata = "Metadata"
headerContentType = "Content-Type"
)

const tenantIDValidationErr = "Invalid tenantID provided. You can locate your tenantID by following the instructions listed here: https://docs.microsoft.com/partner-center/find-ids-and-domain-names."
Expand Down
197 changes: 130 additions & 67 deletions sdk/azidentity/bearer_token_policy.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
package azidentity

import (
"fmt"
"net/http"
"strings"
"sync"
"time"

Expand All @@ -16,97 +18,158 @@ const (
)

type bearerTokenPolicy struct {
// cond is used to synchronize token refresh. the locker
// must be locked when updating the following shared state.
// mainResource is the resource to be retreived using the tenant specified in the credential
mainResource *expiringResource
// auxResources are additional resources that are required for cross-tenant applications
auxResources map[string]*expiringResource
// the following fields are read-only
creds azcore.TokenCredential
options azcore.TokenRequestOptions
}

type expiringResource struct {
// cond is used to synchronize access to the shared resource embodied by the remaining fields
cond *sync.Cond

// renewing indicates that the token is in the process of being refreshed
renewing bool
// acquiring indicates that some thread/goroutine is in the process of acquiring/updating the resource
acquiring bool

// header contains the authorization header value
header string
// resource contains the value of the shared resource
resource interface{}

// expiresOn is when the token will expire
expiresOn time.Time
// expiration indicates when the shared resource expires; it is 0 if the resource was never acquired
expiration time.Time

// the following fields are read-only
creds azcore.TokenCredential
options azcore.TokenRequestOptions
// acquireResource is the callback function that actually acquires the resource
acquireResource acquireResource
}

func newBearerTokenPolicy(creds azcore.TokenCredential, opts azcore.AuthenticationOptions) *bearerTokenPolicy {
return &bearerTokenPolicy{
cond: sync.NewCond(&sync.Mutex{}),
creds: creds,
options: opts.TokenRequest,
}
type acquireResource func(state interface{}) (newResource interface{}, newExpiration time.Time, err error)

type acquiringResourceState struct {
req *azcore.Request
p bearerTokenPolicy
}

func (b *bearerTokenPolicy) Do(req *azcore.Request) (*azcore.Response, error) {
if req.URL.Scheme != "https" {
// HTTPS must be used, otherwise the tokens are at the risk of being exposed
return nil, &AuthenticationFailedError{msg: "token credentials require a URL using the HTTPS protocol scheme"}
// acquire acquires or updates the resource; only one
// thread/goroutine at a time ever calls this function
func acquire(state interface{}) (newResource interface{}, newExpiration time.Time, err error) {
s := state.(acquiringResourceState)
tk, err := s.p.creds.GetToken(s.req.Context(), s.p.options)
if err != nil {
return nil, time.Time{}, err
}
// create a "refresh window" before the token's real expiration date.
// this allows callers to continue to use the old token while the
// refresh is in progress.
const window = 2 * time.Minute
now, getToken, header := time.Now(), false, ""
return tk, tk.ExpiresOn, nil
}

func newExpiringResource(ar acquireResource) *expiringResource {
return &expiringResource{cond: sync.NewCond(&sync.Mutex{}), acquireResource: ar}
}

func (er *expiringResource) GetResource(state interface{}) (interface{}, error) {
// If the resource is expiring within this time window, update it eagerly.
// This allows other threads/goroutines to keep running by using the not-yet-expired
// resource value while one thread/goroutine updates the resource.
const window = 2 * time.Minute // This example updates the resource 2 minutes prior to expiration

now, acquire, resource := time.Now(), false, er.resource
// acquire exclusive lock
b.cond.L.Lock()
er.cond.L.Lock()
for {
if b.expiresOn.IsZero() || b.expiresOn.Before(now) {
// token was never obtained or has expired
if !b.renewing {
// another go routine isn't refreshing the token so this one will
b.renewing = true
getToken = true
if er.expiration.IsZero() || er.expiration.Before(now) {
// The resource was never acquired or has expired
if !er.acquiring {
// If another thread/goroutine is not acquiring/updating the resource, this thread/goroutine will do it
er.acquiring, acquire = true, true
break
}
// getting here means this go routine will wait for the token to refresh
} else if b.expiresOn.Add(-window).Before(now) {
// token is within the expiration window
if !b.renewing {
// another go routine isn't refreshing the token so this one will
b.renewing = true
getToken = true
// Getting here means that this thread/goroutine will wait for the updated resource
} else if er.expiration.Add(-window).Before(now) {
// The resource is valid but is expiring within the time window
if !er.acquiring {
// If another thread/goroutine is not acquiring/renewing the resource, this thread/goroutine will do it
er.acquiring, acquire = true, true
break
}
// this go routine will use the existing token while another refreshes it
header = b.header
// This thread/goroutine will use the existing resource value while another updates it
resource = er.resource
break
} else {
// token is not expiring yet so use it as-is
header = b.header
// The resource is not close to expiring, this thread/goroutine should use its current value
resource = er.resource
break
}
// wait for the token to refresh
b.cond.Wait()
// If we get here, wait for the new resource value to be acquired/updated
er.cond.Wait()
}
b.cond.L.Unlock()
if getToken {
// this go routine has been elected to refresh the token
tk, err := b.creds.GetToken(req.Context(), b.options)
// update shared state
b.cond.L.Lock()
// to avoid a deadlock if GetToken() fails we MUST reset b.renewing to false before returning
b.renewing = false
er.cond.L.Unlock() // Release the lock so no threads/goroutines are blocked

var err error
if acquire {
// This thread/goroutine has been selected to acquire/update the resource
var expiration time.Time
resource, expiration, err = er.acquireResource(state)

// Atomically, update the shared resource's new value & expiration.
er.cond.L.Lock()
if err == nil {
// No error, update resource & expiration
er.resource, er.expiration = resource, expiration
}
er.acquiring = false // Indicate that no thread/goroutine is currently acquiring the resrouce

// Wake up any waiting threads/goroutines since there is a resource they can ALL use
er.cond.L.Unlock()
er.cond.Broadcast()
}
return resource, err // Return the resource this thread/goroutine can use
}

func newBearerTokenPolicy(creds azcore.TokenCredential, opts azcore.AuthenticationOptions) *bearerTokenPolicy {
p := &bearerTokenPolicy{
creds: creds,
options: opts.TokenRequest,
mainResource: newExpiringResource(acquire),
}
if len(opts.AuxiliaryTenants) > 0 {
p.auxResources = map[string]*expiringResource{}
}
for _, t := range opts.AuxiliaryTenants {
p.auxResources[t] = newExpiringResource(acquire)

}
return p
}

func (b *bearerTokenPolicy) Do(req *azcore.Request) (*azcore.Response, error) {
as := acquiringResourceState{
p: *b,
req: req,
}
tk, err := b.mainResource.GetResource(as)
if err != nil {
return nil, err
}
if token, ok := tk.(*azcore.AccessToken); ok {
req.Request.Header.Set(headerXmsDate, time.Now().UTC().Format(http.TimeFormat))
req.Request.Header.Set(headerAuthorization, fmt.Sprintf("Bearer %s", token.Token))
}
auxTokens := []string{}
for tenant, er := range b.auxResources {
bCopy := *b
bCopy.options.TenantID = tenant
auxAS := acquiringResourceState{
p: bCopy,
req: req,
}
auxTk, err := er.GetResource(auxAS)
if err != nil {
b.unlock()
return nil, err
}
header = bearerTokenPrefix + tk.Token
b.header = header
b.expiresOn = tk.ExpiresOn
b.unlock()
auxTokens = append(auxTokens, fmt.Sprintf("%s%s", bearerTokenPrefix, auxTk.(*azcore.AccessToken).Token))
}
if len(auxTokens) > 0 {
req.Request.Header.Set(headerAuxiliaryAuthorization, strings.Join(auxTokens, ", "))
}
req.Request.Header.Set(headerXmsDate, time.Now().UTC().Format(http.TimeFormat))
req.Request.Header.Set(headerAuthorization, header)
return req.Next()
}

// signal any waiters that the token has been refreshed
func (b *bearerTokenPolicy) unlock() {
b.cond.Broadcast()
b.cond.L.Unlock()
}
49 changes: 49 additions & 0 deletions sdk/azidentity/bearer_token_policy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -197,3 +197,52 @@ func TestBearerPolicy_GetTokenFailsNoDeadlock(t *testing.T) {
t.Fatal("expected nil response")
}
}

func TestBearerTokenWithAuxiliaryTenants(t *testing.T) {
srv, close := mock.NewTLSServer()
defer close()
headerResult := "Bearer new_token, Bearer new_token, Bearer new_token"
srv.AppendResponse(mock.WithBody([]byte(accessTokenRespSuccess)))
srv.AppendResponse(mock.WithBody([]byte(accessTokenRespSuccess)))
srv.AppendResponse(mock.WithBody([]byte(accessTokenRespSuccess)))
srv.AppendResponse(mock.WithBody([]byte(accessTokenRespSuccess)))
srv.AppendResponse()
options := ClientSecretCredentialOptions{
AuthorityHost: srv.URL(),
HTTPClient: srv,
}
cred, err := NewClientSecretCredential(tenantID, clientID, secret, &options)
if err != nil {
t.Fatalf("Unable to create credential. Received: %v", err)
}
retryOpts := azcore.RetryOptions{
MaxRetryDelay: 500 * time.Millisecond,
RetryDelay: 50 * time.Millisecond,
}
pipeline := azcore.NewPipeline(
srv,
azcore.NewRetryPolicy(&retryOpts),
cred.NewAuthenticationPolicy(
azcore.AuthenticationOptions{
TokenRequest: azcore.TokenRequestOptions{
Scopes: []string{scope},
},
AuxiliaryTenants: []string{"tenant1", "tenant2", "tenant3"},
}),
azcore.NewLogPolicy(nil))

req, err := azcore.NewRequest(context.Background(), http.MethodGet, srv.URL())
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
resp, err := pipeline.Do(req)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
if resp.StatusCode != http.StatusOK {
t.Fatalf("unexpected status code: %d", resp.StatusCode)
}
if auxH := resp.Request.Header.Get(headerAuxiliaryAuthorization); auxH != headerResult {
t.Fatalf("unexpected auxiliary authorization header %s", auxH)
}
}