Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

config: refactor LoadDefaultConfig to take in context and concrete options #951

Merged
merged 4 commits into from
Dec 15, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ func main() {
// Using the SDK's default configuration, loading additional config
// and credentials values from the environment variables, shared
// credentials, and shared configuration files
cfg, err := config.LoadDefaultConfig(config.WithRegion("us-west-2"))
cfg, err := config.LoadDefaultConfig(context.TODO(), config.WithRegion("us-west-2"))
if err != nil {
log.Fatalf("unable to load SDK config, %v", err)
}
Expand Down
7 changes: 4 additions & 3 deletions aws/credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,9 @@ import (
// The following example demonstrates using the AnonymousCredentials to prevent
// SDK's external config loading attempt to resolve credentials.
//
// cfg, err := config.LoadDefaultConfig(
// config.WithCredentialsProvider(aws.AnonymousCredentials{}))
// cfg, err := config.LoadDefaultConfig(context.TODO(),
// config.WithCredentialsProvider(aws.AnonymousCredentials{}),
// )
// if err != nil {
// log.Fatalf("failed to load config, %v", err)
// }
Expand All @@ -42,7 +43,7 @@ import (
//
// This can also be configured for specific operations calls too.
//
// cfg, err := config.LoadDefaultConfig()
// cfg, err := config.LoadDefaultConfig(context.TODO())
// if err != nil {
// log.Fatalf("failed to load config, %v", err)
// }
Expand Down
39 changes: 23 additions & 16 deletions config/codegen/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,25 +11,32 @@ import (
)

const (
sharedConfigType = "&SharedConfig{}"
envConfigType = "&EnvConfig{}"
sharedConfigType = "&SharedConfig{}"
envConfigType = "&EnvConfig{}"
awsConfigType = "&awsConfig{}"
ec2IMDSRegionType = "&UseEC2IMDSRegion{}"
loadOptionsType = "&LoadOptions{}"
)

var implAsserts = map[string][]string{
"SharedConfigProfileProvider": {envConfigType, `WithSharedConfigProfile("")`},
"SharedConfigFilesProvider": {envConfigType, `WithSharedConfigFiles(nil)`},
"CustomCABundleProvider": {envConfigType, `WithCustomCABundle(nil)`},
"RegionProvider": {envConfigType, sharedConfigType, `WithRegion("")`, `WithEC2IMDSRegion{}`},
"CredentialsProviderProvider": {`WithCredentialsProvider(nil)`},
"DefaultRegionProvider": {`WithDefaultRegion("")`},
"EC2RoleCredentialOptionsProvider": {`WithEC2RoleCredentialOptions(nil)`},
"EndpointCredentialOptionsProvider": {`WithEndpointCredentialOptions(nil)`},
"EndpointResolverProvider": {`WithEndpointResolver(nil)`},
"APIOptionsProvider": {`WithAPIOptions(nil)`},
"HTTPClientProvider": {`WithHTTPClient(nil)`},
"AssumeRoleCredentialOptionsProvider": {`WithAssumeRoleCredentialOptions(nil)`},
"WebIdentityRoleCredentialOptionsProvider": {`WithWebIdentityRoleCredentialOptions(nil)`},
"RetryProvider": {`WithRetryer(nil)`},
"sharedConfigProfileProvider": {envConfigType, loadOptionsType},
"sharedConfigFilesProvider": {envConfigType, loadOptionsType},
"customCABundleProvider": {envConfigType, loadOptionsType},
"regionProvider": {envConfigType, sharedConfigType, loadOptionsType, ec2IMDSRegionType},
"credentialsProviderProvider": {loadOptionsType},
"defaultRegionProvider": {loadOptionsType},
"ec2RoleCredentialOptionsProvider": {loadOptionsType},
"endpointCredentialOptionsProvider": {loadOptionsType},
"assumeRoleCredentialOptionsProvider": {loadOptionsType},
"webIdentityRoleCredentialOptionsProvider": {loadOptionsType},
"httpClientProvider": {loadOptionsType},
"apiOptionsProvider": {loadOptionsType},
"retryProvider": {loadOptionsType},
"endpointResolverProvider": {loadOptionsType},
"loggerProvider": {loadOptionsType},
"clientLogModeProvider": {loadOptionsType},
"logConfigurationWarningsProvider": {loadOptionsType},
"ec2IMDSRegionProvider": {loadOptionsType},
}

var tplProviderTests = template.Must(template.New("tplProviderTests").Funcs(map[string]interface{}{
Expand Down
37 changes: 24 additions & 13 deletions config/config.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package config

import (
"context"
"github.com/aws/aws-sdk-go-v2/aws"
)

Expand Down Expand Up @@ -43,7 +44,7 @@ var defaultAWSConfigResolvers = []awsConfigResolver{

// Sets the region the API Clients should use for making requests to.
resolveRegion,
// TODO: Add back EC2 Region Resolver Support
resolveEC2IMDSRegion,
resolveDefaultRegion,

// Sets the additional set of middleware stack mutators that will custom
Expand Down Expand Up @@ -72,15 +73,15 @@ type Config interface{}
//
// The loader should return an error if it fails to load the external configuration
// or the configuration data is malformed, or required components missing.
type loader func(configs) (Config, error)
type loader func(context.Context, configs) (Config, error)

// An awsConfigResolver will extract configuration data from the configs slice
// using the provider interfaces to extract specific functionality. The extracted
// configuration values will be written to the AWS Config value.
//
// The resolver should return an error if it it fails to extract the data, the
// data is malformed, or incomplete.
type awsConfigResolver func(cfg *aws.Config, configs configs) error
type awsConfigResolver func(ctx context.Context, cfg *aws.Config, configs configs) error

// configs is a slice of Config values. These values will be used by the
// AWSConfigResolvers to extract external configuration values to populate the
Expand All @@ -99,9 +100,9 @@ type configs []Config
//
// If a loader returns an error this method will stop iterating and return
// that error.
func (cs configs) AppendFromLoaders(loaders []loader) (configs, error) {
func (cs configs) AppendFromLoaders(ctx context.Context, loaders []loader) (configs, error) {
for _, fn := range loaders {
cfg, err := fn(cs)
cfg, err := fn(ctx, cs)
if err != nil {
return nil, err
}
Expand All @@ -118,11 +119,11 @@ func (cs configs) AppendFromLoaders(loaders []loader) (configs, error) {
//
// If an resolver returns an error this method will return that error, and stop
// iterating over the resolvers.
func (cs configs) ResolveAWSConfig(resolvers []awsConfigResolver) (aws.Config, error) {
func (cs configs) ResolveAWSConfig(ctx context.Context, resolvers []awsConfigResolver) (aws.Config, error) {
var cfg aws.Config

for _, fn := range resolvers {
if err := fn(&cfg, cs); err != nil {
if err := fn(ctx, &cfg, cs); err != nil {
// TODO provide better error?
return aws.Config{}, err
}
Expand Down Expand Up @@ -155,7 +156,7 @@ func (cs configs) ResolveConfig(f func(configs []interface{}) error) error {
// The custom configurations must satisfy the respective providers for their data
// or the custom data will be ignored by the resolvers and config loaders.
//
// cfg, err := config.LoadDefaultConfig(
// cfg, err := config.LoadDefaultConfig( context.TODO(),
// WithSharedConfigProfile("test-profile"),
// )
// if err != nil {
Expand All @@ -166,14 +167,24 @@ func (cs configs) ResolveConfig(f func(configs []interface{}) error) error {
// The default configuration sources are:
// * Environment Variables
// * Shared Configuration and Shared Credentials files.
func LoadDefaultConfig(cfgs ...Config) (aws.Config, error) {
var cfgCpy configs
cfgCpy = append(cfgCpy, cfgs...)
func LoadDefaultConfig(ctx context.Context, optFns ...func(*LoadOptions) error) (cfg aws.Config, err error) {
var options LoadOptions
for _, optFn := range optFns {
optFn(&options)
}

// assign Load Options to configs
var cfgCpy = configs{options}

cfgCpy, err := cfgCpy.AppendFromLoaders(defaultLoaders)
cfgCpy, err = cfgCpy.AppendFromLoaders(ctx, defaultLoaders)
if err != nil {
return aws.Config{}, err
}

return cfgCpy.ResolveAWSConfig(defaultAWSConfigResolvers)
cfg, err = cfgCpy.ResolveAWSConfig(ctx, defaultAWSConfigResolvers)
if err != nil {
return aws.Config{}, err
}

return cfg, nil
}
77 changes: 48 additions & 29 deletions config/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,44 +2,51 @@ package config

import (
"context"
"reflect"
"github.com/google/go-cmp/cmp"
"testing"

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/credentials"
)

func TestConfigs_SharedConfigOptions(t *testing.T) {
_, err := configs{
var options LoadOptions
optFns := []func(*LoadOptions) error{
WithSharedConfigProfile("profile-name"),
WithSharedConfigFiles([]string{"creds-file"}),
}.AppendFromLoaders([]loader{
func(configs configs) (Config, error) {
}

for _, optFn := range optFns {
optFn(&options)
}

_, err := configs{options}.AppendFromLoaders(context.TODO(), []loader{
func(ctx context.Context, configs configs) (Config, error) {
var profile string
var found bool
var files []string
var err error

for _, cfg := range configs {
if p, ok := cfg.(SharedConfigProfileProvider); ok {
profile, err = p.GetSharedConfigProfile()
if err != nil {
if p, ok := cfg.(sharedConfigProfileProvider); ok {
profile, found, err = p.getSharedConfigProfile(ctx)
if err != nil || !found {
return nil, err
}
}
if p, ok := cfg.(SharedConfigFilesProvider); ok {
files, err = p.GetSharedConfigFiles()
if err != nil {
if p, ok := cfg.(sharedConfigFilesProvider); ok {
files, found, err = p.getSharedConfigFiles(ctx)
if err != nil || !found {
return nil, err
}
}

}

if e, a := "profile-name", profile; e != a {
t.Errorf("expect %v profile, got %v", e, a)
}
if e, a := []string{"creds-file"}, files; !reflect.DeepEqual(e, a) {
t.Errorf("expect %v files, got %v", e, a)
if diff := cmp.Diff([]string{"creds-file"}, files); len(diff) != 0 {
t.Errorf("expect resolved shared config match, got diff: \n %s", diff)
}

return nil, nil
Expand All @@ -52,16 +59,21 @@ func TestConfigs_SharedConfigOptions(t *testing.T) {
}

func TestConfigs_AppendFromLoaders(t *testing.T) {
expectCfg := WithRegion("mock-region")
var options LoadOptions
err := WithRegion("mock-region")(&options)
if err != nil {
t.Fatalf("expect not error, got %v", err)
}

cfgs, err := configs{}.AppendFromLoaders([]loader{
func(configs configs) (Config, error) {
if e, a := 0, len(configs); e != a {
t.Errorf("expect %v configs, got %v", e, a)
}
return expectCfg, nil
},
})
cfgs, err := configs{}.AppendFromLoaders(
context.TODO(), []loader{
func(ctx context.Context, configs configs) (Config, error) {
if e, a := 0, len(configs); e != a {
t.Errorf("expect %v configs, got %v", e, a)
}
return options, nil
},
})

if err != nil {
t.Fatalf("expect no error, got %v", err)
Expand All @@ -71,13 +83,14 @@ func TestConfigs_AppendFromLoaders(t *testing.T) {
t.Errorf("expect %v configs, got %v", e, a)
}

if e, a := expectCfg, cfgs[0]; e != a {
t.Errorf("expect %v config, got %v", e, a)
if diff := cmp.Diff(options, cfgs[0]); len(diff) != 0 {
t.Errorf("expect config match, got diff: \n %s", diff)
}
}

func TestConfigs_ResolveAWSConfig(t *testing.T) {
configSources := configs{
var options LoadOptions
optFns := []func(*LoadOptions) error{
WithRegion("mock-region"),
WithCredentialsProvider(credentials.StaticCredentialsProvider{
Value: aws.Credentials{
Expand All @@ -87,7 +100,13 @@ func TestConfigs_ResolveAWSConfig(t *testing.T) {
}),
}

cfg, err := configSources.ResolveAWSConfig([]awsConfigResolver{
for _, optFn := range optFns {
optFn(&options)
}

config := configs{options}

cfg, err := config.ResolveAWSConfig(context.TODO(), []awsConfigResolver{
resolveRegion,
resolveCredentials,
})
Expand All @@ -99,7 +118,7 @@ func TestConfigs_ResolveAWSConfig(t *testing.T) {
t.Errorf("expect %v region, got %v", e, a)
}

creds, err := cfg.Credentials.Retrieve(context.Background())
creds, err := cfg.Credentials.Retrieve(context.TODO())
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
Expand All @@ -112,7 +131,7 @@ func TestConfigs_ResolveAWSConfig(t *testing.T) {
expectedSources = append(expectedSources, s)
}

if e, a := expectedSources, cfg.ConfigSources; !reflect.DeepEqual(e, a) {
t.Errorf("expected %v, got %v", e, a)
if diff := cmp.Diff(expectedSources, cfg.ConfigSources); len(diff) != 0 {
t.Errorf("expect config sources match, got diff: \n %s", diff)
}
}
2 changes: 1 addition & 1 deletion config/doc.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,5 @@
// implement the same provider interface, priority will be handled by the order in which the sources were passed in.
//
// A number of helpers (prefixed by ``With``) are provided in this package that implement their respective provider
// interface. These helpers should be used for overriding configuration programatically at runtime.
// interface. These helpers should be used for overriding configuration programmatically at runtime.
package config
13 changes: 7 additions & 6 deletions config/doc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,14 @@ import (
)

func Example() {
cfg, err := config.LoadDefaultConfig()
ctx := context.TODO()
cfg, err := config.LoadDefaultConfig(ctx)
if err != nil {
log.Fatal(err)
}

client := sts.NewFromConfig(cfg)

identity, err := client.GetCallerIdentity(context.Background(), &sts.GetCallerIdentityInput{})
identity, err := client.GetCallerIdentity(ctx, &sts.GetCallerIdentityInput{})
if err != nil {
log.Fatal(err)
}
Expand All @@ -27,9 +27,11 @@ func Example() {
}

func Example_custom_config() {
ctx := context.TODO()

// Config sources can be passed to LoadDefaultConfig, these sources can implement one or more
// provider interfaces. These sources take priority over the standard environment and shared configuration values.
cfg, err := config.LoadDefaultConfig(
cfg, err := config.LoadDefaultConfig(ctx,
config.WithRegion("us-west-2"),
config.WithSharedConfigProfile("customProfile"),
)
Expand All @@ -38,8 +40,7 @@ func Example_custom_config() {
}

client := sts.NewFromConfig(cfg)

identity, err := client.GetCallerIdentity(context.Background(), &sts.GetCallerIdentityInput{})
identity, err := client.GetCallerIdentity(ctx, &sts.GetCallerIdentityInput{})
if err != nil {
log.Fatal(err)
}
Expand Down
Loading