Skip to content

Commit

Permalink
Adds Policy and Duration parameters to `stscreds.WebIdentityRoleO…
Browse files Browse the repository at this point in the history
…ptions` (#1670)

Adds the Policy and Duration parameters from sts.AssumeRoleWithWebIdentityInput to stscreds.WebIdentityRoleOptions.

Closes #1662
  • Loading branch information
gdavison authored Apr 25, 2022
1 parent da333d3 commit b0a3c24
Show file tree
Hide file tree
Showing 4 changed files with 120 additions and 25 deletions.
8 changes: 8 additions & 0 deletions .changelog/1be705cb9be94061bd94a4dbded36da5.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
{
"id": "1be705cb-9be9-4061-bd94-a4dbded36da5",
"type": "feature",
"description": "Adds Duration and Policy options that can be used when creating stscreds.WebIdentityRoleProvider credentials provider.",
"modules": [
"credentials"
]
}
9 changes: 7 additions & 2 deletions credentials/stscreds/assume_role_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,13 @@ type AssumeRoleAPIClient interface {
AssumeRole(ctx context.Context, params *sts.AssumeRoleInput, optFns ...func(*sts.Options)) (*sts.AssumeRoleOutput, error)
}

// DefaultDuration is the default amount of time in minutes that the credentials
// will be valid for.
// DefaultDuration is the default amount of time in minutes that the
// credentials will be valid for. This value is only used by AssumeRoleProvider
// for specifying the default expiry duration of an assume role.
//
// Other providers such as WebIdentityRoleProvider do not use this value, and
// instead rely on STS API's default parameter handing to assign a default
// value.
var DefaultDuration = time.Duration(15) * time.Minute

// AssumeRoleProvider retrieves temporary credentials from the STS service, and
Expand Down
27 changes: 25 additions & 2 deletions credentials/stscreds/web_identity_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"fmt"
"io/ioutil"
"strconv"
"time"

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/aws/retry"
Expand Down Expand Up @@ -45,6 +46,19 @@ type WebIdentityRoleOptions struct {
// Session name, if you wish to uniquely identify this session.
RoleSessionName string

// Expiry duration of the STS credentials. STS will assign a default expiry
// duration if this value is unset. This is different from the Duration
// option of AssumeRoleProvider, which automatically assigns 15 minutes if
// Duration is unset.
//
// See the STS AssumeRoleWithWebIdentity API reference guide for more
// information on defaults.
// https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRoleWithWebIdentity.html
Duration time.Duration

// An IAM policy in JSON format that you want to use as an inline session policy.
Policy *string

// The Amazon Resource Names (ARNs) of the IAM managed policies that you
// want to use as managed session policies. The policies must exist in the
// same account as the role.
Expand Down Expand Up @@ -100,12 +114,21 @@ func (p *WebIdentityRoleProvider) Retrieve(ctx context.Context) (aws.Credentials
// uses unix time in nanoseconds to uniquely identify sessions.
sessionName = strconv.FormatInt(sdk.NowTime().UnixNano(), 10)
}
resp, err := p.options.Client.AssumeRoleWithWebIdentity(ctx, &sts.AssumeRoleWithWebIdentityInput{
input := &sts.AssumeRoleWithWebIdentityInput{
PolicyArns: p.options.PolicyARNs,
RoleArn: &p.options.RoleARN,
RoleSessionName: &sessionName,
WebIdentityToken: aws.String(string(b)),
}, func(options *sts.Options) {
}
if p.options.Duration != 0 {
// If set use the value, otherwise STS will assign a default expiration duration.
input.DurationSeconds = aws.Int32(int32(p.options.Duration / time.Second))
}
if p.options.Policy != nil {
input.Policy = p.options.Policy
}

resp, err := p.options.Client.AssumeRoleWithWebIdentity(ctx, input, func(options *sts.Options) {
options.Retryer = retry.AddWithErrorCodes(options.Retryer, invalidIdentityTokenExceptionCode)
})
if err != nil {
Expand Down
101 changes: 80 additions & 21 deletions credentials/stscreds/web_identity_provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,32 +31,79 @@ func (m mockErrorCode) Error() string {
}

func TestWebIdentityProviderRetrieve(t *testing.T) {
defer func() func() {
o := sdk.NowTime
sdk.NowTime = func() time.Time {
return time.Time{}
}
return func() {
sdk.NowTime = o
}
}()()
restorTime := sdk.TestingUseReferenceTime(time.Time{})
defer restorTime()

cases := map[string]struct {
mockClient mockAssumeRoleWithWebIdentity
roleARN string
tokenFilepath string
sessionName string
expectedError error
options func(*stscreds.WebIdentityRoleOptions)
expectedCredValue aws.Credentials
}{
"session name case": {
"success": {
roleARN: "arn01234567890123456789",
tokenFilepath: "testdata/token.jwt",
sessionName: "foo",
mockClient: func(ctx context.Context, params *sts.AssumeRoleWithWebIdentityInput, optFns ...func(*sts.Options)) (*sts.AssumeRoleWithWebIdentityOutput, error) {
options: func(o *stscreds.WebIdentityRoleOptions) {
o.RoleSessionName = "foo"
},
mockClient: func(
ctx context.Context, params *sts.AssumeRoleWithWebIdentityInput, optFns ...func(*sts.Options),
) (
*sts.AssumeRoleWithWebIdentityOutput, error,
) {
if e, a := "foo", *params.RoleSessionName; e != a {
return nil, fmt.Errorf("expected %v, but received %v", e, a)
}
if params.DurationSeconds != nil {
return nil, fmt.Errorf("expect no duration seconds, got %v",
*params.DurationSeconds)
}
if params.Policy != nil {
return nil, fmt.Errorf("expect no policy, got %v",
*params.Policy)
}
return &sts.AssumeRoleWithWebIdentityOutput{
Credentials: &types.Credentials{
Expiration: aws.Time(sdk.NowTime()),
AccessKeyId: aws.String("access-key-id"),
SecretAccessKey: aws.String("secret-access-key"),
SessionToken: aws.String("session-token"),
},
}, nil
},
expectedCredValue: aws.Credentials{
AccessKeyID: "access-key-id",
SecretAccessKey: "secret-access-key",
SessionToken: "session-token",
Source: stscreds.WebIdentityProviderName,
CanExpire: true,
Expires: sdk.NowTime(),
},
},
"success with duration and policy": {
roleARN: "arn01234567890123456789",
tokenFilepath: "testdata/token.jwt",
options: func(o *stscreds.WebIdentityRoleOptions) {
o.Duration = 42 * time.Second
o.Policy = aws.String("super secret policy")
o.RoleSessionName = "foo"
},
mockClient: func(
ctx context.Context, params *sts.AssumeRoleWithWebIdentityInput, optFns ...func(*sts.Options),
) (
*sts.AssumeRoleWithWebIdentityOutput, error,
) {
if e, a := "foo", *params.RoleSessionName; e != a {
return nil, fmt.Errorf("expected %v, but received %v", e, a)
}
if e, a := int32(42), aws.ToInt32(params.DurationSeconds); e != a {
return nil, fmt.Errorf("expect %v duration seconds, got %v", e, a)
}
if e, a := "super secret policy", aws.ToString(params.Policy); e != a {
return nil, fmt.Errorf("expect %v policy, got %v", e, a)
}
return &sts.AssumeRoleWithWebIdentityOutput{
Credentials: &types.Credentials{
Expiration: aws.Time(sdk.NowTime()),
Expand All @@ -78,8 +125,14 @@ func TestWebIdentityProviderRetrieve(t *testing.T) {
"configures token retry": {
roleARN: "arn01234567890123456789",
tokenFilepath: "testdata/token.jwt",
sessionName: "foo",
mockClient: func(ctx context.Context, params *sts.AssumeRoleWithWebIdentityInput, optFns ...func(*sts.Options)) (*sts.AssumeRoleWithWebIdentityOutput, error) {
options: func(o *stscreds.WebIdentityRoleOptions) {
o.RoleSessionName = "foo"
},
mockClient: func(
ctx context.Context, params *sts.AssumeRoleWithWebIdentityInput, optFns ...func(*sts.Options),
) (
*sts.AssumeRoleWithWebIdentityOutput, error,
) {
o := sts.Options{}
for _, fn := range optFns {
fn(&o)
Expand Down Expand Up @@ -112,13 +165,19 @@ func TestWebIdentityProviderRetrieve(t *testing.T) {

for name, c := range cases {
t.Run(name, func(t *testing.T) {
p := stscreds.NewWebIdentityRoleProvider(c.mockClient, c.roleARN, stscreds.IdentityTokenFile(c.tokenFilepath),
func(o *stscreds.WebIdentityRoleOptions) {
o.RoleSessionName = c.sessionName
})
var optFns []func(*stscreds.WebIdentityRoleOptions)
if c.options != nil {
optFns = append(optFns, c.options)
}
p := stscreds.NewWebIdentityRoleProvider(
c.mockClient,
c.roleARN,
stscreds.IdentityTokenFile(c.tokenFilepath),
optFns...,
)
credValue, err := p.Retrieve(context.Background())
if e, a := c.expectedError, err; !reflect.DeepEqual(e, a) {
t.Errorf("expected %v, but received %v", e, a)
if err != nil {
t.Fatalf("expect no error, got %v", err)
}

if e, a := c.expectedCredValue, credValue; !reflect.DeepEqual(e, a) {
Expand Down

0 comments on commit b0a3c24

Please sign in to comment.