Skip to content

Commit

Permalink
Implement AWS SSO Credential Provider
Browse files Browse the repository at this point in the history
  • Loading branch information
skmcgrail committed Jan 25, 2021
1 parent 409c761 commit 317d8b1
Show file tree
Hide file tree
Showing 10 changed files with 446 additions and 0 deletions.
3 changes: 3 additions & 0 deletions credentials/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,11 @@ go 1.15
require (
github.com/aws/aws-sdk-go-v2 v1.0.1-0.20210122214637-6cf9ad2f8e2f
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.0.0
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.0.0
github.com/aws/aws-sdk-go-v2/service/sso v1.0.0
github.com/aws/aws-sdk-go-v2/service/sts v1.0.0
github.com/aws/smithy-go v1.0.0
github.com/google/go-cmp v0.5.4
)

replace (
Expand Down
2 changes: 2 additions & 0 deletions credentials/go.sum
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
github.com/aws/aws-sdk-go-v2/service/sso v1.0.0 h1:eNwZL0deLt9ehrTpPAO/pvztJxa4RT6+E7sbDpgMGUQ=
github.com/aws/aws-sdk-go-v2/service/sso v1.0.0/go.mod h1:qNdDupP6xoM//zL1JmPl2XGbyPL5kKrlsoYVh8XZxzQ=
github.com/aws/smithy-go v1.0.0 h1:hkhcRKG9rJ4Fn+RbfXY7Tz7b3ITLDyolBnLLBhwbg/c=
github.com/aws/smithy-go v1.0.0/go.mod h1:EzMw8dbp/YJL4A5/sbhGddag+NPT7q084agLbB9LgIw=
github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8=
Expand Down
57 changes: 57 additions & 0 deletions credentials/ssocreds/doc.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
// Package provides a credential provider for retrieving temporary AWS credentials using an SSO access token.
//
// IMPORTANT: The provider in this package does not initiate or perform the AWS SSO login flow. The SDK provider
// expects that you have already performed the SSO login flow using AWS CLI using the "aws sso login" command, or by
// some other mechanism. The provider must find a valid non-expired access token for the AWS SSO user portal URL in
// ~/.aws/sso/cache. If a cached token is not found, it is expired, or the file is malformed an error will be returned.
//
// Loading AWS SSO credentials with the AWS shared configuration file
//
// You can use configure AWS SSO credentials from the AWS shared configuration file by
// providing the specifying the required keys in the profile:
//
// sso_account_id
// sso_region
// sso_role_name
// sso_start_url
//
// For example, the following defines a profile "devsso" and specifies the AWS SSO parameters that defines the target
// account, role, sign-on portal, and the region where the user portal is located. Note: all SSO arguments must be
// provided, or an error will be returned.
//
// [profile devsso]
// sso_start_url = https://my-sso-portal.awsapps.com/start
// sso_role_name = SSOReadOnlyRole
// sso_region = us-east-1
// sso_account_id = 123456789012
//
// Using the config module, you can load the AWS SDK shared configuration, and specify that this profile be used to
// retrieve credentials. For example:
//
// config, err := config.LoadDefaultConfig(context.TODO(), config.WithSharedConfigProfile("devsso"))
// if err != nil {
// return err
// }
//
// Programmatically loading AWS SSO credentials directly
//
// You can programmatically construct the AWS SSO Provider in your application, and provide the necessary information
// to load and retrieve temporary credentials using an access token from ~/.aws/sso/cache.
//
// client := sso.NewFromConfig(cfg)
//
// var provider aws.CredentialsProvider
// provider = ssocreds.New(client, "123456789012", "SSOReadOnlyRole", "us-east-1", "https://my-sso-portal.awsapps.com/start")
//
// // Wrap the provider with aws.CredentialsCache to cache the credentials until their expire time
// provider = aws.NewCredentialsCache(provider)
//
// credentials, err := provider.Retrieve(context.TODO())
// if err != nil {
// return err
// }
//
// It is important that you wrap the Provider with aws.CredentialsCache if you are programmatically constructing the
// provider directly. This prevents your application from accessing the cached access token and requesting new
// credentials each time the credentials are used.
package ssocreds
9 changes: 9 additions & 0 deletions credentials/ssocreds/os.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
// +build !windows

package ssocreds

import "os"

func getHomeDirectory() string {
return os.Getenv("HOME")
}
7 changes: 7 additions & 0 deletions credentials/ssocreds/os_windows.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package ssocreds

import "os"

func getHomeDirectory() string {
return os.Getenv("USERPROFILE")
}
176 changes: 176 additions & 0 deletions credentials/ssocreds/provider.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
package ssocreds

import (
"context"
"crypto/sha1"
"encoding/hex"
"encoding/json"
"fmt"
"io/ioutil"
"path/filepath"
"strings"
"time"

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/internal/sdk"
"github.com/aws/aws-sdk-go-v2/service/sso"
)

const ProviderName = "SSOProvider"

var defaultCacheLocation = filepath.Join(getHomeDirectory(), ".aws", "sso", "cache")

// GetRoleCredentialsAPIClient is a API client that implements the GetRoleCredentials operation.
type GetRoleCredentialsAPIClient interface {
GetRoleCredentials(ctx context.Context, params *sso.GetRoleCredentialsInput, optFns ...func(*sso.Options)) (*sso.GetRoleCredentialsOutput, error)
}

// Options is the Provider options structure.
type Options struct {
Client GetRoleCredentialsAPIClient

// The AWS account that is assigned to the user.
AccountID string

// The region where the AWS Signle Sign-On (AWS SSO) user portal is hosted.
Region string

// The role name that is assigned to the user.
RoleName string

// The URL that points to the organization's AWS Single Sign-On (AWS SSO) user portal.
StartURL string
}

// Provider is an AWS credential provider that retrieves temporary AWS credentials by exchanging an SSO login token.
type Provider struct {
options Options
}

// New returns a new AWS Signle Sign-On (AWS SSO) credential proivder.
func New(client GetRoleCredentialsAPIClient, accountID, region, roleName, startURL string, optFns ...func(options *Options)) *Provider {
options := Options{
Client: client,
AccountID: accountID,
Region: region,
RoleName: roleName,
StartURL: startURL,
}

for _, fn := range optFns {
fn(&options)
}

return &Provider{
options: options,
}
}

// Retrieve retrieves temporary AWS credentials from the configured Amazon Single Sign-On (AWS SSO) user portal
// by exchanging the accessToken present in ~/.aws/sso/cache.
func (p *Provider) Retrieve(ctx context.Context) (aws.Credentials, error) {
tokenFile, err := loadTokenFile(p.options.StartURL)
if err != nil {
return aws.Credentials{}, err
}

output, err := p.options.Client.GetRoleCredentials(ctx, &sso.GetRoleCredentialsInput{
AccessToken: &tokenFile.AccessToken,
AccountId: &p.options.AccountID,
RoleName: &p.options.RoleName,
}, p.configureClientOptions)
if err != nil {
return aws.Credentials{}, err
}

return aws.Credentials{
AccessKeyID: aws.ToString(output.RoleCredentials.AccessKeyId),
SecretAccessKey: aws.ToString(output.RoleCredentials.SecretAccessKey),
SessionToken: aws.ToString(output.RoleCredentials.SessionToken),
Expires: time.Unix(output.RoleCredentials.Expiration, 0).UTC(),
CanExpire: true,
Source: ProviderName,
}, nil
}

func (p *Provider) configureClientOptions(options *sso.Options) {
options.Region = p.options.Region
}

func getCacheFileName(url string) (string, error) {
hash := sha1.New()
_, err := hash.Write([]byte(url))
if err != nil {
return "", err
}
return strings.ToLower(hex.EncodeToString(hash.Sum(nil))) + ".json", nil
}

type rfc3339 time.Time

func (r *rfc3339) UnmarshalJSON(bytes []byte) error {
var value string

if err := json.Unmarshal(bytes, &value); err != nil {
return err
}

parse, err := time.Parse(time.RFC3339, value)
if err != nil {
return err
}

*r = rfc3339(parse)

return nil
}

type token struct {
AccessToken string `json:"accessToken"`
ExpiresAt rfc3339 `json:"expiresAt"`
Region string `json:"region,omitempty"`
StartURL string `json:"startUrl,omitempty"`
}

func (t token) Expired() bool {
return sdk.NowTime().Round(0).After(time.Time(t.ExpiresAt))
}

// InvalidTokenError is the error type that is returned if aloaded token
type InvalidTokenError struct {
Err error
}

func (i *InvalidTokenError) Unwrap() error {
return i.Err
}

func (i *InvalidTokenError) Error() string {
return "the SSO session associated with this profile has expired or is otherwise invalid. To refresh this SSO session run aws sso login with the corresponding profile."
}

func loadTokenFile(startURL string) (t token, err error) {
key, err := getCacheFileName(startURL)
if err != nil {
return token{}, &InvalidTokenError{Err: err}
}

fileBytes, err := ioutil.ReadFile(filepath.Join(defaultCacheLocation, key))
if err != nil {
return token{}, &InvalidTokenError{Err: err}
}

if err := json.Unmarshal(fileBytes, &t); err != nil {
return token{}, &InvalidTokenError{Err: err}
}

if len(t.AccessToken) == 0 {
return token{}, &InvalidTokenError{}
}

if t.Expired() {
return token{}, &InvalidTokenError{Err: fmt.Errorf("access token is expired")}
}

return t, nil
}
Loading

0 comments on commit 317d8b1

Please sign in to comment.