-
Notifications
You must be signed in to change notification settings - Fork 634
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implement AWS SSO Credential Provider
- Loading branch information
Showing
10 changed files
with
446 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
// Package provides a credential provider for retrieving temporary AWS credentials using an SSO access token. | ||
// | ||
// IMPORTANT: The provider in this package does not initiate or perform the AWS SSO login flow. The SDK provider | ||
// expects that you have already performed the SSO login flow using AWS CLI using the "aws sso login" command, or by | ||
// some other mechanism. The provider must find a valid non-expired access token for the AWS SSO user portal URL in | ||
// ~/.aws/sso/cache. If a cached token is not found, it is expired, or the file is malformed an error will be returned. | ||
// | ||
// Loading AWS SSO credentials with the AWS shared configuration file | ||
// | ||
// You can use configure AWS SSO credentials from the AWS shared configuration file by | ||
// providing the specifying the required keys in the profile: | ||
// | ||
// sso_account_id | ||
// sso_region | ||
// sso_role_name | ||
// sso_start_url | ||
// | ||
// For example, the following defines a profile "devsso" and specifies the AWS SSO parameters that defines the target | ||
// account, role, sign-on portal, and the region where the user portal is located. Note: all SSO arguments must be | ||
// provided, or an error will be returned. | ||
// | ||
// [profile devsso] | ||
// sso_start_url = https://my-sso-portal.awsapps.com/start | ||
// sso_role_name = SSOReadOnlyRole | ||
// sso_region = us-east-1 | ||
// sso_account_id = 123456789012 | ||
// | ||
// Using the config module, you can load the AWS SDK shared configuration, and specify that this profile be used to | ||
// retrieve credentials. For example: | ||
// | ||
// config, err := config.LoadDefaultConfig(context.TODO(), config.WithSharedConfigProfile("devsso")) | ||
// if err != nil { | ||
// return err | ||
// } | ||
// | ||
// Programmatically loading AWS SSO credentials directly | ||
// | ||
// You can programmatically construct the AWS SSO Provider in your application, and provide the necessary information | ||
// to load and retrieve temporary credentials using an access token from ~/.aws/sso/cache. | ||
// | ||
// client := sso.NewFromConfig(cfg) | ||
// | ||
// var provider aws.CredentialsProvider | ||
// provider = ssocreds.New(client, "123456789012", "SSOReadOnlyRole", "us-east-1", "https://my-sso-portal.awsapps.com/start") | ||
// | ||
// // Wrap the provider with aws.CredentialsCache to cache the credentials until their expire time | ||
// provider = aws.NewCredentialsCache(provider) | ||
// | ||
// credentials, err := provider.Retrieve(context.TODO()) | ||
// if err != nil { | ||
// return err | ||
// } | ||
// | ||
// It is important that you wrap the Provider with aws.CredentialsCache if you are programmatically constructing the | ||
// provider directly. This prevents your application from accessing the cached access token and requesting new | ||
// credentials each time the credentials are used. | ||
package ssocreds |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
// +build !windows | ||
|
||
package ssocreds | ||
|
||
import "os" | ||
|
||
func getHomeDirectory() string { | ||
return os.Getenv("HOME") | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
package ssocreds | ||
|
||
import "os" | ||
|
||
func getHomeDirectory() string { | ||
return os.Getenv("USERPROFILE") | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,176 @@ | ||
package ssocreds | ||
|
||
import ( | ||
"context" | ||
"crypto/sha1" | ||
"encoding/hex" | ||
"encoding/json" | ||
"fmt" | ||
"io/ioutil" | ||
"path/filepath" | ||
"strings" | ||
"time" | ||
|
||
"github.com/aws/aws-sdk-go-v2/aws" | ||
"github.com/aws/aws-sdk-go-v2/internal/sdk" | ||
"github.com/aws/aws-sdk-go-v2/service/sso" | ||
) | ||
|
||
const ProviderName = "SSOProvider" | ||
|
||
var defaultCacheLocation = filepath.Join(getHomeDirectory(), ".aws", "sso", "cache") | ||
|
||
// GetRoleCredentialsAPIClient is a API client that implements the GetRoleCredentials operation. | ||
type GetRoleCredentialsAPIClient interface { | ||
GetRoleCredentials(ctx context.Context, params *sso.GetRoleCredentialsInput, optFns ...func(*sso.Options)) (*sso.GetRoleCredentialsOutput, error) | ||
} | ||
|
||
// Options is the Provider options structure. | ||
type Options struct { | ||
Client GetRoleCredentialsAPIClient | ||
|
||
// The AWS account that is assigned to the user. | ||
AccountID string | ||
|
||
// The region where the AWS Signle Sign-On (AWS SSO) user portal is hosted. | ||
Region string | ||
|
||
// The role name that is assigned to the user. | ||
RoleName string | ||
|
||
// The URL that points to the organization's AWS Single Sign-On (AWS SSO) user portal. | ||
StartURL string | ||
} | ||
|
||
// Provider is an AWS credential provider that retrieves temporary AWS credentials by exchanging an SSO login token. | ||
type Provider struct { | ||
options Options | ||
} | ||
|
||
// New returns a new AWS Signle Sign-On (AWS SSO) credential proivder. | ||
func New(client GetRoleCredentialsAPIClient, accountID, region, roleName, startURL string, optFns ...func(options *Options)) *Provider { | ||
options := Options{ | ||
Client: client, | ||
AccountID: accountID, | ||
Region: region, | ||
RoleName: roleName, | ||
StartURL: startURL, | ||
} | ||
|
||
for _, fn := range optFns { | ||
fn(&options) | ||
} | ||
|
||
return &Provider{ | ||
options: options, | ||
} | ||
} | ||
|
||
// Retrieve retrieves temporary AWS credentials from the configured Amazon Single Sign-On (AWS SSO) user portal | ||
// by exchanging the accessToken present in ~/.aws/sso/cache. | ||
func (p *Provider) Retrieve(ctx context.Context) (aws.Credentials, error) { | ||
tokenFile, err := loadTokenFile(p.options.StartURL) | ||
if err != nil { | ||
return aws.Credentials{}, err | ||
} | ||
|
||
output, err := p.options.Client.GetRoleCredentials(ctx, &sso.GetRoleCredentialsInput{ | ||
AccessToken: &tokenFile.AccessToken, | ||
AccountId: &p.options.AccountID, | ||
RoleName: &p.options.RoleName, | ||
}, p.configureClientOptions) | ||
if err != nil { | ||
return aws.Credentials{}, err | ||
} | ||
|
||
return aws.Credentials{ | ||
AccessKeyID: aws.ToString(output.RoleCredentials.AccessKeyId), | ||
SecretAccessKey: aws.ToString(output.RoleCredentials.SecretAccessKey), | ||
SessionToken: aws.ToString(output.RoleCredentials.SessionToken), | ||
Expires: time.Unix(output.RoleCredentials.Expiration, 0).UTC(), | ||
CanExpire: true, | ||
Source: ProviderName, | ||
}, nil | ||
} | ||
|
||
func (p *Provider) configureClientOptions(options *sso.Options) { | ||
options.Region = p.options.Region | ||
} | ||
|
||
func getCacheFileName(url string) (string, error) { | ||
hash := sha1.New() | ||
_, err := hash.Write([]byte(url)) | ||
if err != nil { | ||
return "", err | ||
} | ||
return strings.ToLower(hex.EncodeToString(hash.Sum(nil))) + ".json", nil | ||
} | ||
|
||
type rfc3339 time.Time | ||
|
||
func (r *rfc3339) UnmarshalJSON(bytes []byte) error { | ||
var value string | ||
|
||
if err := json.Unmarshal(bytes, &value); err != nil { | ||
return err | ||
} | ||
|
||
parse, err := time.Parse(time.RFC3339, value) | ||
if err != nil { | ||
return err | ||
} | ||
|
||
*r = rfc3339(parse) | ||
|
||
return nil | ||
} | ||
|
||
type token struct { | ||
AccessToken string `json:"accessToken"` | ||
ExpiresAt rfc3339 `json:"expiresAt"` | ||
Region string `json:"region,omitempty"` | ||
StartURL string `json:"startUrl,omitempty"` | ||
} | ||
|
||
func (t token) Expired() bool { | ||
return sdk.NowTime().Round(0).After(time.Time(t.ExpiresAt)) | ||
} | ||
|
||
// InvalidTokenError is the error type that is returned if aloaded token | ||
type InvalidTokenError struct { | ||
Err error | ||
} | ||
|
||
func (i *InvalidTokenError) Unwrap() error { | ||
return i.Err | ||
} | ||
|
||
func (i *InvalidTokenError) Error() string { | ||
return "the SSO session associated with this profile has expired or is otherwise invalid. To refresh this SSO session run aws sso login with the corresponding profile." | ||
} | ||
|
||
func loadTokenFile(startURL string) (t token, err error) { | ||
key, err := getCacheFileName(startURL) | ||
if err != nil { | ||
return token{}, &InvalidTokenError{Err: err} | ||
} | ||
|
||
fileBytes, err := ioutil.ReadFile(filepath.Join(defaultCacheLocation, key)) | ||
if err != nil { | ||
return token{}, &InvalidTokenError{Err: err} | ||
} | ||
|
||
if err := json.Unmarshal(fileBytes, &t); err != nil { | ||
return token{}, &InvalidTokenError{Err: err} | ||
} | ||
|
||
if len(t.AccessToken) == 0 { | ||
return token{}, &InvalidTokenError{} | ||
} | ||
|
||
if t.Expired() { | ||
return token{}, &InvalidTokenError{Err: fmt.Errorf("access token is expired")} | ||
} | ||
|
||
return t, nil | ||
} |
Oops, something went wrong.