Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

auth: manual token expiration, better auxiliary token handling, improve tests #665

Merged
merged 15 commits into from
Nov 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 17 additions & 3 deletions sdk/auth/cached_authorizer.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,12 @@ import (
"fmt"
"net/http"
"sync"
"time"

"golang.org/x/oauth2"
)

var _ Authorizer = &CachedAuthorizer{}
var _ CachingAuthorizer = &CachedAuthorizer{}

// CachedAuthorizer caches a token until it expires, then acquires a new token from Source
type CachedAuthorizer struct {
Expand Down Expand Up @@ -54,7 +55,7 @@ func (c *CachedAuthorizer) AuxiliaryTokens(ctx context.Context, req *http.Reques
}
c.mutex.RUnlock()

if !dueForRenewal {
if dueForRenewal || len(c.auxTokens) == 0 {
c.mutex.Lock()
defer c.mutex.Unlock()
var err error
Expand All @@ -67,9 +68,22 @@ func (c *CachedAuthorizer) AuxiliaryTokens(ctx context.Context, req *http.Reques
return c.auxTokens, nil
}

// InvalidateCachedTokens expires the currently cached token and auxTokens, forcing new
// tokens to be acquired when Token() or AuxiliaryTokens() are next called
func (c *CachedAuthorizer) InvalidateCachedTokens() error {
if c.token == nil {
return nil
}
c.token.Expiry = time.Now()
manicminer marked this conversation as resolved.
Show resolved Hide resolved
for i := range c.auxTokens {
c.auxTokens[i].Expiry = time.Now()
}
return nil
}

// NewCachedAuthorizer returns an Authorizer that caches an access token for the duration of its validity.
// If the cached token expires, a new one is acquired and cached.
func NewCachedAuthorizer(src Authorizer) (Authorizer, error) {
func NewCachedAuthorizer(src Authorizer) (CachingAuthorizer, error) {
if _, ok := src.(*SharedKeyAuthorizer); ok {
return nil, fmt.Errorf("internal-error: SharedKeyAuthorizer cannot be cached")
}
Expand Down
161 changes: 161 additions & 0 deletions sdk/auth/cached_authorizer_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
package auth_test

import (
"context"
"net/http"
"regexp"
"testing"
"time"

"github.com/hashicorp/go-azure-sdk/sdk/auth"
"github.com/hashicorp/go-azure-sdk/sdk/claims"
"github.com/hashicorp/go-azure-sdk/sdk/internal/test"
)

func TestCachedAuthorizer(t *testing.T) {
tokenPattern := regexp.MustCompile("^[a-zA-Z0-9_-]+[.][a-zA-Z0-9_-]+[.][a-zA-Z0-9_-]+")
req := &http.Request{}

authorizer, err := auth.NewCachedAuthorizer(&test.TestAuthorizer{})
if err != nil {
t.Fatalf("received error for NewCachedAuthorizer(): %+v", err)
}

// Retrieve the first access tokens
token, err := authorizer.Token(context.Background(), req)
if err != nil {
t.Fatalf("received error for CachedAuthorizer.Token(): %+v", err)
}
if !tokenPattern.MatchString(token.AccessToken) {
t.Fatalf("unexpected access token received: %q", token.AccessToken)
}
auxTokens, err := authorizer.AuxiliaryTokens(context.Background(), req)
if err != nil {
t.Fatalf("received error for CachedAuthorizer.AuxiliaryTokens(): %+v", err)
}
for i, auxToken := range auxTokens {
if !tokenPattern.MatchString(auxToken.AccessToken) {
t.Fatalf("unexpected auxiliary access token received at %d: %q", i, token.AccessToken)
}
}

// Parse the claims and compare the IssuedAt and Expiry times
tokenClaims, err := claims.ParseClaims(token)
if err != nil {
t.Fatalf("received error for claims.ParseClaims(): %+v", err)
}
if tokenClaims.IssuedAt != test.TestTokenIssued.Unix() {
t.Fatalf("unexpected `iat` claim for access token, expected: %d, received: %d", test.TestTokenIssued.Unix(), tokenClaims.IssuedAt)
}
if tokenClaims.Expires != test.TestTokenExpiry.Unix() {
t.Fatalf("unexpected `exp` claim for access token, expected: %d, received: %d", test.TestTokenExpiry.Unix(), tokenClaims.Expires)
}
for i, auxToken := range auxTokens {
auxTokenClaims, err := claims.ParseClaims(auxToken)
if err != nil {
t.Fatalf("received error for claims.ParseClaims(): %+v", err)
}
if auxTokenClaims.IssuedAt != test.TestTokenIssued.Unix() {
t.Fatalf("unexpected `iat` claim for auxiliary access token at %d, expected: %d, received: %d", i, test.TestTokenIssued.Unix(), auxTokenClaims.IssuedAt)
}
if auxTokenClaims.Expires != test.TestTokenExpiry.Unix() {
t.Fatalf("unexpected `exp` claim for auxiliary access token at %d, expected: %d, received: %d", i, test.TestTokenExpiry.Unix(), auxTokenClaims.Expires)
}
}

// Wait for 5 seconds and advance the issued/expiry times for the testAuthorizer
time.Sleep(5 * time.Second)
earlierTestTokenIssued := test.TestTokenIssued
earlierTestTokenExpiry := test.TestTokenExpiry
test.TestTokenIssued = time.Now()
test.TestTokenExpiry = time.Now().Add(3599 * time.Second)

// Retrieve a second token, this should be retrieved from the cache
token, err = authorizer.Token(context.Background(), req)
if err != nil {
t.Fatalf("received error for CachedAuthorizer.Token(): %+v", err)
}
if !tokenPattern.MatchString(token.AccessToken) {
t.Fatalf("unexpected access token received: %q", token.AccessToken)
}
auxTokens, err = authorizer.AuxiliaryTokens(context.Background(), req)
if err != nil {
t.Fatalf("received error for CachedAuthorizer.AuxiliaryTokens(): %+v", err)
}
for i, auxToken := range auxTokens {
if !tokenPattern.MatchString(auxToken.AccessToken) {
t.Fatalf("unexpected auxiliary access token received at %d: %q", i, token.AccessToken)
}
}

// Parse the claims for the second token, ensure the IssuedAt and Expiry times _have not_ changed
tokenClaims, err = claims.ParseClaims(token)
if err != nil {
t.Fatalf("received error for claims.ParseClaims(): %+v", err)
}
if tokenClaims.IssuedAt != earlierTestTokenIssued.Unix() {
t.Fatalf("unexpected `iat` claim for access token, expected: %d, received: %d", earlierTestTokenIssued.Unix(), tokenClaims.IssuedAt)
}
if tokenClaims.Expires != earlierTestTokenExpiry.Unix() {
t.Fatalf("unexpected `exp` claim for access token, expected: %d, received: %d", earlierTestTokenExpiry.Unix(), tokenClaims.Expires)
}
for i, auxToken := range auxTokens {
auxTokenClaims, err := claims.ParseClaims(auxToken)
if err != nil {
t.Fatalf("received error for claims.ParseClaims(): %+v", err)
}
if auxTokenClaims.IssuedAt != earlierTestTokenIssued.Unix() {
t.Fatalf("unexpected `iat` claim for auxiliary access token at %d, expected: %d, received: %d", i, earlierTestTokenIssued.Unix(), auxTokenClaims.IssuedAt)
}
if auxTokenClaims.Expires != earlierTestTokenExpiry.Unix() {
t.Fatalf("unexpected `exp` claim for auxiliary access token at %d, expected: %d, received: %d", i, earlierTestTokenExpiry.Unix(), auxTokenClaims.Expires)
}
}

// Invalidate the access tokens
if err = authorizer.InvalidateCachedTokens(); err != nil {
t.Fatalf("received error for CachedAuthorizer.ExpireTokens(): %+v", err)
}

// Retrieve a third token, which should be re-acquired from the testAuthorizer
token, err = authorizer.Token(context.Background(), req)
if err != nil {
t.Fatalf("received error for CachedAuthorizer.Token(): %+v", err)
}
if !tokenPattern.MatchString(token.AccessToken) {
t.Fatalf("unexpected access token received: %q", token.AccessToken)
}
auxTokens, err = authorizer.AuxiliaryTokens(context.Background(), req)
if err != nil {
t.Fatalf("received error for CachedAuthorizer.AuxiliaryTokens(): %+v", err)
}
for i, auxToken := range auxTokens {
if !tokenPattern.MatchString(auxToken.AccessToken) {
t.Fatalf("unexpected auxiliary access token received at %d: %q", i, token.AccessToken)
}
}

// Parse the claims for the third token, ensure the IssuedAt and Expiry times _have_ changed
tokenClaims, err = claims.ParseClaims(token)
if err != nil {
t.Fatalf("received error for claims.ParseClaims(): %+v", err)
}
if tokenClaims.IssuedAt != test.TestTokenIssued.Unix() {
t.Fatalf("unexpected `iat` claim for access token, expected: %d, received: %d", test.TestTokenIssued.Unix(), tokenClaims.IssuedAt)
}
if tokenClaims.Expires != test.TestTokenExpiry.Unix() {
t.Fatalf("unexpected `exp` claim for access token, expected: %d, received: %d", test.TestTokenExpiry.Unix(), tokenClaims.Expires)
}
for i, auxToken := range auxTokens {
auxTokenClaims, err := claims.ParseClaims(auxToken)
if err != nil {
t.Fatalf("received error for claims.ParseClaims(): %+v", err)
}
if auxTokenClaims.IssuedAt != test.TestTokenIssued.Unix() {
t.Fatalf("unexpected `iat` claim for auxiliary access token at %d, expected: %d, received: %d", i, test.TestTokenIssued.Unix(), auxTokenClaims.IssuedAt)
}
if auxTokenClaims.Expires != test.TestTokenExpiry.Unix() {
t.Fatalf("unexpected `exp` claim for auxiliary access token at %d, expected: %d, received: %d", i, test.TestTokenExpiry.Unix(), auxTokenClaims.Expires)
}
}
}
11 changes: 11 additions & 0 deletions sdk/auth/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,22 @@ import (

// Authorizer is anything that can return an access token for authorizing API connections
type Authorizer interface {
// Token obtains a new access token for the configured tenant
Token(ctx context.Context, request *http.Request) (*oauth2.Token, error)

// AuxiliaryTokens obtains new access tokens for the configured auxiliary tenants
AuxiliaryTokens(ctx context.Context, request *http.Request) ([]*oauth2.Token, error)
}

// CachingAuthorizer implements Authorizer whilst caching access tokens and offering a way to intentionally invalidate them
type CachingAuthorizer interface {
Authorizer

// InvalidateCachedTokens invalidates any cached access tokens, so that new tokens are automatically
// retrieved from the authorization service on the next call to Token or AuxiliaryTokens.
InvalidateCachedTokens() error
}

// HTTPClient is an HTTP client used for sending authentication requests and obtaining tokens
type HTTPClient interface {
Do(req *http.Request) (*http.Response, error)
Expand Down
2 changes: 0 additions & 2 deletions sdk/auth/shared_key_authorizer.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,6 @@ func (s *SharedKeyAuthorizer) AuxiliaryTokens(_ context.Context, _ *http.Request
return []*oauth2.Token{}, nil
}

// ---

const (
storageEmulatorAccountName string = "devstoreaccount1"

Expand Down
39 changes: 35 additions & 4 deletions sdk/auth/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,42 @@
package auth

import (
"golang.org/x/oauth2"
"context"
"fmt"
"net/http"
"time"

"github.com/hashicorp/go-azure-sdk/sdk/claims"
"golang.org/x/oauth2"
)

// SetAuthHeader decorates a *http.Request with the Authorization header using a bearer token obtained from the Token
// method of the supplied Authorizer.
func SetAuthHeader(ctx context.Context, req *http.Request, authorizer Authorizer) error {
if req == nil {
return fmt.Errorf("request was nil")
}
if authorizer == nil {
return fmt.Errorf("authorizer was nil")
}

token, err := authorizer.Token(ctx, req)
if err != nil {
return err
}

if req.Header == nil {
req.Header = make(http.Header)
}

req.Header.Set("Authorization", fmt.Sprintf("%s %s", token.Type(), token.AccessToken))

return nil
}

const tokenExpiryDelta = 20 * time.Minute

// tokenExpiresSoon returns true if the token expires within 10 minutes, or if more than 50% of its validity period has elapsed (if this can be determined), whichever is later
// tokenDueForRenewal returns true if the token expires within 10 minutes, or if more than 50% of its validity period has elapsed (if this can be determined), whichever is later
func tokenDueForRenewal(token *oauth2.Token) bool {
if token == nil {
return true
Expand All @@ -26,7 +53,11 @@ func tokenDueForRenewal(token *oauth2.Token) bool {
expiry := token.Expiry.Round(0)
delta := tokenExpiryDelta
now := time.Now()
expiresWithinTenMinutes := expiry.Add(-delta).Before(now)

// Always return early if the token validity doesn't extend past the expiry delta
if expiry.Add(-delta).Before(now) {
tombuildsstuff marked this conversation as resolved.
Show resolved Hide resolved
return true
}

// Try to parse the token claims to retrieve the issuedAt time
if claims, err := claims.ParseClaims(token); err == nil {
Expand All @@ -43,5 +74,5 @@ func tokenDueForRenewal(token *oauth2.Token) bool {
}
}

return expiresWithinTenMinutes
return false
}
26 changes: 26 additions & 0 deletions sdk/auth/token_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package auth_test

import (
"context"
"net/http"
"regexp"
"testing"

"github.com/hashicorp/go-azure-sdk/sdk/auth"
"github.com/hashicorp/go-azure-sdk/sdk/internal/test"
)

func TestSetAuthHeader(t *testing.T) {
req := &http.Request{}
authorizer := &test.TestAuthorizer{}

err := auth.SetAuthHeader(context.Background(), req, authorizer)
if err != nil {
t.Fatalf("received error: %+v", err)
}

expected := regexp.MustCompile("^Bearer [a-zA-Z0-9_-]+[.][a-zA-Z0-9_-]+[.][a-zA-Z0-9_-]+")
if val := req.Header.Get("Authorization"); !expected.MatchString(val) {
t.Fatalf("Authorization header mismatch, received: %q", val)
}
}
1 change: 1 addition & 0 deletions sdk/claims/claims.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
// Claims is used to unmarshall the claims from a JWT issued by the Microsoft Identity Platform.
type Claims struct {
Audience string `json:"aud"`
Expires int64 `json:"exp"`
IssuedAt int64 `json:"iat"`
Issuer string `json:"iss"`
IdentityProvider string `json:"idp"`
Expand Down
23 changes: 15 additions & 8 deletions sdk/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@ func RequestRetryAll(retryFuncs ...RequestRetryFunc) func(resp *http.Response, o
}
}

// RetryableErrorHandler simply returns the resp and err, this is needed to makes the retryablehttp client's Do() return early with the response body not drained.
// RetryableErrorHandler simply returns the resp and err, this is needed to make the Do() method
// of retryablehttp client return early with the response body not drained.
func RetryableErrorHandler(resp *http.Response, err error, _ int) (*http.Response, error) {
return resp, err
}
Expand Down Expand Up @@ -260,6 +261,11 @@ type Client struct {
// Authorizer is anything that can provide an access token with which to authorize requests.
Authorizer auth.Authorizer

// AuthorizeRequest is an optional function to decorate a Request for authorization prior to being sent.
// When nil, a standard Authorization header will be added using a bearer token as returned by the Token method
// of the configured Authorizer. Define this function in order to customize the request authorization.
AuthorizeRequest func(context.Context, *http.Request, auth.Authorizer) error

// DisableRetries prevents the client from reattempting failed requests (which it does to work around eventual consistency issues).
// This does not impact handling of retries related to rate limiting, which are always performed.
DisableRetries bool
Expand Down Expand Up @@ -327,14 +333,15 @@ func (c *Client) Execute(ctx context.Context, req *Request) (*Response, error) {
return nil, fmt.Errorf("req.Request was nil")
}

// at this point we're ready to send the HTTP Request, as such let's get the Authorization token
// and add that to the request
if c.Authorizer != nil {
token, err := c.Authorizer.Token(ctx, req.Request)
if err != nil {
return nil, err
// Authorize the request
if c.AuthorizeRequest != nil {
if err := c.AuthorizeRequest(ctx, req.Request, c.Authorizer); err != nil {
return nil, fmt.Errorf("authorizing request: %+v", err)
}
} else if c.Authorizer != nil {
if err := auth.SetAuthHeader(ctx, req.Request, c.Authorizer); err != nil {
return nil, fmt.Errorf("authorizing request: %+v", err)
}
token.SetAuthHeader(req.Request)
}

var err error
Expand Down
1 change: 1 addition & 0 deletions sdk/client/dataplane/storage/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ func NewBaseClient(baseUri string, componentName, apiVersion string) (*BaseClien
}

func (c *BaseClient) NewRequest(ctx context.Context, input client.RequestOptions) (*client.Request, error) {
// TODO move these validations to base client method
if _, ok := ctx.Deadline(); !ok {
return nil, fmt.Errorf("the context used must have a deadline attached for polling purposes, but got no deadline")
}
Expand Down
Loading