Skip to content

Commit

Permalink
expiring resource rework
Browse files Browse the repository at this point in the history
  • Loading branch information
catalinaperalta committed Jul 23, 2021
1 parent 1af91d4 commit 610397a
Show file tree
Hide file tree
Showing 2 changed files with 181 additions and 69 deletions.
197 changes: 130 additions & 67 deletions sdk/azidentity/bearer_token_policy.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
package azidentity

import (
"fmt"
"net/http"
"strings"
"sync"
"time"

Expand All @@ -16,97 +18,158 @@ const (
)

type bearerTokenPolicy struct {
// cond is used to synchronize token refresh. the locker
// must be locked when updating the following shared state.
// mainResource is the resource to be retreived using the tenant specified in the credential
mainResource *expiringResource
// auxResources are additional resources that are required for cross-tenant applications
auxResources map[string]*expiringResource
// the following fields are read-only
creds azcore.TokenCredential
options azcore.TokenRequestOptions
}

type expiringResource struct {
// cond is used to synchronize access to the shared resource embodied by the remaining fields
cond *sync.Cond

// renewing indicates that the token is in the process of being refreshed
renewing bool
// acquiring indicates that some thread/goroutine is in the process of acquiring/updating the resource
acquiring bool

// header contains the authorization header value
header string
// resource contains the value of the shared resource
resource interface{}

// expiresOn is when the token will expire
expiresOn time.Time
// expiration indicates when the shared resource expires; it is 0 if the resource was never acquired
expiration time.Time

// the following fields are read-only
creds azcore.TokenCredential
options azcore.TokenRequestOptions
// acquireResource is the callback function that actually acquires the resource
acquireResource acquireResource
}

func newBearerTokenPolicy(creds azcore.TokenCredential, opts azcore.AuthenticationPolicyOptions) *bearerTokenPolicy {
return &bearerTokenPolicy{
cond: sync.NewCond(&sync.Mutex{}),
creds: creds,
options: opts.Options,
}
type acquireResource func(state interface{}) (newResource interface{}, newExpiration time.Time, err error)

type acquiringResourceState struct {
req *azcore.Request
p bearerTokenPolicy
}

func (b *bearerTokenPolicy) Do(req *azcore.Request) (*azcore.Response, error) {
if req.URL.Scheme != "https" {
// HTTPS must be used, otherwise the tokens are at the risk of being exposed
return nil, &AuthenticationFailedError{msg: "token credentials require a URL using the HTTPS protocol scheme"}
// acquire acquires or updates the resource; only one
// thread/goroutine at a time ever calls this function
func acquire(state interface{}) (newResource interface{}, newExpiration time.Time, err error) {
s := state.(acquiringResourceState)
tk, err := s.p.creds.GetToken(s.req.Context(), s.p.options)
if err != nil {
return nil, time.Time{}, err
}
// create a "refresh window" before the token's real expiration date.
// this allows callers to continue to use the old token while the
// refresh is in progress.
const window = 2 * time.Minute
now, getToken, header := time.Now(), false, ""
return tk, tk.ExpiresOn, nil
}

func newExpiringResource(ar acquireResource) *expiringResource {
return &expiringResource{cond: sync.NewCond(&sync.Mutex{}), acquireResource: ar}
}

func (er *expiringResource) GetResource(state interface{}) (interface{}, error) {
// If the resource is expiring within this time window, update it eagerly.
// This allows other threads/goroutines to keep running by using the not-yet-expired
// resource value while one thread/goroutine updates the resource.
const window = 2 * time.Minute // This example updates the resource 2 minutes prior to expiration

now, acquire, resource := time.Now(), false, er.resource
// acquire exclusive lock
b.cond.L.Lock()
er.cond.L.Lock()
for {
if b.expiresOn.IsZero() || b.expiresOn.Before(now) {
// token was never obtained or has expired
if !b.renewing {
// another go routine isn't refreshing the token so this one will
b.renewing = true
getToken = true
if er.expiration.IsZero() || er.expiration.Before(now) {
// The resource was never acquired or has expired
if !er.acquiring {
// If another thread/goroutine is not acquiring/updating the resource, this thread/goroutine will do it
er.acquiring, acquire = true, true
break
}
// getting here means this go routine will wait for the token to refresh
} else if b.expiresOn.Add(-window).Before(now) {
// token is within the expiration window
if !b.renewing {
// another go routine isn't refreshing the token so this one will
b.renewing = true
getToken = true
// Getting here means that this thread/goroutine will wait for the updated resource
} else if er.expiration.Add(-window).Before(now) {
// The resource is valid but is expiring within the time window
if !er.acquiring {
// If another thread/goroutine is not acquiring/renewing the resource, this thread/goroutine will do it
er.acquiring, acquire = true, true
break
}
// this go routine will use the existing token while another refreshes it
header = b.header
// This thread/goroutine will use the existing resource value while another updates it
resource = er.resource
break
} else {
// token is not expiring yet so use it as-is
header = b.header
// The resource is not close to expiring, this thread/goroutine should use its current value
resource = er.resource
break
}
// wait for the token to refresh
b.cond.Wait()
// If we get here, wait for the new resource value to be acquired/updated
er.cond.Wait()
}
b.cond.L.Unlock()
if getToken {
// this go routine has been elected to refresh the token
tk, err := b.creds.GetToken(req.Context(), b.options)
// update shared state
b.cond.L.Lock()
// to avoid a deadlock if GetToken() fails we MUST reset b.renewing to false before returning
b.renewing = false
er.cond.L.Unlock() // Release the lock so no threads/goroutines are blocked

var err error
if acquire {
// This thread/goroutine has been selected to acquire/update the resource
var expiration time.Time
resource, expiration, err = er.acquireResource(state)

// Atomically, update the shared resource's new value & expiration.
er.cond.L.Lock()
if err == nil {
// No error, update resource & expiration
er.resource, er.expiration = resource, expiration
}
er.acquiring = false // Indicate that no thread/goroutine is currently acquiring the resrouce

// Wake up any waiting threads/goroutines since there is a resource they can ALL use
er.cond.L.Unlock()
er.cond.Broadcast()
}
return resource, err // Return the resource this thread/goroutine can use
}

func newBearerTokenPolicy(creds azcore.TokenCredential, opts azcore.AuthenticationOptions) *bearerTokenPolicy {
p := &bearerTokenPolicy{
creds: creds,
options: opts.TokenRequest,
mainResource: newExpiringResource(acquire),
}
if len(opts.AuxiliaryTenants) > 0 {
p.auxResources = map[string]*expiringResource{}
}
for _, t := range opts.AuxiliaryTenants {
p.auxResources[t] = newExpiringResource(acquire)

}
return p
}

func (b *bearerTokenPolicy) Do(req *azcore.Request) (*azcore.Response, error) {
as := acquiringResourceState{
p: *b,
req: req,
}
tk, err := b.mainResource.GetResource(as)
if err != nil {
return nil, err
}
if token, ok := tk.(*azcore.AccessToken); ok {
req.Request.Header.Set(headerXmsDate, time.Now().UTC().Format(http.TimeFormat))
req.Request.Header.Set(headerAuthorization, fmt.Sprintf("Bearer %s", token.Token))
}
auxTokens := []string{}
for tenant, er := range b.auxResources {
bCopy := *b
bCopy.options.TenantID = tenant
auxAS := acquiringResourceState{
p: bCopy,
req: req,
}
auxTk, err := er.GetResource(auxAS)
if err != nil {
b.unlock()
return nil, err
}
header = bearerTokenPrefix + tk.Token
b.header = header
b.expiresOn = tk.ExpiresOn
b.unlock()
auxTokens = append(auxTokens, fmt.Sprintf("%s%s", bearerTokenPrefix, auxTk.(*azcore.AccessToken).Token))
}
if len(auxTokens) > 0 {
req.Request.Header.Set(headerAuxiliaryAuthorization, strings.Join(auxTokens, ", "))
}
req.Request.Header.Set(azcore.HeaderXmsDate, time.Now().UTC().Format(http.TimeFormat))
req.Request.Header.Set(azcore.HeaderAuthorization, header)
return req.Next()
}

// signal any waiters that the token has been refreshed
func (b *bearerTokenPolicy) unlock() {
b.cond.Broadcast()
b.cond.L.Unlock()
}
53 changes: 51 additions & 2 deletions sdk/azidentity/bearer_token_policy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ func defaultTestPipeline(srv azcore.Transport, cred azcore.Credential, scope str
return azcore.NewPipeline(
srv,
azcore.NewRetryPolicy(&retryOpts),
cred.AuthenticationPolicy(azcore.AuthenticationPolicyOptions{Options: azcore.TokenRequestOptions{Scopes: []string{scope}}}),
cred.NewAuthenticationPolicy(azcore.AuthenticationOptions{TokenRequest: azcore.TokenRequestOptions{Scopes: []string{scope}}}),
azcore.NewLogPolicy(nil))
}

Expand All @@ -55,7 +55,7 @@ func TestBearerPolicy_SuccessGetToken(t *testing.T) {
t.Fatalf("Expected nil error but received one")
}
const expectedToken = bearerTokenPrefix + tokenValue
if token := resp.Request.Header.Get(azcore.HeaderAuthorization); token != expectedToken {
if token := resp.Request.Header.Get(headerAuthorization); token != expectedToken {
t.Fatalf("expected token '%s', got '%s'", expectedToken, token)
}
}
Expand Down Expand Up @@ -197,3 +197,52 @@ func TestBearerPolicy_GetTokenFailsNoDeadlock(t *testing.T) {
t.Fatal("expected nil response")
}
}

func TestBearerTokenWithAuxiliaryTenants(t *testing.T) {
srv, close := mock.NewTLSServer()
defer close()
headerResult := "Bearer new_token, Bearer new_token, Bearer new_token"
srv.AppendResponse(mock.WithBody([]byte(accessTokenRespSuccess)))
srv.AppendResponse(mock.WithBody([]byte(accessTokenRespSuccess)))
srv.AppendResponse(mock.WithBody([]byte(accessTokenRespSuccess)))
srv.AppendResponse(mock.WithBody([]byte(accessTokenRespSuccess)))
srv.AppendResponse()
options := ClientSecretCredentialOptions{
AuthorityHost: srv.URL(),
HTTPClient: srv,
}
cred, err := NewClientSecretCredential(tenantID, clientID, secret, &options)
if err != nil {
t.Fatalf("Unable to create credential. Received: %v", err)
}
retryOpts := azcore.RetryOptions{
MaxRetryDelay: 500 * time.Millisecond,
RetryDelay: 50 * time.Millisecond,
}
pipeline := azcore.NewPipeline(
srv,
azcore.NewRetryPolicy(&retryOpts),
cred.NewAuthenticationPolicy(
azcore.AuthenticationOptions{
TokenRequest: azcore.TokenRequestOptions{
Scopes: []string{scope},
},
AuxiliaryTenants: []string{"tenant1", "tenant2", "tenant3"},
}),
azcore.NewLogPolicy(nil))

req, err := azcore.NewRequest(context.Background(), http.MethodGet, srv.URL())
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
resp, err := pipeline.Do(req)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
if resp.StatusCode != http.StatusOK {
t.Fatalf("unexpected status code: %d", resp.StatusCode)
}
if auxH := resp.Request.Header.Get(headerAuxiliaryAuthorization); auxH != headerResult {
t.Fatalf("unexpected auxiliary authorization header %s", auxH)
}
}

0 comments on commit 610397a

Please sign in to comment.