diff --git a/credentials/go.mod b/credentials/go.mod index 7f1d84595bc..6121e797ce2 100644 --- a/credentials/go.mod +++ b/credentials/go.mod @@ -5,8 +5,11 @@ go 1.15 require ( github.com/aws/aws-sdk-go-v2 v1.0.1-0.20210122214637-6cf9ad2f8e2f github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.0.0 + github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.0.0 + github.com/aws/aws-sdk-go-v2/service/sso v1.0.0 github.com/aws/aws-sdk-go-v2/service/sts v1.0.0 github.com/aws/smithy-go v1.0.0 + github.com/google/go-cmp v0.5.4 ) replace ( diff --git a/credentials/go.sum b/credentials/go.sum index 8f7f803c058..fe76f6955f8 100644 --- a/credentials/go.sum +++ b/credentials/go.sum @@ -1,3 +1,5 @@ +github.com/aws/aws-sdk-go-v2/service/sso v1.0.0 h1:eNwZL0deLt9ehrTpPAO/pvztJxa4RT6+E7sbDpgMGUQ= +github.com/aws/aws-sdk-go-v2/service/sso v1.0.0/go.mod h1:qNdDupP6xoM//zL1JmPl2XGbyPL5kKrlsoYVh8XZxzQ= github.com/aws/smithy-go v1.0.0 h1:hkhcRKG9rJ4Fn+RbfXY7Tz7b3ITLDyolBnLLBhwbg/c= github.com/aws/smithy-go v1.0.0/go.mod h1:EzMw8dbp/YJL4A5/sbhGddag+NPT7q084agLbB9LgIw= github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= diff --git a/credentials/ssocreds/doc.go b/credentials/ssocreds/doc.go new file mode 100644 index 00000000000..6e20416de87 --- /dev/null +++ b/credentials/ssocreds/doc.go @@ -0,0 +1,57 @@ +// Package provides a credential provider for retrieving temporary AWS credentials using an SSO access token. +// +// IMPORTANT: The provider in this package does not initiate or perform the AWS SSO login flow. The SDK provider +// expects that you have already performed the SSO login flow using AWS CLI using the "aws sso login" command, or by +// some other mechanism. The provider must find a valid non-expired access token for the AWS SSO user portal URL in +// ~/.aws/sso/cache. If a cached token is not found, it is expired, or the file is malformed an error will be returned. +// +// Loading AWS SSO credentials with the AWS shared configuration file +// +// You can use configure AWS SSO credentials from the AWS shared configuration file by +// providing the specifying the required keys in the profile: +// +// sso_account_id +// sso_region +// sso_role_name +// sso_start_url +// +// For example, the following defines a profile "devsso" and specifies the AWS SSO parameters that defines the target +// account, role, sign-on portal, and the region where the user portal is located. Note: all SSO arguments must be +// provided, or an error will be returned. +// +// [profile devsso] +// sso_start_url = https://my-sso-portal.awsapps.com/start +// sso_role_name = SSOReadOnlyRole +// sso_region = us-east-1 +// sso_account_id = 123456789012 +// +// Using the config module, you can load the AWS SDK shared configuration, and specify that this profile be used to +// retrieve credentials. For example: +// +// config, err := config.LoadDefaultConfig(context.TODO(), config.WithSharedConfigProfile("devsso")) +// if err != nil { +// return err +// } +// +// Programmatically loading AWS SSO credentials directly +// +// You can programmatically construct the AWS SSO Provider in your application, and provide the necessary information +// to load and retrieve temporary credentials using an access token from ~/.aws/sso/cache. +// +// client := sso.NewFromConfig(cfg) +// +// var provider aws.CredentialsProvider +// provider = ssocreds.New(client, "123456789012", "SSOReadOnlyRole", "us-east-1", "https://my-sso-portal.awsapps.com/start") +// +// // Wrap the provider with aws.CredentialsCache to cache the credentials until their expire time +// provider = aws.NewCredentialsCache(provider) +// +// credentials, err := provider.Retrieve(context.TODO()) +// if err != nil { +// return err +// } +// +// It is important that you wrap the Provider with aws.CredentialsCache if you are programmatically constructing the +// provider directly. This prevents your application from accessing the cached access token and requesting new +// credentials each time the credentials are used. +package ssocreds diff --git a/credentials/ssocreds/os.go b/credentials/ssocreds/os.go new file mode 100644 index 00000000000..ceca7dceecb --- /dev/null +++ b/credentials/ssocreds/os.go @@ -0,0 +1,9 @@ +// +build !windows + +package ssocreds + +import "os" + +func getHomeDirectory() string { + return os.Getenv("HOME") +} diff --git a/credentials/ssocreds/os_windows.go b/credentials/ssocreds/os_windows.go new file mode 100644 index 00000000000..eb48f61e5bc --- /dev/null +++ b/credentials/ssocreds/os_windows.go @@ -0,0 +1,7 @@ +package ssocreds + +import "os" + +func getHomeDirectory() string { + return os.Getenv("USERPROFILE") +} diff --git a/credentials/ssocreds/provider.go b/credentials/ssocreds/provider.go new file mode 100644 index 00000000000..2c672d79a77 --- /dev/null +++ b/credentials/ssocreds/provider.go @@ -0,0 +1,176 @@ +package ssocreds + +import ( + "context" + "crypto/sha1" + "encoding/hex" + "encoding/json" + "fmt" + "io/ioutil" + "path/filepath" + "strings" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/internal/sdk" + "github.com/aws/aws-sdk-go-v2/service/sso" +) + +const ProviderName = "SSOProvider" + +var defaultCacheLocation = filepath.Join(getHomeDirectory(), ".aws", "sso", "cache") + +// GetRoleCredentialsAPIClient is a API client that implements the GetRoleCredentials operation. +type GetRoleCredentialsAPIClient interface { + GetRoleCredentials(ctx context.Context, params *sso.GetRoleCredentialsInput, optFns ...func(*sso.Options)) (*sso.GetRoleCredentialsOutput, error) +} + +// Options is the Provider options structure. +type Options struct { + Client GetRoleCredentialsAPIClient + + // The AWS account that is assigned to the user. + AccountID string + + // The region where the AWS Signle Sign-On (AWS SSO) user portal is hosted. + Region string + + // The role name that is assigned to the user. + RoleName string + + // The URL that points to the organization's AWS Single Sign-On (AWS SSO) user portal. + StartURL string +} + +// Provider is an AWS credential provider that retrieves temporary AWS credentials by exchanging an SSO login token. +type Provider struct { + options Options +} + +// New returns a new AWS Signle Sign-On (AWS SSO) credential proivder. +func New(client GetRoleCredentialsAPIClient, accountID, region, roleName, startURL string, optFns ...func(options *Options)) *Provider { + options := Options{ + Client: client, + AccountID: accountID, + Region: region, + RoleName: roleName, + StartURL: startURL, + } + + for _, fn := range optFns { + fn(&options) + } + + return &Provider{ + options: options, + } +} + +// Retrieve retrieves temporary AWS credentials from the configured Amazon Single Sign-On (AWS SSO) user portal +// by exchanging the accessToken present in ~/.aws/sso/cache. +func (p *Provider) Retrieve(ctx context.Context) (aws.Credentials, error) { + tokenFile, err := loadTokenFile(p.options.StartURL) + if err != nil { + return aws.Credentials{}, err + } + + output, err := p.options.Client.GetRoleCredentials(ctx, &sso.GetRoleCredentialsInput{ + AccessToken: &tokenFile.AccessToken, + AccountId: &p.options.AccountID, + RoleName: &p.options.RoleName, + }, p.configureClientOptions) + if err != nil { + return aws.Credentials{}, err + } + + return aws.Credentials{ + AccessKeyID: aws.ToString(output.RoleCredentials.AccessKeyId), + SecretAccessKey: aws.ToString(output.RoleCredentials.SecretAccessKey), + SessionToken: aws.ToString(output.RoleCredentials.SessionToken), + Expires: time.Unix(output.RoleCredentials.Expiration, 0).UTC(), + CanExpire: true, + Source: ProviderName, + }, nil +} + +func (p *Provider) configureClientOptions(options *sso.Options) { + options.Region = p.options.Region +} + +func getCacheFileName(url string) (string, error) { + hash := sha1.New() + _, err := hash.Write([]byte(url)) + if err != nil { + return "", err + } + return strings.ToLower(hex.EncodeToString(hash.Sum(nil))) + ".json", nil +} + +type rfc3339 time.Time + +func (r *rfc3339) UnmarshalJSON(bytes []byte) error { + var value string + + if err := json.Unmarshal(bytes, &value); err != nil { + return err + } + + parse, err := time.Parse(time.RFC3339, value) + if err != nil { + return err + } + + *r = rfc3339(parse) + + return nil +} + +type token struct { + AccessToken string `json:"accessToken"` + ExpiresAt rfc3339 `json:"expiresAt"` + Region string `json:"region,omitempty"` + StartURL string `json:"startUrl,omitempty"` +} + +func (t token) Expired() bool { + return sdk.NowTime().Round(0).After(time.Time(t.ExpiresAt)) +} + +// InvalidTokenError is the error type that is returned if aloaded token +type InvalidTokenError struct { + Err error +} + +func (i *InvalidTokenError) Unwrap() error { + return i.Err +} + +func (i *InvalidTokenError) Error() string { + return "the SSO session associated with this profile has expired or is otherwise invalid. To refresh this SSO session run aws sso login with the corresponding profile." +} + +func loadTokenFile(startURL string) (t token, err error) { + key, err := getCacheFileName(startURL) + if err != nil { + return token{}, &InvalidTokenError{Err: err} + } + + fileBytes, err := ioutil.ReadFile(filepath.Join(defaultCacheLocation, key)) + if err != nil { + return token{}, &InvalidTokenError{Err: err} + } + + if err := json.Unmarshal(fileBytes, &t); err != nil { + return token{}, &InvalidTokenError{Err: err} + } + + if len(t.AccessToken) == 0 { + return token{}, &InvalidTokenError{} + } + + if t.Expired() { + return token{}, &InvalidTokenError{Err: fmt.Errorf("access token is expired")} + } + + return t, nil +} diff --git a/credentials/ssocreds/provider_test.go b/credentials/ssocreds/provider_test.go new file mode 100644 index 00000000000..aa0962e5e87 --- /dev/null +++ b/credentials/ssocreds/provider_test.go @@ -0,0 +1,178 @@ +package ssocreds + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/internal/sdk" + "github.com/aws/aws-sdk-go-v2/service/sso" + "github.com/aws/aws-sdk-go-v2/service/sso/types" + "github.com/google/go-cmp/cmp" +) + +type mockClient struct { + t *testing.T + + Output *sso.GetRoleCredentialsOutput + Err error + + ExpectedAccountID string + ExpectedAccessToken string + ExpectedRoleName string + ExpectedClientRegion string + + Response func(mockClient) (*sso.GetRoleCredentialsOutput, error) +} + +func (m mockClient) GetRoleCredentials(ctx context.Context, params *sso.GetRoleCredentialsInput, optFns ...func(options *sso.Options)) (out *sso.GetRoleCredentialsOutput, err error) { + m.t.Helper() + + if len(m.ExpectedAccountID) > 0 { + if diff := cmp.Diff(m.ExpectedAccountID, aws.ToString(params.AccountId)); len(diff) > 0 { + m.t.Error(diff) + } + } + + if len(m.ExpectedAccessToken) > 0 { + if diff := cmp.Diff(m.ExpectedAccessToken, aws.ToString(params.AccessToken)); len(diff) > 0 { + m.t.Error(diff) + } + } + + if len(m.ExpectedRoleName) > 0 { + if diff := cmp.Diff(m.ExpectedRoleName, aws.ToString(params.RoleName)); len(diff) > 0 { + m.t.Error(diff) + } + } + + o := sso.Options{ + Region: "client-region", + } + for _, fn := range optFns { + fn(&o) + } + + if len(m.ExpectedClientRegion) > 0 { + if diff := cmp.Diff(m.ExpectedClientRegion, o.Region); len(diff) > 0 { + m.t.Error(diff) + } + } + + if m.Response == nil { + return out, err + } + return m.Response(m) +} + +func swapCacheLocation(dir string) func() { + original := defaultCacheLocation + defaultCacheLocation = dir + return func() { + defaultCacheLocation = original + } +} + +func swapNowTime(referenceTime time.Time) func() { + original := sdk.NowTime + sdk.NowTime = func() time.Time { + return referenceTime + } + return func() { + sdk.NowTime = original + } +} + +func TestProvider(t *testing.T) { + restoreCache := swapCacheLocation("testdata") + defer restoreCache() + + restoreTime := swapNowTime(time.Date(2021, 01, 19, 19, 50, 0, 0, time.UTC)) + defer restoreTime() + + cases := map[string]struct { + Client mockClient + AccountID string + Region string + RoleName string + StartURL string + Options []func(*Options) + + ExpectedErr bool + ExpectedCredentials aws.Credentials + }{ + "missing required parameter values": { + StartURL: "https://invalid-required", + ExpectedErr: true, + }, + "valid required parameter values": { + Client: mockClient{ + ExpectedAccountID: "012345678901", + ExpectedRoleName: "TestRole", + ExpectedClientRegion: "us-west-2", + ExpectedAccessToken: "dGhpcyBpcyBub3QgYSByZWFsIHZhbHVl", + Response: func(mock mockClient) (*sso.GetRoleCredentialsOutput, error) { + return &sso.GetRoleCredentialsOutput{ + RoleCredentials: &types.RoleCredentials{ + AccessKeyId: aws.String("AccessKey"), + SecretAccessKey: aws.String("SecretKey"), + SessionToken: aws.String("SessionToken"), + Expiration: time.Date(2021, 01, 20, 00, 00, 0, 0, time.UTC).Unix(), + }, + }, nil + }, + }, + AccountID: "012345678901", + Region: "us-west-2", + RoleName: "TestRole", + StartURL: "https://valid-required-only", + ExpectedCredentials: aws.Credentials{ + AccessKeyID: "AccessKey", + SecretAccessKey: "SecretKey", + SessionToken: "SessionToken", + CanExpire: true, + Expires: time.Date(2021, 01, 20, 00, 00, 0, 0, time.UTC), + Source: ProviderName, + }, + }, + "expired access token": { + StartURL: "https://expired", + ExpectedErr: true, + }, + "api error": { + Client: mockClient{ + ExpectedAccountID: "012345678901", + ExpectedRoleName: "TestRole", + ExpectedClientRegion: "us-west-2", + ExpectedAccessToken: "dGhpcyBpcyBub3QgYSByZWFsIHZhbHVl", + Response: func(mock mockClient) (*sso.GetRoleCredentialsOutput, error) { + return nil, fmt.Errorf("api error") + }, + }, + AccountID: "012345678901", + Region: "us-west-2", + RoleName: "TestRole", + StartURL: "https://valid-required-only", + ExpectedErr: true, + }, + } + + for name, tt := range cases { + t.Run(name, func(t *testing.T) { + tt.Client.t = t + + provider := New(tt.Client, tt.AccountID, tt.Region, tt.RoleName, tt.StartURL, tt.Options...) + + credentials, err := provider.Retrieve(context.Background()) + if (err != nil) != tt.ExpectedErr { + t.Errorf("expect error: %v", tt.ExpectedErr) + } + + if diff := cmp.Diff(tt.ExpectedCredentials, credentials); len(diff) > 0 { + t.Errorf(diff) + } + }) + } +} diff --git a/credentials/ssocreds/testdata/00126f0eb29dc1310529dcc8fc178693e1493135.json b/credentials/ssocreds/testdata/00126f0eb29dc1310529dcc8fc178693e1493135.json new file mode 100644 index 00000000000..42bf135ff78 --- /dev/null +++ b/credentials/ssocreds/testdata/00126f0eb29dc1310529dcc8fc178693e1493135.json @@ -0,0 +1,4 @@ +{ + "accessToken": "dGhpcyBpcyBub3QgYSByZWFsIHZhbHVl", + "expiresAt": "2021-01-19T23:00:00Z" +} diff --git a/credentials/ssocreds/testdata/b5f90cb535abf87a12eb4c57db2b1e837e229ea0.json b/credentials/ssocreds/testdata/b5f90cb535abf87a12eb4c57db2b1e837e229ea0.json new file mode 100644 index 00000000000..7d5bd2d53be --- /dev/null +++ b/credentials/ssocreds/testdata/b5f90cb535abf87a12eb4c57db2b1e837e229ea0.json @@ -0,0 +1,4 @@ +{ + "accessToken": "dGhpcyBpcyBub3QgYSByZWFsIHZhbHVl", + "expiresAt": "2021-01-19T18:00:00Z" +} diff --git a/credentials/ssocreds/testdata/f7f7ff326478d8c33d47eeb3408cf1c783cb611e.json b/credentials/ssocreds/testdata/f7f7ff326478d8c33d47eeb3408cf1c783cb611e.json new file mode 100644 index 00000000000..a103267e449 --- /dev/null +++ b/credentials/ssocreds/testdata/f7f7ff326478d8c33d47eeb3408cf1c783cb611e.json @@ -0,0 +1,6 @@ +{ + "accessToken": "", + "expiresAt": "", + "region": "", + "startUrl": "" +}