Skip to content

Commit

Permalink
Add ExpiryWindow and ExpiryWindowJitterFrac to CredentialsCache (#946)
Browse files Browse the repository at this point in the history
  • Loading branch information
skmcgrail authored Dec 7, 2020
1 parent 8baa5f4 commit 734d12f
Show file tree
Hide file tree
Showing 14 changed files with 253 additions and 179 deletions.
22 changes: 22 additions & 0 deletions aws/context.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package aws

import (
"context"
"time"
)

type suppressedContext struct {
context.Context
}

func (s *suppressedContext) Deadline() (deadline time.Time, ok bool) {
return time.Time{}, false
}

func (s *suppressedContext) Done() <-chan struct{} {
return nil
}

func (s *suppressedContext) Err() error {
return nil
}
81 changes: 75 additions & 6 deletions aws/credential_cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,75 @@ package aws
import (
"context"
"sync/atomic"
"time"

sdkrand "github.com/aws/aws-sdk-go-v2/internal/rand"
"github.com/aws/aws-sdk-go-v2/internal/sync/singleflight"
)

// CredentialsCacheOptions are the options
type CredentialsCacheOptions struct {

// ExpiryWindow will allow the credentials to trigger refreshing prior to
// the credentials actually expiring. This is beneficial so race conditions
// with expiring credentials do not cause request to fail unexpectedly
// due to ExpiredTokenException exceptions.
//
// An ExpiryWindow of 10s would cause calls to IsExpired() to return true
// 10 seconds before the credentials are actually expired. This can cause an
// increased number of requests to refresh the credentials to occur.
//
// If ExpiryWindow is 0 or less it will be ignored.
ExpiryWindow time.Duration

// ExpiryWindowJitterFrac provides a mechanism for randomizing the expiration of credentials
// within the configured ExpiryWindow by a random percentage. Valid values are between 0.0 and 1.0.
//
// As an example if ExpiryWindow is 60 seconds and ExpiryWindowJitterFrac is 0.5 then credentials will be set to
// expire between 30 to 60 seconds prior to their actual expiration time.
//
// If ExpiryWindow is 0 or less then ExpiryWindowJitterFrac is ignored.
// If ExpiryWindowJitterFrac is 0 then no randomization will be applied to the window.
// If ExpiryWindowJitterFrac < 0 the value will be treated as 0.
// If ExpiryWindowJitterFrac > 1 the value will be treated as 1.
ExpiryWindowJitterFrac float64
}

// CredentialsCache provides caching and concurrency safe credentials retrieval
// via the provider's retrieve method.
type CredentialsCache struct {
Provider CredentialsProvider
// provider is the CredentialProvider implementation to be wrapped by the CredentialCache.
provider CredentialsProvider

creds atomic.Value
sf singleflight.Group
options CredentialsCacheOptions
creds atomic.Value
sf singleflight.Group
}

// NewCredentialsCache returns a CredentialsCache that wraps provider. Provider is expected to not be nil. A variadic
// list of one or more functions can be provided to modify the CredentialsCache configuration. This allows for
// configuration of credential expiry window and jitter.
func NewCredentialsCache(provider CredentialsProvider, optFns ...func(options *CredentialsCacheOptions)) *CredentialsCache {
options := CredentialsCacheOptions{}

for _, fn := range optFns {
fn(&options)
}

if options.ExpiryWindow < 0 {
options.ExpiryWindow = 0
}

if options.ExpiryWindowJitterFrac < 0 {
options.ExpiryWindowJitterFrac = 0
} else if options.ExpiryWindowJitterFrac > 1 {
options.ExpiryWindowJitterFrac = 1
}

return &CredentialsCache{
provider: provider,
options: options,
}
}

// Retrieve returns the credentials. If the credentials have already been
Expand All @@ -27,7 +85,9 @@ func (p *CredentialsCache) Retrieve(ctx context.Context) (Credentials, error) {
return *creds, nil
}

resCh := p.sf.DoChan("", p.singleRetrieve)
resCh := p.sf.DoChan("", func() (interface{}, error) {
return p.singleRetrieve(&suppressedContext{ctx})
})
select {
case res := <-resCh:
return res.Val.(Credentials), res.Err
Expand All @@ -36,13 +96,22 @@ func (p *CredentialsCache) Retrieve(ctx context.Context) (Credentials, error) {
}
}

func (p *CredentialsCache) singleRetrieve() (interface{}, error) {
func (p *CredentialsCache) singleRetrieve(ctx context.Context) (interface{}, error) {
if creds := p.getCreds(); creds != nil {
return *creds, nil
}

creds, err := p.Provider.Retrieve(context.TODO())
creds, err := p.provider.Retrieve(ctx)
if err == nil {
if creds.CanExpire {
randFloat64, err := sdkrand.CryptoRandFloat64()
if err != nil {
return Credentials{}, err
}
jitter := time.Duration(randFloat64 * p.options.ExpiryWindowJitterFrac * float64(p.options.ExpiryWindow))
creds.Expires = creds.Expires.Add(-(p.options.ExpiryWindow - jitter))
}

p.creds.Store(&creds)
}

Expand Down
8 changes: 2 additions & 6 deletions aws/credential_cache_bench_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,7 @@ func BenchmarkCredentialsCache_Retrieve(b *testing.B) {
cases := []int{1, 10, 100, 500, 1000, 10000}
for _, c := range cases {
b.Run(strconv.Itoa(c), func(b *testing.B) {
p := CredentialsCache{
Provider: provider,
}
p := NewCredentialsCache(provider)
var wg sync.WaitGroup
wg.Add(c)
for i := 0; i < c; i++ {
Expand Down Expand Up @@ -59,9 +57,7 @@ func BenchmarkCredentialsCache_Retrieve_Invalidate(b *testing.B) {
for _, expRate := range expRates {
for _, c := range cases {
b.Run(fmt.Sprintf("%d-%d", expRate, c), func(b *testing.B) {
p := CredentialsCache{
Provider: provider,
}
p := NewCredentialsCache(provider)
var wg sync.WaitGroup
wg.Add(c)
for i := 0; i < c; i++ {
Expand Down
137 changes: 104 additions & 33 deletions aws/credential_cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,13 @@ func TestCredentialsCache_Cache(t *testing.T) {
}

var called bool
p := &CredentialsCache{
Provider: CredentialsProviderFunc(func(ctx context.Context) (Credentials, error) {
if called {
t.Fatalf("expect provider.Retrieve to only be called once")
}
called = true
return expect, nil
}),
}
p := NewCredentialsCache(CredentialsProviderFunc(func(ctx context.Context) (Credentials, error) {
if called {
t.Fatalf("expect provider.Retrieve to only be called once")
}
called = true
return expect, nil
}))

for i := 0; i < 2; i++ {
creds, err := p.Retrieve(context.Background())
Expand Down Expand Up @@ -108,12 +106,10 @@ func TestCredentialsCache_Expires(t *testing.T) {

for _, c := range cases {
var called int
p := &CredentialsCache{
Provider: CredentialsProviderFunc(func(ctx context.Context) (Credentials, error) {
called++
return c.Creds(), nil
}),
}
p := NewCredentialsCache(CredentialsProviderFunc(func(ctx context.Context) (Credentials, error) {
called++
return c.Creds(), nil
}))

p.Retrieve(context.Background())
p.Retrieve(context.Background())
Expand All @@ -131,13 +127,92 @@ func TestCredentialsCache_Expires(t *testing.T) {
}
}

func TestCredentialsCache_Error(t *testing.T) {
p := &CredentialsCache{
Provider: CredentialsProviderFunc(func(ctx context.Context) (Credentials, error) {
return Credentials{}, fmt.Errorf("failed")
}),
func TestCredentialsCache_ExpireTime(t *testing.T) {
orig := sdk.NowTime
defer func() { sdk.NowTime = orig }()
var mockTime time.Time
sdk.NowTime = func() time.Time { return mockTime }

cases := map[string]struct {
ExpireTime time.Time
ExpiryWindow time.Duration
JitterFrac float64
Validate func(t *testing.T, v time.Time)
}{
"no expire window": {
Validate: func(t *testing.T, v time.Time) {
t.Helper()
if e, a := mockTime, v; !e.Equal(a) {
t.Errorf("expect %v, got %v", e, a)
}
},
},
"expire window": {
ExpireTime: mockTime.Add(100),
ExpiryWindow: 50,
Validate: func(t *testing.T, v time.Time) {
t.Helper()
if e, a := mockTime.Add(50), v; !e.Equal(a) {
t.Errorf("expect %v, got %v", e, a)
}
},
},
"expire window with jitter": {
ExpireTime: mockTime.Add(100),
JitterFrac: 0.5,
ExpiryWindow: 50,
Validate: func(t *testing.T, v time.Time) {
t.Helper()
max := mockTime.Add(75)
min := mockTime.Add(50)
if v.Before(min) {
t.Errorf("expect %v to be before %s", v, min)
}
if v.After(max) {
t.Errorf("expect %v to be after %s", v, max)
}
},
},
"no expire window with jitter": {
ExpireTime: mockTime,
JitterFrac: 0.5,
Validate: func(t *testing.T, v time.Time) {
t.Helper()
if e, a := mockTime, v; !e.Equal(a) {
t.Errorf("expect %v, got %v", e, a)
}
},
},
}

for name, tt := range cases {
t.Run(name, func(t *testing.T) {
p := NewCredentialsCache(CredentialsProviderFunc(func(ctx context.Context) (Credentials, error) {
return Credentials{
AccessKeyID: "accessKey",
SecretAccessKey: "secretKey",
CanExpire: true,
Expires: tt.ExpireTime,
}, nil
}), func(options *CredentialsCacheOptions) {
options.ExpiryWindow = tt.ExpiryWindow
options.ExpiryWindowJitterFrac = tt.JitterFrac
})

credentials, err := p.Retrieve(context.Background())
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
tt.Validate(t, credentials.Expires)
})
}
}

func TestCredentialsCache_Error(t *testing.T) {
p := NewCredentialsCache(CredentialsProviderFunc(func(ctx context.Context) (Credentials, error) {
return Credentials{}, fmt.Errorf("failed")
}))

creds, err := p.Retrieve(context.Background())
if err == nil {
t.Fatalf("expect error, not none")
Expand All @@ -156,16 +231,14 @@ func TestCredentialsCache_Race(t *testing.T) {
SecretAccessKey: "secret",
}
var called bool
p := &CredentialsCache{
Provider: CredentialsProviderFunc(func(ctx context.Context) (Credentials, error) {
time.Sleep(time.Duration(rand.Intn(10)) * time.Millisecond)
if called {
t.Fatalf("expect provider.Retrieve only called once")
}
called = true
return expect, nil
}),
}
p := NewCredentialsCache(CredentialsProviderFunc(func(ctx context.Context) (Credentials, error) {
time.Sleep(time.Duration(rand.Intn(10)) * time.Millisecond)
if called {
t.Fatalf("expect provider.Retrieve only called once")
}
called = true
return expect, nil
}))

var wg sync.WaitGroup
wg.Add(100)
Expand Down Expand Up @@ -206,9 +279,7 @@ func TestCredentialsCache_RetrieveConcurrent(t *testing.T) {
stub := &stubConcurrentProvider{
done: make(chan struct{}),
}
provider := CredentialsCache{
Provider: stub,
}
provider := NewCredentialsCache(stub)

var wg sync.WaitGroup
wg.Add(2)
Expand Down
Loading

0 comments on commit 734d12f

Please sign in to comment.