Skip to content

Commit

Permalink
Merge pull request #35926 from hashicorp/f-token_bucket_rate_limiter_…
Browse files Browse the repository at this point in the history
…capacity

Add provider `token_bucket_rate_limiter_capacity` parameter
  • Loading branch information
ewbankkit authored Feb 21, 2024
2 parents 789a197 + d5dc0b8 commit 6a9a929
Show file tree
Hide file tree
Showing 46 changed files with 593 additions and 567 deletions.
3 changes: 3 additions & 0 deletions .changelog/35926.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
```release-note:enhancement
provider: Add `token_bucket_rate_limiter_capacity` parameter
```
5 changes: 3 additions & 2 deletions internal/acctest/vcr.go
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ func vcrProviderConfigureContextFunc(provider *schema.Provider, configureContext
} else {
meta = new(conns.AWSClient)
}
meta.SetHTTPClient(httpClient)
meta.SetHTTPClient(ctx, httpClient)
provider.SetMeta(meta)

if v, ds := configureContextFunc(ctx, d); ds.HasError() {
Expand Down Expand Up @@ -391,14 +391,15 @@ func closeVCRRecorder(t *testing.T) {
panic(p)
}

ctx := context.TODO() // nosemgrep:ci.semgrep.migrate.context-todo
testName := t.Name()
providerMetas.Lock()
meta, ok := providerMetas[testName]
defer providerMetas.Unlock()

if ok {
if !t.Failed() {
if v, ok := meta.HTTPClient().Transport.(*recorder.Recorder); ok {
if v, ok := meta.HTTPClient(ctx).Transport.(*recorder.Recorder); ok {
t.Log("stopping VCR recorder")
if err := v.Stop(); err != nil {
t.Error(err)
Expand Down
52 changes: 30 additions & 22 deletions internal/conns/awsclient.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ import (
aws_sdkv2 "github.com/aws/aws-sdk-go-v2/aws"
config_sdkv2 "github.com/aws/aws-sdk-go-v2/config"
s3_sdkv2 "github.com/aws/aws-sdk-go-v2/service/s3"
endpoints_sdkv1 "github.com/aws/aws-sdk-go/aws/endpoints"
session_sdkv1 "github.com/aws/aws-sdk-go/aws/session"
apigatewayv2_sdkv1 "github.com/aws/aws-sdk-go/service/apigatewayv2"
baselogging "github.com/hashicorp/aws-sdk-go-base/v2/logging"
Expand All @@ -26,18 +25,17 @@ import (
type AWSClient struct {
AccountID string
DefaultTagsConfig *tftags.DefaultConfig
DNSSuffix string
IgnoreTagsConfig *tftags.IgnoreConfig
Partition string
Region string
ReverseDNSPrefix string
ServicePackages map[string]ServicePackage
Session *session_sdkv1.Session
TerraformVersion string

awsConfig *aws_sdkv2.Config
clients map[string]any
conns map[string]any
dnsSuffix string
endpoints map[string]string // From provider configuration.
httpClient *http.Client
lock sync.Mutex
Expand All @@ -49,29 +47,29 @@ type AWSClient struct {
}

// CredentialsProvider returns the AWS SDK for Go v2 credentials provider.
func (c *AWSClient) CredentialsProvider() aws_sdkv2.CredentialsProvider {
func (c *AWSClient) CredentialsProvider(context.Context) aws_sdkv2.CredentialsProvider {
if c.awsConfig == nil {
return nil
}
return c.awsConfig.Credentials
}

func (c *AWSClient) AwsConfig() aws_sdkv2.Config { // nosemgrep:ci.aws-in-func-name
func (c *AWSClient) AwsConfig(context.Context) aws_sdkv2.Config { // nosemgrep:ci.aws-in-func-name
return c.awsConfig.Copy()
}

// PartitionHostname returns a hostname with the provider domain suffix for the partition
// e.g. PREFIX.amazonaws.com
// The prefix should not contain a trailing period.
func (c *AWSClient) PartitionHostname(prefix string) string {
return fmt.Sprintf("%s.%s", prefix, c.DNSSuffix)
func (c *AWSClient) PartitionHostname(ctx context.Context, prefix string) string {
return fmt.Sprintf("%s.%s", prefix, c.DNSSuffix(ctx))
}

// RegionalHostname returns a hostname with the provider domain suffix for the region and partition
// e.g. PREFIX.us-west-2.amazonaws.com
// The prefix should not contain a trailing period.
func (c *AWSClient) RegionalHostname(prefix string) string {
return fmt.Sprintf("%s.%s.%s", prefix, c.Region, c.DNSSuffix)
func (c *AWSClient) RegionalHostname(ctx context.Context, prefix string) string {
return fmt.Sprintf("%s.%s.%s", prefix, c.Region, c.DNSSuffix(ctx))
}

// S3ExpressClient returns an S3 API client suitable for use with S3 Express (directory buckets).
Expand All @@ -97,20 +95,20 @@ func (c *AWSClient) S3ExpressClient(ctx context.Context) *s3_sdkv2.Client {
}

// S3UsePathStyle returns the s3_force_path_style provider configuration value.
func (c *AWSClient) S3UsePathStyle() bool {
func (c *AWSClient) S3UsePathStyle(context.Context) bool {
return c.s3UsePathStyle
}

// SetHTTPClient sets the http.Client used for AWS API calls.
// To have effect it must be called before the AWS SDK v1 Session is created.
func (c *AWSClient) SetHTTPClient(httpClient *http.Client) {
func (c *AWSClient) SetHTTPClient(_ context.Context, httpClient *http.Client) {
if c.Session == nil {
c.httpClient = httpClient
}
}

// HTTPClient returns the http.Client used for AWS API calls.
func (c *AWSClient) HTTPClient() *http.Client {
func (c *AWSClient) HTTPClient(context.Context) *http.Client {
return c.httpClient
}

Expand All @@ -121,36 +119,36 @@ func (c *AWSClient) RegisterLogger(ctx context.Context) context.Context {

// APIGatewayInvokeURL returns the Amazon API Gateway (REST APIs) invoke URL for the configured AWS Region.
// See https://docs.aws.amazon.com/apigateway/latest/developerguide/how-to-call-api.html.
func (c *AWSClient) APIGatewayInvokeURL(restAPIID, stageName string) string {
return fmt.Sprintf("https://%s/%s", c.RegionalHostname(fmt.Sprintf("%s.execute-api", restAPIID)), stageName)
func (c *AWSClient) APIGatewayInvokeURL(ctx context.Context, restAPIID, stageName string) string {
return fmt.Sprintf("https://%s/%s", c.RegionalHostname(ctx, fmt.Sprintf("%s.execute-api", restAPIID)), stageName)
}

// APIGatewayV2InvokeURL returns the Amazon API Gateway v2 (WebSocket & HTTP APIs) invoke URL for the configured AWS Region.
// See https://docs.aws.amazon.com/apigateway/latest/developerguide/http-api-publish.html and
// https://docs.aws.amazon.com/apigateway/latest/developerguide/apigateway-set-up-websocket-deployment.html.
func (c *AWSClient) APIGatewayV2InvokeURL(protocolType, apiID, stageName string) string {
func (c *AWSClient) APIGatewayV2InvokeURL(ctx context.Context, protocolType, apiID, stageName string) string {
if protocolType == apigatewayv2_sdkv1.ProtocolTypeWebsocket {
return fmt.Sprintf("wss://%s/%s", c.RegionalHostname(fmt.Sprintf("%s.execute-api", apiID)), stageName)
return fmt.Sprintf("wss://%s/%s", c.RegionalHostname(ctx, fmt.Sprintf("%s.execute-api", apiID)), stageName)
}

if stageName == "$default" {
return fmt.Sprintf("https://%s/", c.RegionalHostname(fmt.Sprintf("%s.execute-api", apiID)))
return fmt.Sprintf("https://%s/", c.RegionalHostname(ctx, fmt.Sprintf("%s.execute-api", apiID)))
}

return fmt.Sprintf("https://%s/%s", c.RegionalHostname(fmt.Sprintf("%s.execute-api", apiID)), stageName)
return fmt.Sprintf("https://%s/%s", c.RegionalHostname(ctx, fmt.Sprintf("%s.execute-api", apiID)), stageName)
}

// CloudFrontDistributionHostedZoneID returns the Route 53 hosted zone ID
// for Amazon CloudFront distributions in the configured AWS partition.
func (c *AWSClient) CloudFrontDistributionHostedZoneID() string {
if c.Partition == endpoints_sdkv1.AwsCnPartitionID {
func (c *AWSClient) CloudFrontDistributionHostedZoneID(context.Context) string {
if c.Partition == names.ChinaPartitionID {
return "Z3RFFRIM2A3IF5" // See https://docs.amazonaws.cn/en_us/aws/latest/userguide/route53.html
}
return "Z2FDTNDATAQYW2" // See https://docs.aws.amazon.com/Route53/latest/APIReference/API_AliasTarget.html#Route53-Type-AliasTarget-HostedZoneId
}

// DefaultKMSKeyPolicy returns the default policy for KMS keys in the configured AWS partition.
func (c *AWSClient) DefaultKMSKeyPolicy() string {
func (c *AWSClient) DefaultKMSKeyPolicy(context.Context) string {
return fmt.Sprintf(`
{
"Id": "default",
Expand All @@ -172,10 +170,20 @@ func (c *AWSClient) DefaultKMSKeyPolicy() string {

// GlobalAcceleratorHostedZoneID returns the Route 53 hosted zone ID
// for AWS Global Accelerator accelerators in the configured AWS partition.
func (c *AWSClient) GlobalAcceleratorHostedZoneID() string {
func (c *AWSClient) GlobalAcceleratorHostedZoneID(context.Context) string {
return "Z2BJ6XQ5FK7U4H" // See https://docs.aws.amazon.com/general/latest/gr/global_accelerator.html#global_accelerator_region
}

// DNSSuffix returns the domain suffix for the configured AWS partition.
func (c *AWSClient) DNSSuffix(context.Context) string {
return c.dnsSuffix
}

// ReverseDNSPrefix returns the reverse DNS prefix for the configured AWS partition.
func (c *AWSClient) ReverseDNSPrefix(ctx context.Context) string {
return names.ReverseDNS(c.DNSSuffix(ctx))
}

// apiClientConfig returns the AWS API client configuration parameters for the specified service.
func (c *AWSClient) apiClientConfig(ctx context.Context, servicePackageName string) map[string]any {
m := map[string]any{
Expand Down
15 changes: 9 additions & 6 deletions internal/conns/awsclient_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@
package conns

import (
"context"
"testing"
)

func TestAWSClientPartitionHostname(t *testing.T) { // nosemgrep:ci.aws-in-func-name
t.Parallel()

ctx := context.TODO()
testCases := []struct {
Name string
AWSClient *AWSClient
Expand All @@ -19,15 +21,15 @@ func TestAWSClientPartitionHostname(t *testing.T) { // nosemgrep:ci.aws-in-func-
{
Name: "AWS Commercial",
AWSClient: &AWSClient{
DNSSuffix: "amazonaws.com",
dnsSuffix: "amazonaws.com",
},
Prefix: "test",
Expected: "test.amazonaws.com",
},
{
Name: "AWS China",
AWSClient: &AWSClient{
DNSSuffix: "amazonaws.com.cn",
dnsSuffix: "amazonaws.com.cn",
},
Prefix: "test",
Expected: "test.amazonaws.com.cn",
Expand All @@ -39,7 +41,7 @@ func TestAWSClientPartitionHostname(t *testing.T) { // nosemgrep:ci.aws-in-func-
t.Run(testCase.Name, func(t *testing.T) {
t.Parallel()

got := testCase.AWSClient.PartitionHostname(testCase.Prefix)
got := testCase.AWSClient.PartitionHostname(ctx, testCase.Prefix)

if got != testCase.Expected {
t.Errorf("got %s, expected %s", got, testCase.Expected)
Expand All @@ -51,6 +53,7 @@ func TestAWSClientPartitionHostname(t *testing.T) { // nosemgrep:ci.aws-in-func-
func TestAWSClientRegionalHostname(t *testing.T) { // nosemgrep:ci.aws-in-func-name
t.Parallel()

ctx := context.TODO()
testCases := []struct {
Name string
AWSClient *AWSClient
Expand All @@ -60,7 +63,7 @@ func TestAWSClientRegionalHostname(t *testing.T) { // nosemgrep:ci.aws-in-func-n
{
Name: "AWS Commercial",
AWSClient: &AWSClient{
DNSSuffix: "amazonaws.com",
dnsSuffix: "amazonaws.com",
Region: "us-west-2", //lintignore:AWSAT003
},
Prefix: "test",
Expand All @@ -69,7 +72,7 @@ func TestAWSClientRegionalHostname(t *testing.T) { // nosemgrep:ci.aws-in-func-n
{
Name: "AWS China",
AWSClient: &AWSClient{
DNSSuffix: "amazonaws.com.cn",
dnsSuffix: "amazonaws.com.cn",
Region: "cn-northwest-1", //lintignore:AWSAT003
},
Prefix: "test",
Expand All @@ -82,7 +85,7 @@ func TestAWSClientRegionalHostname(t *testing.T) { // nosemgrep:ci.aws-in-func-n
t.Run(testCase.Name, func(t *testing.T) {
t.Parallel()

got := testCase.AWSClient.RegionalHostname(testCase.Prefix)
got := testCase.AWSClient.RegionalHostname(ctx, testCase.Prefix)

if got != testCase.Expected {
t.Errorf("got %s, expected %s", got, testCase.Expected)
Expand Down
69 changes: 35 additions & 34 deletions internal/conns/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ type Config struct {
SuppressDebugLog bool
TerraformVersion string
Token string
TokenBucketRateLimiterCapacity int
UseDualStackEndpoint bool
UseFIPSEndpoint bool
}
Expand All @@ -68,35 +69,36 @@ func (c *Config) ConfigureProvider(ctx context.Context, client *AWSClient) (*AWS
ctx, logger := logging.NewTfLogger(ctx)

awsbaseConfig := awsbase.Config{
AccessKey: c.AccessKey,
AllowedAccountIds: c.AllowedAccountIds,
APNInfo: StdUserAgentProducts(c.TerraformVersion),
AssumeRoleWithWebIdentity: c.AssumeRoleWithWebIdentity,
CallerDocumentationURL: "https://registry.terraform.io/providers/hashicorp/aws",
CallerName: "Terraform AWS Provider",
EC2MetadataServiceEnableState: c.EC2MetadataServiceEnableState,
ForbiddenAccountIds: c.ForbiddenAccountIds,
IamEndpoint: c.Endpoints[names.IAM],
Insecure: c.Insecure,
HTTPClient: client.HTTPClient(),
HTTPProxy: c.HTTPProxy,
HTTPSProxy: c.HTTPSProxy,
HTTPProxyMode: awsbase.HTTPProxyModeLegacy,
Logger: logger,
MaxRetries: c.MaxRetries,
NoProxy: c.NoProxy,
Profile: c.Profile,
Region: c.Region,
RetryMode: c.RetryMode,
SecretKey: c.SecretKey,
SkipCredsValidation: c.SkipCredsValidation,
SkipRequestingAccountId: c.SkipRequestingAccountId,
SsoEndpoint: c.Endpoints[names.SSO],
StsEndpoint: c.Endpoints[names.STS],
SuppressDebugLog: c.SuppressDebugLog,
Token: c.Token,
UseDualStackEndpoint: c.UseDualStackEndpoint,
UseFIPSEndpoint: c.UseFIPSEndpoint,
AccessKey: c.AccessKey,
AllowedAccountIds: c.AllowedAccountIds,
APNInfo: StdUserAgentProducts(c.TerraformVersion),
AssumeRoleWithWebIdentity: c.AssumeRoleWithWebIdentity,
CallerDocumentationURL: "https://registry.terraform.io/providers/hashicorp/aws",
CallerName: "Terraform AWS Provider",
EC2MetadataServiceEnableState: c.EC2MetadataServiceEnableState,
ForbiddenAccountIds: c.ForbiddenAccountIds,
IamEndpoint: c.Endpoints[names.IAM],
Insecure: c.Insecure,
HTTPClient: client.HTTPClient(ctx),
HTTPProxy: c.HTTPProxy,
HTTPSProxy: c.HTTPSProxy,
HTTPProxyMode: awsbase.HTTPProxyModeLegacy,
Logger: logger,
MaxRetries: c.MaxRetries,
NoProxy: c.NoProxy,
Profile: c.Profile,
Region: c.Region,
RetryMode: c.RetryMode,
SecretKey: c.SecretKey,
SkipCredsValidation: c.SkipCredsValidation,
SkipRequestingAccountId: c.SkipRequestingAccountId,
SsoEndpoint: c.Endpoints[names.SSO],
StsEndpoint: c.Endpoints[names.STS],
SuppressDebugLog: c.SuppressDebugLog,
Token: c.Token,
TokenBucketRateLimiterCapacity: c.TokenBucketRateLimiterCapacity,
UseDualStackEndpoint: c.UseDualStackEndpoint,
UseFIPSEndpoint: c.UseFIPSEndpoint,
}

if c.AssumeRole != nil && c.AssumeRole.RoleARN != "" {
Expand Down Expand Up @@ -189,19 +191,18 @@ func (c *Config) ConfigureProvider(ctx context.Context, client *AWSClient) (*AWS
return nil, sdkdiag.AppendErrorf(diags, err.Error())
}

DNSSuffix := "amazonaws.com"
dnsSuffix := "amazonaws.com"
if p, ok := endpoints_sdkv1.PartitionForRegion(endpoints_sdkv1.DefaultPartitions(), c.Region); ok {
DNSSuffix = p.DNSSuffix()
dnsSuffix = p.DNSSuffix()
}

client.AccountID = accountID
client.DefaultTagsConfig = c.DefaultTagsConfig
client.DNSSuffix = DNSSuffix
client.dnsSuffix = dnsSuffix
client.IgnoreTagsConfig = c.IgnoreTagsConfig
client.Partition = partition
client.Region = c.Region
client.ReverseDNSPrefix = names.ReverseDNS(DNSSuffix)
client.SetHTTPClient(sess.Config.HTTPClient) // Must be called while client.Session is nil.
client.SetHTTPClient(ctx, sess.Config.HTTPClient) // Must be called while client.Session is nil.
client.Session = sess
client.TerraformVersion = c.TerraformVersion

Expand Down
2 changes: 1 addition & 1 deletion internal/conns/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -479,7 +479,7 @@ func TestProxyConfig(t *testing.T) {

meta := p.Meta().(*conns.AWSClient)

client := meta.AwsConfig().HTTPClient
client := meta.AwsConfig(ctx).HTTPClient
bClient, ok := client.(*awshttp.BuildableClient)
if !ok {
t.Fatalf("expected awshttp.BuildableClient, got %T", client)
Expand Down
4 changes: 4 additions & 0 deletions internal/provider/fwprovider/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,10 @@ func (p *fwprovider) Schema(ctx context.Context, req provider.SchemaRequest, res
Optional: true,
Description: "session token. A session token is only required if you are\nusing temporary security credentials.",
},
"token_bucket_rate_limiter_capacity": schema.Int64Attribute{
Optional: true,
Description: "The capacity of the AWS SDK's token bucket rate limiter.",
},
"use_dualstack_endpoint": schema.BoolAttribute{
Optional: true,
Description: "Resolve an endpoint with DualStack capability",
Expand Down
Loading

0 comments on commit 6a9a929

Please sign in to comment.