diff --git a/.changelog/9ebe24c4791541e0840da49eab6f9d97.json b/.changelog/9ebe24c4791541e0840da49eab6f9d97.json new file mode 100644 index 00000000000..f733e75674d --- /dev/null +++ b/.changelog/9ebe24c4791541e0840da49eab6f9d97.json @@ -0,0 +1,11 @@ +{ + "id": "9ebe24c4-7915-41e0-840d-a49eab6f9d97", + "type": "feature", + "description": "Enable HTTP checksums in supported services by default. New config fields, RequestChecksumCalculation and ResponseChecksumValidation, allow the caller to opt-out of this new default behavior. This feature also replaces the default MD5 checksum with CRC32.", + "modules": [ + ".", + "config", + "service/internal/checksum", + "service/s3" + ] +} \ No newline at end of file diff --git a/aws/checksum.go b/aws/checksum.go new file mode 100644 index 00000000000..4152caade10 --- /dev/null +++ b/aws/checksum.go @@ -0,0 +1,33 @@ +package aws + +// RequestChecksumCalculation controls request checksum calculation workflow +type RequestChecksumCalculation int + +const ( + // RequestChecksumCalculationUnset is the unset value for RequestChecksumCalculation + RequestChecksumCalculationUnset RequestChecksumCalculation = iota + + // RequestChecksumCalculationWhenSupported indicates request checksum will be calculated + // if the operation supports input checksums + RequestChecksumCalculationWhenSupported + + // RequestChecksumCalculationWhenRequired indicates request checksum will be calculated + // if required by the operation or if user elects to set a checksum algorithm in request + RequestChecksumCalculationWhenRequired +) + +// ResponseChecksumValidation controls response checksum validation workflow +type ResponseChecksumValidation int + +const ( + // ResponseChecksumValidationUnset is the unset value for ResponseChecksumValidation + ResponseChecksumValidationUnset ResponseChecksumValidation = iota + + // ResponseChecksumValidationWhenSupported indicates response checksum will be validated + // if the operation supports output checksums + ResponseChecksumValidationWhenSupported + + // ResponseChecksumValidationWhenRequired indicates response checksum will only + // be validated if the operation requires output checksum validation + ResponseChecksumValidationWhenRequired +) diff --git a/aws/config.go b/aws/config.go index 16000d79279..a015cc5b20c 100644 --- a/aws/config.go +++ b/aws/config.go @@ -165,6 +165,33 @@ type Config struct { // Controls how a resolved AWS account ID is handled for endpoint routing. AccountIDEndpointMode AccountIDEndpointMode + + // RequestChecksumCalculation determines when request checksum calculation is performed. + // + // There are two possible values for this setting: + // + // 1. RequestChecksumCalculationWhenSupported (default): The checksum is always calculated + // if the operation supports it, regardless of whether the user sets an algorithm in the request. + // + // 2. RequestChecksumCalculationWhenRequired: The checksum is only calculated if the user + // explicitly sets a checksum algorithm in the request. + // + // This setting is sourced from the environment variable AWS_REQUEST_CHECKSUM_CALCULATION + // or the shared config profile attribute "request_checksum_calculation". + RequestChecksumCalculation RequestChecksumCalculation + + // ResponseChecksumValidation determines when response checksum validation is performed + // + // There are two possible values for this setting: + // + // 1. ResponseChecksumValidationWhenSupported (default): The checksum is always validated + // if the operation supports it, regardless of whether the user sets the validation mode to ENABLED in request. + // + // 2. ResponseChecksumValidationWhenRequired: The checksum is only validated if the user + // explicitly sets the validation mode to ENABLED in the request + // This variable is sourced from environment variable AWS_RESPONSE_CHECKSUM_VALIDATION or + // the shared config profile attribute "response_checksum_validation". + ResponseChecksumValidation ResponseChecksumValidation } // NewConfig returns a new Config pointer that can be chained with builder diff --git a/codegen/smithy-aws-go-codegen/src/main/java/software/amazon/smithy/aws/go/codegen/AddAwsConfigFields.java b/codegen/smithy-aws-go-codegen/src/main/java/software/amazon/smithy/aws/go/codegen/AddAwsConfigFields.java index b73fcaa2c46..dd52460d517 100644 --- a/codegen/smithy-aws-go-codegen/src/main/java/software/amazon/smithy/aws/go/codegen/AddAwsConfigFields.java +++ b/codegen/smithy-aws-go-codegen/src/main/java/software/amazon/smithy/aws/go/codegen/AddAwsConfigFields.java @@ -84,6 +84,10 @@ public class AddAwsConfigFields implements GoIntegration { private static final String SDK_ACCOUNTID_ENDPOINT_MODE = "AccountIDEndpointMode"; + private static final String REQUEST_CHECKSUM_CALCULATION = "RequestChecksumCalculation"; + + private static final String RESPONSE_CHECKSUM_VALIDATION = "ResponseChecksumValidation"; + private static final List AWS_CONFIG_FIELDS = ListUtils.of( AwsConfigField.builder() .name(REGION_CONFIG_NAME) @@ -244,6 +248,18 @@ public class AddAwsConfigFields implements GoIntegration { .type(SdkGoTypes.Aws.AccountIDEndpointMode) .documentation("Indicates how aws account ID is applied in endpoint2.0 routing") .servicePredicate(AccountIDEndpointRouting::hasAccountIdEndpoints) + .build(), + AwsConfigField.builder() + .name(REQUEST_CHECKSUM_CALCULATION) + .type(SdkGoTypes.Aws.RequestChecksumCalculation) + .documentation("Indicates how user opt-in/out request checksum calculation") + .servicePredicate(AwsHttpChecksumGenerator::hasInputChecksumTrait) + .build(), + AwsConfigField.builder() + .name(RESPONSE_CHECKSUM_VALIDATION) + .type(SdkGoTypes.Aws.ResponseChecksumValidation) + .documentation("Indicates how user opt-in/out response checksum validation") + .servicePredicate(AwsHttpChecksumGenerator::hasOutputChecksumTrait) .build() ); diff --git a/codegen/smithy-aws-go-codegen/src/main/java/software/amazon/smithy/aws/go/codegen/AwsHttpChecksumGenerator.java b/codegen/smithy-aws-go-codegen/src/main/java/software/amazon/smithy/aws/go/codegen/AwsHttpChecksumGenerator.java index a23b71cdbc6..f21c39e665f 100644 --- a/codegen/smithy-aws-go-codegen/src/main/java/software/amazon/smithy/aws/go/codegen/AwsHttpChecksumGenerator.java +++ b/codegen/smithy-aws-go-codegen/src/main/java/software/amazon/smithy/aws/go/codegen/AwsHttpChecksumGenerator.java @@ -22,6 +22,7 @@ import software.amazon.smithy.go.codegen.integration.MiddlewareRegistrar; import software.amazon.smithy.go.codegen.integration.RuntimeClientPlugin; import software.amazon.smithy.model.Model; +import software.amazon.smithy.model.knowledge.TopDownIndex; import software.amazon.smithy.model.shapes.MemberShape; import software.amazon.smithy.model.shapes.OperationShape; import software.amazon.smithy.model.shapes.ServiceShape; @@ -73,9 +74,7 @@ public byte getOrder() { @Override public void processFinalizedModel(GoSettings settings, Model model) { ServiceShape service = settings.getService(model); - for (ShapeId operationId : service.getAllOperations()) { - final OperationShape operation = model.expectShape(operationId, OperationShape.class); - + for (OperationShape operation : TopDownIndex.of(model).getContainedOperations(service)) { // Create a symbol provider because one is not available in this call. SymbolProvider symbolProvider = GoCodegenPlugin.createSymbolProvider(model, settings); @@ -128,8 +127,7 @@ public void writeAdditionalFiles( boolean supportsComputeInputChecksumsWorkflow = false; boolean supportsChecksumValidationWorkflow = false; - for (ShapeId operationID : service.getAllOperations()) { - OperationShape operation = model.expectShape(operationID, OperationShape.class); + for (OperationShape operation : TopDownIndex.of(model).getContainedOperations(service)) { if (!hasChecksumTrait(model, service, operation)) { continue; } @@ -178,11 +176,11 @@ public List getClientPlugins() { } // return true if operation shape is decorated with `httpChecksum` trait. - private boolean hasChecksumTrait(Model model, ServiceShape service, OperationShape operation) { + private static boolean hasChecksumTrait(Model model, ServiceShape service, OperationShape operation) { return operation.hasTrait(HttpChecksumTrait.class); } - private boolean hasInputChecksumTrait(Model model, ServiceShape service, OperationShape operation) { + private static boolean hasInputChecksumTrait(Model model, ServiceShape service, OperationShape operation) { if (!hasChecksumTrait(model, service, operation)) { return false; } @@ -190,7 +188,16 @@ private boolean hasInputChecksumTrait(Model model, ServiceShape service, Operati return trait.isRequestChecksumRequired() || trait.getRequestAlgorithmMember().isPresent(); } - private boolean hasOutputChecksumTrait(Model model, ServiceShape service, OperationShape operation) { + public static boolean hasInputChecksumTrait(Model model, ServiceShape service) { + for (OperationShape operation : TopDownIndex.of(model).getContainedOperations(service)) { + if (hasInputChecksumTrait(model, service, operation)) { + return true; + } + } + return false; + } + + private static boolean hasOutputChecksumTrait(Model model, ServiceShape service, OperationShape operation) { if (!hasChecksumTrait(model, service, operation)) { return false; } @@ -198,6 +205,15 @@ private boolean hasOutputChecksumTrait(Model model, ServiceShape service, Operat return trait.getRequestValidationModeMember().isPresent() && !trait.getResponseAlgorithms().isEmpty(); } + public static boolean hasOutputChecksumTrait(Model model, ServiceShape service) { + for (OperationShape operation : TopDownIndex.of(model).getContainedOperations(service)) { + if (hasOutputChecksumTrait(model, service, operation)) { + return true; + } + } + return false; + } + private boolean isS3ServiceShape(Model model, ServiceShape service) { String serviceId = service.expectTrait(ServiceTrait.class).getSdkId(); return serviceId.equalsIgnoreCase("S3"); @@ -244,6 +260,7 @@ private void writeInputMiddlewareHelper( return $T(stack, $T{ GetAlgorithm: $L, RequireChecksum: $L, + RequestChecksumCalculation: options.RequestChecksumCalculation, EnableTrailingChecksum: $L, EnableComputeSHA256PayloadHash: true, EnableDecodedContentLengthHeader: $L, @@ -284,6 +301,7 @@ private void writeOutputMiddlewareHelper( writer.write(""" return $T(stack, $T{ GetValidationMode: $L, + ResponseChecksumValidation: options.ResponseChecksumValidation, ValidationAlgorithms: $L, IgnoreMultipartValidation: $L, LogValidationSkipped: true, @@ -293,7 +311,6 @@ private void writeOutputMiddlewareHelper( AwsGoDependency.SERVICE_INTERNAL_CHECKSUM).build(), SymbolUtils.createValueSymbolBuilder("OutputMiddlewareOptions", AwsGoDependency.SERVICE_INTERNAL_CHECKSUM).build(), - getRequestValidationModeAccessorFuncName(operationName), convertToGoStringList(responseAlgorithms), ignoreMultipartChecksumValidationMap.getOrDefault( diff --git a/codegen/smithy-aws-go-codegen/src/main/java/software/amazon/smithy/aws/go/codegen/SdkGoTypes.java b/codegen/smithy-aws-go-codegen/src/main/java/software/amazon/smithy/aws/go/codegen/SdkGoTypes.java index c91a3f4c2aa..e0158ba62af 100644 --- a/codegen/smithy-aws-go-codegen/src/main/java/software/amazon/smithy/aws/go/codegen/SdkGoTypes.java +++ b/codegen/smithy-aws-go-codegen/src/main/java/software/amazon/smithy/aws/go/codegen/SdkGoTypes.java @@ -39,6 +39,8 @@ public static final class Aws { public static final Symbol AccountIDEndpointModeRequired = AwsGoDependency.AWS_CORE.valueSymbol("AccountIDEndpointModeRequired"); public static final Symbol AccountIDEndpointModeDisabled = AwsGoDependency.AWS_CORE.valueSymbol("AccountIDEndpointModeDisabled"); + public static final Symbol RequestChecksumCalculation = AwsGoDependency.AWS_CORE.valueSymbol("RequestChecksumCalculation"); + public static final Symbol ResponseChecksumValidation = AwsGoDependency.AWS_CORE.valueSymbol("ResponseChecksumValidation"); public static final class Middleware { public static final Symbol GetRequiresLegacyEndpoints = AwsGoDependency.AWS_MIDDLEWARE.valueSymbol("GetRequiresLegacyEndpoints"); diff --git a/config/config.go b/config/config.go index d5226cb0437..0577c61869e 100644 --- a/config/config.go +++ b/config/config.go @@ -83,6 +83,12 @@ var defaultAWSConfigResolvers = []awsConfigResolver{ // Sets the AccountIDEndpointMode if present in env var or shared config profile resolveAccountIDEndpointMode, + + // Sets the RequestChecksumCalculation if present in env var or shared config profile + resolveRequestChecksumCalculation, + + // Sets the ResponseChecksumValidation if present in env var or shared config profile + resolveResponseChecksumValidation, } // A Config represents a generic configuration value or set of values. This type diff --git a/config/env_config.go b/config/env_config.go index 3a06f1412a7..a672850f53e 100644 --- a/config/env_config.go +++ b/config/env_config.go @@ -83,6 +83,9 @@ const ( awsAccountIDEnv = "AWS_ACCOUNT_ID" awsAccountIDEndpointModeEnv = "AWS_ACCOUNT_ID_ENDPOINT_MODE" + + awsRequestChecksumCalculation = "AWS_REQUEST_CHECKSUM_CALCULATION" + awsResponseChecksumValidation = "AWS_RESPONSE_CHECKSUM_VALIDATION" ) var ( @@ -296,6 +299,12 @@ type EnvConfig struct { // Indicates whether account ID will be required/ignored in endpoint2.0 routing AccountIDEndpointMode aws.AccountIDEndpointMode + + // Indicates whether request checksum should be calculated + RequestChecksumCalculation aws.RequestChecksumCalculation + + // Indicates whether response checksum should be validated + ResponseChecksumValidation aws.ResponseChecksumValidation } // loadEnvConfig reads configuration values from the OS's environment variables. @@ -400,6 +409,13 @@ func NewEnvConfig() (EnvConfig, error) { return cfg, err } + if err := setRequestChecksumCalculationFromEnvVal(&cfg.RequestChecksumCalculation, []string{awsRequestChecksumCalculation}); err != nil { + return cfg, err + } + if err := setResponseChecksumValidationFromEnvVal(&cfg.ResponseChecksumValidation, []string{awsResponseChecksumValidation}); err != nil { + return cfg, err + } + return cfg, nil } @@ -432,6 +448,14 @@ func (c EnvConfig) getAccountIDEndpointMode(context.Context) (aws.AccountIDEndpo return c.AccountIDEndpointMode, len(c.AccountIDEndpointMode) > 0, nil } +func (c EnvConfig) getRequestChecksumCalculation(context.Context) (aws.RequestChecksumCalculation, bool, error) { + return c.RequestChecksumCalculation, c.RequestChecksumCalculation > 0, nil +} + +func (c EnvConfig) getResponseChecksumValidation(context.Context) (aws.ResponseChecksumValidation, bool, error) { + return c.ResponseChecksumValidation, c.ResponseChecksumValidation > 0, nil +} + // GetRetryMaxAttempts returns the value of AWS_MAX_ATTEMPTS if was specified, // and not 0. func (c EnvConfig) GetRetryMaxAttempts(ctx context.Context) (int, bool, error) { @@ -528,6 +552,45 @@ func setAIDEndPointModeFromEnvVal(m *aws.AccountIDEndpointMode, keys []string) e return nil } +func setRequestChecksumCalculationFromEnvVal(m *aws.RequestChecksumCalculation, keys []string) error { + for _, k := range keys { + value := os.Getenv(k) + if len(value) == 0 { + continue + } + + switch strings.ToLower(value) { + case checksumWhenSupported: + *m = aws.RequestChecksumCalculationWhenSupported + case checksumWhenRequired: + *m = aws.RequestChecksumCalculationWhenRequired + default: + return fmt.Errorf("invalid value for environment variable, %s=%s, must be when_supported/when_required", k, value) + } + } + return nil +} + +func setResponseChecksumValidationFromEnvVal(m *aws.ResponseChecksumValidation, keys []string) error { + for _, k := range keys { + value := os.Getenv(k) + if len(value) == 0 { + continue + } + + switch strings.ToLower(value) { + case checksumWhenSupported: + *m = aws.ResponseChecksumValidationWhenSupported + case checksumWhenRequired: + *m = aws.ResponseChecksumValidationWhenRequired + default: + return fmt.Errorf("invalid value for environment variable, %s=%s, must be when_supported/when_required", k, value) + } + + } + return nil +} + // GetRegion returns the AWS Region if set in the environment. Returns an empty // string if not set. func (c EnvConfig) getRegion(ctx context.Context) (string, bool, error) { diff --git a/config/env_config_test.go b/config/env_config_test.go index 02c00d37aa7..870c46509bc 100644 --- a/config/env_config_test.go +++ b/config/env_config_test.go @@ -514,7 +514,6 @@ func TestNewEnvConfig(t *testing.T) { Config: EnvConfig{ AccountIDEndpointMode: aws.AccountIDEndpointModeRequired, }, - WantErr: false, }, 47: { Env: map[string]string{ @@ -523,6 +522,52 @@ func TestNewEnvConfig(t *testing.T) { Config: EnvConfig{}, WantErr: true, }, + 48: { + Env: map[string]string{ + "AWS_REQUEST_CHECKSUM_CALCULATION": "WHEN_SUPPORTED", + }, + Config: EnvConfig{ + RequestChecksumCalculation: aws.RequestChecksumCalculationWhenSupported, + }, + }, + 49: { + Env: map[string]string{ + "AWS_REQUEST_CHECKSUM_CALCULATION": "when_required", + }, + Config: EnvConfig{ + RequestChecksumCalculation: aws.RequestChecksumCalculationWhenRequired, + }, + }, + 50: { + Env: map[string]string{ + "AWS_REQUEST_CHECKSUM_CALCULATION": "blabla", + }, + Config: EnvConfig{}, + WantErr: true, + }, + 51: { + Env: map[string]string{ + "AWS_RESPONSE_CHECKSUM_VALIDATION": "WHEN_SUPPORTED", + }, + Config: EnvConfig{ + ResponseChecksumValidation: aws.ResponseChecksumValidationWhenSupported, + }, + }, + 52: { + Env: map[string]string{ + "AWS_RESPONSE_CHECKSUM_VALIDATION": "when_Required", + }, + Config: EnvConfig{ + ResponseChecksumValidation: aws.ResponseChecksumValidationWhenRequired, + }, + }, + 53: { + Env: map[string]string{ + "AWS_RESPONSE_CHECKSUM_VALIDATION": "blabla", + }, + Config: EnvConfig{}, + WantErr: true, + }, } for i, c := range cases { diff --git a/config/load_options.go b/config/load_options.go index 5f643977b00..a8465d6b849 100644 --- a/config/load_options.go +++ b/config/load_options.go @@ -216,7 +216,14 @@ type LoadOptions struct { // Whether S3 Express auth is disabled. S3DisableExpressAuth *bool + // Whether account id should be built into endpoint resolution AccountIDEndpointMode aws.AccountIDEndpointMode + + // Specify if request checksum should be calculated + RequestChecksumCalculation aws.RequestChecksumCalculation + + // Specifies if response checksum should be validated + ResponseChecksumValidation aws.ResponseChecksumValidation } func (o LoadOptions) getDefaultsMode(ctx context.Context) (aws.DefaultsMode, bool, error) { @@ -284,6 +291,14 @@ func (o LoadOptions) getAccountIDEndpointMode(ctx context.Context) (aws.AccountI return o.AccountIDEndpointMode, len(o.AccountIDEndpointMode) > 0, nil } +func (o LoadOptions) getRequestChecksumCalculation(ctx context.Context) (aws.RequestChecksumCalculation, bool, error) { + return o.RequestChecksumCalculation, o.RequestChecksumCalculation > 0, nil +} + +func (o LoadOptions) getResponseChecksumValidation(ctx context.Context) (aws.ResponseChecksumValidation, bool, error) { + return o.ResponseChecksumValidation, o.ResponseChecksumValidation > 0, nil +} + // WithRegion is a helper function to construct functional options // that sets Region on config's LoadOptions. Setting the region to // an empty string, will result in the region value being ignored. @@ -340,6 +355,26 @@ func WithAccountIDEndpointMode(m aws.AccountIDEndpointMode) LoadOptionsFunc { } } +// WithRequestChecksumCalculation is a helper function to construct functional options +// that sets RequestChecksumCalculation on config's LoadOptions +func WithRequestChecksumCalculation(c aws.RequestChecksumCalculation) LoadOptionsFunc { + return func(o *LoadOptions) error { + if c > 0 { + o.RequestChecksumCalculation = c + } + return nil + } +} + +// WithResponseChecksumValidation is a helper function to construct functional options +// that sets ResponseChecksumValidation on config's LoadOptions +func WithResponseChecksumValidation(v aws.ResponseChecksumValidation) LoadOptionsFunc { + return func(o *LoadOptions) error { + o.ResponseChecksumValidation = v + return nil + } +} + // getDefaultRegion returns DefaultRegion from config's LoadOptions func (o LoadOptions) getDefaultRegion(ctx context.Context) (string, bool, error) { if len(o.DefaultRegion) == 0 { diff --git a/config/provider.go b/config/provider.go index 043781f1f77..a8ff40d846b 100644 --- a/config/provider.go +++ b/config/provider.go @@ -242,6 +242,40 @@ func getAccountIDEndpointMode(ctx context.Context, configs configs) (value aws.A return } +// requestChecksumCalculationProvider provides access to the RequestChecksumCalculation +type requestChecksumCalculationProvider interface { + getRequestChecksumCalculation(context.Context) (aws.RequestChecksumCalculation, bool, error) +} + +func getRequestChecksumCalculation(ctx context.Context, configs configs) (value aws.RequestChecksumCalculation, found bool, err error) { + for _, cfg := range configs { + if p, ok := cfg.(requestChecksumCalculationProvider); ok { + value, found, err = p.getRequestChecksumCalculation(ctx) + if err != nil || found { + break + } + } + } + return +} + +// responseChecksumValidationProvider provides access to the ResponseChecksumValidation +type responseChecksumValidationProvider interface { + getResponseChecksumValidation(context.Context) (aws.ResponseChecksumValidation, bool, error) +} + +func getResponseChecksumValidation(ctx context.Context, configs configs) (value aws.ResponseChecksumValidation, found bool, err error) { + for _, cfg := range configs { + if p, ok := cfg.(responseChecksumValidationProvider); ok { + value, found, err = p.getResponseChecksumValidation(ctx) + if err != nil || found { + break + } + } + } + return +} + // ec2IMDSRegionProvider provides access to the ec2 imds region // configuration value type ec2IMDSRegionProvider interface { diff --git a/config/resolve.go b/config/resolve.go index 41009c7da06..a68bd0993f7 100644 --- a/config/resolve.go +++ b/config/resolve.go @@ -182,6 +182,36 @@ func resolveAccountIDEndpointMode(ctx context.Context, cfg *aws.Config, configs return nil } +// resolveRequestChecksumCalculation extracts the RequestChecksumCalculation from the configs slice's +// SharedConfig or EnvConfig +func resolveRequestChecksumCalculation(ctx context.Context, cfg *aws.Config, configs configs) error { + c, found, err := getRequestChecksumCalculation(ctx, configs) + if err != nil { + return err + } + + if !found { + c = aws.RequestChecksumCalculationWhenSupported + } + cfg.RequestChecksumCalculation = c + return nil +} + +// resolveResponseValidation extracts the ResponseChecksumValidation from the configs slice's +// SharedConfig or EnvConfig +func resolveResponseChecksumValidation(ctx context.Context, cfg *aws.Config, configs configs) error { + c, found, err := getResponseChecksumValidation(ctx, configs) + if err != nil { + return err + } + + if !found { + c = aws.ResponseChecksumValidationWhenSupported + } + cfg.ResponseChecksumValidation = c + return nil +} + // resolveDefaultRegion extracts the first instance of a default region and sets `aws.Config.Region` to the default // region if region had not been resolved from other sources. func resolveDefaultRegion(ctx context.Context, cfg *aws.Config, configs configs) error { diff --git a/config/resolve_test.go b/config/resolve_test.go index 839378f99e4..69e59dc6e02 100644 --- a/config/resolve_test.go +++ b/config/resolve_test.go @@ -269,6 +269,80 @@ func TestResolveAccountIDEndpointMode(t *testing.T) { } } +func TestResolveRequestChecksumCalculation(t *testing.T) { + cases := map[string]struct { + RequestChecksumCalculation aws.RequestChecksumCalculation + ExpectCalculation aws.RequestChecksumCalculation + }{ + "checksum calculation when required": { + RequestChecksumCalculation: aws.RequestChecksumCalculationWhenRequired, + ExpectCalculation: aws.RequestChecksumCalculationWhenRequired, + }, + "default when unset": { + ExpectCalculation: aws.RequestChecksumCalculationWhenSupported, + }, + } + + for name, c := range cases { + t.Run(name, func(t *testing.T) { + var options LoadOptions + optFns := []func(*LoadOptions) error{ + WithRequestChecksumCalculation(c.RequestChecksumCalculation), + } + + for _, optFn := range optFns { + optFn(&options) + } + + configs := configs{options} + var cfg aws.Config + if err := resolveRequestChecksumCalculation(context.Background(), &cfg, configs); err != nil { + t.Fatalf("expect no error, got %v", err) + } + if e, a := c.ExpectCalculation, cfg.RequestChecksumCalculation; e != a { + t.Errorf("expect RequestChecksumCalculation to be %v, got %v", e, a) + } + }) + } +} + +func TestResolveResponseChecksumValidation(t *testing.T) { + cases := map[string]struct { + ResponseChecksumValidation aws.ResponseChecksumValidation + ExpectValidation aws.ResponseChecksumValidation + }{ + "checksum validation when required": { + ResponseChecksumValidation: aws.ResponseChecksumValidationWhenRequired, + ExpectValidation: aws.ResponseChecksumValidationWhenRequired, + }, + "default when unset": { + ExpectValidation: aws.ResponseChecksumValidationWhenSupported, + }, + } + + for name, c := range cases { + t.Run(name, func(t *testing.T) { + var options LoadOptions + optFns := []func(*LoadOptions) error{ + WithResponseChecksumValidation(c.ResponseChecksumValidation), + } + + for _, optFn := range optFns { + optFn(&options) + } + + configs := configs{options} + var cfg aws.Config + if err := resolveResponseChecksumValidation(context.Background(), &cfg, configs); err != nil { + t.Fatalf("expect no error, got %v", err) + } + if e, a := c.ExpectValidation, cfg.ResponseChecksumValidation; e != a { + t.Errorf("expect ResponseChecksumValidation to be %v, got %v", e, a) + } + }) + } +} + func TestResolveCredentialsProvider(t *testing.T) { var options LoadOptions optFns := []func(options *LoadOptions) error{ diff --git a/config/shared_config.go b/config/shared_config.go index d7a2b5307ea..00b071fe6f1 100644 --- a/config/shared_config.go +++ b/config/shared_config.go @@ -118,6 +118,11 @@ const ( accountIDKey = "aws_account_id" accountIDEndpointMode = "account_id_endpoint_mode" + + requestChecksumCalculationKey = "request_checksum_calculation" + responseChecksumValidationKey = "response_checksum_validation" + checksumWhenSupported = "when_supported" + checksumWhenRequired = "when_required" ) // defaultSharedConfigProfile allows for swapping the default profile for testing @@ -346,6 +351,12 @@ type SharedConfig struct { S3DisableExpressAuth *bool AccountIDEndpointMode aws.AccountIDEndpointMode + + // RequestChecksumCalculation indicates if the request checksum should be calculated + RequestChecksumCalculation aws.RequestChecksumCalculation + + // ResponseChecksumValidation indicates if the response checksum should be validated + ResponseChecksumValidation aws.ResponseChecksumValidation } func (c SharedConfig) getDefaultsMode(ctx context.Context) (value aws.DefaultsMode, ok bool, err error) { @@ -1133,6 +1144,13 @@ func (c *SharedConfig) setFromIniSection(profile string, section ini.Section) er return fmt.Errorf("failed to load %s from shared config, %w", accountIDEndpointMode, err) } + if err := updateRequestChecksumCalculation(&c.RequestChecksumCalculation, section, requestChecksumCalculationKey); err != nil { + return fmt.Errorf("failed to load %s from shared config, %w", requestChecksumCalculationKey, err) + } + if err := updateResponseChecksumValidation(&c.ResponseChecksumValidation, section, responseChecksumValidationKey); err != nil { + return fmt.Errorf("failed to load %s from shared config, %w", responseChecksumValidationKey, err) + } + // Shared Credentials creds := aws.Credentials{ AccessKeyID: section.String(accessKeyIDKey), @@ -1207,6 +1225,42 @@ func updateAIDEndpointMode(m *aws.AccountIDEndpointMode, sec ini.Section, key st return nil } +func updateRequestChecksumCalculation(m *aws.RequestChecksumCalculation, sec ini.Section, key string) error { + if !sec.Has(key) { + return nil + } + + v := sec.String(key) + switch strings.ToLower(v) { + case checksumWhenSupported: + *m = aws.RequestChecksumCalculationWhenSupported + case checksumWhenRequired: + *m = aws.RequestChecksumCalculationWhenRequired + default: + return fmt.Errorf("invalid value for shared config profile field, %s=%s, must be when_supported/when_required", key, v) + } + + return nil +} + +func updateResponseChecksumValidation(m *aws.ResponseChecksumValidation, sec ini.Section, key string) error { + if !sec.Has(key) { + return nil + } + + v := sec.String(key) + switch strings.ToLower(v) { + case checksumWhenSupported: + *m = aws.ResponseChecksumValidationWhenSupported + case checksumWhenRequired: + *m = aws.ResponseChecksumValidationWhenRequired + default: + return fmt.Errorf("invalid value for shared config profile field, %s=%s, must be when_supported/when_required", key, v) + } + + return nil +} + func (c SharedConfig) getRequestMinCompressSizeBytes(ctx context.Context) (int64, bool, error) { if c.RequestMinCompressSizeBytes == nil { return 0, false, nil @@ -1225,6 +1279,14 @@ func (c SharedConfig) getAccountIDEndpointMode(ctx context.Context) (aws.Account return c.AccountIDEndpointMode, len(c.AccountIDEndpointMode) > 0, nil } +func (c SharedConfig) getRequestChecksumCalculation(ctx context.Context) (aws.RequestChecksumCalculation, bool, error) { + return c.RequestChecksumCalculation, c.RequestChecksumCalculation > 0, nil +} + +func (c SharedConfig) getResponseChecksumValidation(ctx context.Context) (aws.ResponseChecksumValidation, bool, error) { + return c.ResponseChecksumValidation, c.ResponseChecksumValidation > 0, nil +} + func updateDefaultsMode(mode *aws.DefaultsMode, section ini.Section, key string) error { if !section.Has(key) { return nil diff --git a/config/shared_config_test.go b/config/shared_config_test.go index ee3f0705e8b..8d71818c8d4 100644 --- a/config/shared_config_test.go +++ b/config/shared_config_test.go @@ -758,6 +758,54 @@ func TestNewSharedConfig(t *testing.T) { }, Err: fmt.Errorf("invalid value for shared config profile field, account_id_endpoint_mode=blabla, must be preferred/required/disabled"), }, + "profile with request checksum calculation when supported": { + ConfigFilenames: []string{testConfigFilename}, + Profile: "request_checksum_calculation_when_supported", + Expected: SharedConfig{ + Profile: "request_checksum_calculation_when_supported", + RequestChecksumCalculation: aws.RequestChecksumCalculationWhenSupported, + }, + }, + "profile with request checksum calculation when required": { + ConfigFilenames: []string{testConfigFilename}, + Profile: "request_checksum_calculation_when_required", + Expected: SharedConfig{ + Profile: "request_checksum_calculation_when_required", + RequestChecksumCalculation: aws.RequestChecksumCalculationWhenRequired, + }, + }, + "profile with invalid request checksum calculation": { + ConfigFilenames: []string{testConfigFilename}, + Profile: "request_checksum_calculation_error", + Expected: SharedConfig{ + Profile: "request_checksum_calculation_error", + }, + Err: fmt.Errorf("invalid value for shared config profile field, request_checksum_calculation=blabla, must be when_supported/when_required"), + }, + "profile with response checksum validation when supported": { + ConfigFilenames: []string{testConfigFilename}, + Profile: "response_checksum_validation_when_supported", + Expected: SharedConfig{ + Profile: "response_checksum_validation_when_supported", + ResponseChecksumValidation: aws.ResponseChecksumValidationWhenSupported, + }, + }, + "profile with response checksum validation when required": { + ConfigFilenames: []string{testConfigFilename}, + Profile: "response_checksum_validation_when_required", + Expected: SharedConfig{ + Profile: "response_checksum_validation_when_required", + ResponseChecksumValidation: aws.ResponseChecksumValidationWhenRequired, + }, + }, + "profile with invalid response checksum validation": { + ConfigFilenames: []string{testConfigFilename}, + Profile: "response_checksum_validation_error", + Expected: SharedConfig{ + Profile: "response_checksum_validation_error", + }, + Err: fmt.Errorf("invalid value for shared config profile field, response_checksum_validation=blabla, must be when_supported/when_required"), + }, } for name, c := range cases { diff --git a/config/testdata/shared_config b/config/testdata/shared_config index b2cbc81873b..c7159d52bff 100644 --- a/config/testdata/shared_config +++ b/config/testdata/shared_config @@ -328,3 +328,22 @@ account_id_endpoint_mode = disabled [profile account_id_endpoint_mode_error] account_id_endpoint_mode = blabla + +[profile request_checksum_calculation_when_supported] +request_checksum_calculation = when_supported + +[profile request_checksum_calculation_when_required] +request_checksum_calculation = WHEN_REQUIRED + +[profile request_checksum_calculation_error] +request_checksum_calculation = blabla + +[profile response_checksum_validation_when_supported] +response_checksum_validation = When_SUPPORTED + +[profile response_checksum_validation_when_required] +response_checksum_validation = when_required + +[profile response_checksum_validation_error] +response_checksum_validation = blabla + diff --git a/service/internal/checksum/algorithms.go b/service/internal/checksum/algorithms.go index a17041c35d0..d241bf59fbb 100644 --- a/service/internal/checksum/algorithms.go +++ b/service/internal/checksum/algorithms.go @@ -30,6 +30,9 @@ const ( // AlgorithmSHA256 represents SHA256 hash algorithm AlgorithmSHA256 Algorithm = "SHA256" + + // AlgorithmCRC64NVME represents CRC64NVME hash algorithm + AlgorithmCRC64NVME Algorithm = "CRC64NVME" ) var supportedAlgorithms = []Algorithm{ diff --git a/service/internal/checksum/middleware_add.go b/service/internal/checksum/middleware_add.go index 1b727acbe17..11243a8048f 100644 --- a/service/internal/checksum/middleware_add.go +++ b/service/internal/checksum/middleware_add.go @@ -1,6 +1,7 @@ package checksum import ( + "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/smithy-go/middleware" ) @@ -14,11 +15,16 @@ type InputMiddlewareOptions struct { // and true, or false if no algorithm is specified. GetAlgorithm func(interface{}) (string, bool) - // Forces the middleware to compute the input payload's checksum. The - // request will fail if the algorithm is not specified or unable to compute - // the checksum. + // RequireChecksum indicates whether operation model forces middleware to compute the input payload's checksum. + // If RequireChecksum is set to true, checksum will be calculated and RequestChecksumCalculation will be ignored, + // otherwise RequestChecksumCalculation will be used to indicate if checksum will be calculated RequireChecksum bool + // RequestChecksumCalculation is the user config to opt-in/out request checksum calculation. If RequireChecksum is + // set to true, checksum will be calculated and this field will be ignored, otherwise + // RequestChecksumCalculation will be used to indicate if checksum will be calculated + RequestChecksumCalculation aws.RequestChecksumCalculation + // Enables support for wrapping the serialized input payload with a // content-encoding: aws-check wrapper, and including a trailer for the // algorithm's checksum value. @@ -72,7 +78,9 @@ func AddInputMiddleware(stack *middleware.Stack, options InputMiddlewareOptions) // Initial checksum configuration look up middleware err = stack.Initialize.Add(&setupInputContext{ - GetAlgorithm: options.GetAlgorithm, + GetAlgorithm: options.GetAlgorithm, + RequireChecksum: options.RequireChecksum, + RequestChecksumCalculation: options.RequestChecksumCalculation, }, middleware.Before) if err != nil { return err @@ -81,7 +89,6 @@ func AddInputMiddleware(stack *middleware.Stack, options InputMiddlewareOptions) stack.Build.Remove("ContentChecksum") inputChecksum := &computeInputPayloadChecksum{ - RequireChecksum: options.RequireChecksum, EnableTrailingChecksum: options.EnableTrailingChecksum, EnableComputePayloadHash: options.EnableComputeSHA256PayloadHash, EnableDecodedContentLengthHeader: options.EnableDecodedContentLengthHeader, @@ -94,7 +101,6 @@ func AddInputMiddleware(stack *middleware.Stack, options InputMiddlewareOptions) if options.EnableTrailingChecksum { trailerMiddleware := &addInputChecksumTrailer{ EnableTrailingChecksum: inputChecksum.EnableTrailingChecksum, - RequireChecksum: inputChecksum.RequireChecksum, EnableComputePayloadHash: inputChecksum.EnableComputePayloadHash, EnableDecodedContentLengthHeader: inputChecksum.EnableDecodedContentLengthHeader, } @@ -126,6 +132,9 @@ type OutputMiddlewareOptions struct { // mode and true, or false if no mode is specified. GetValidationMode func(interface{}) (string, bool) + // ResponseChecksumValidation is the user config to opt-in/out response checksum validation + ResponseChecksumValidation aws.ResponseChecksumValidation + // The set of checksum algorithms that should be used for response payload // checksum validation. The algorithm(s) used will be a union of the // output's returned algorithms and this set. @@ -134,7 +143,7 @@ type OutputMiddlewareOptions struct { ValidationAlgorithms []string // If set the middleware will ignore output multipart checksums. Otherwise - // an checksum format error will be returned by the middleware. + // a checksum format error will be returned by the middleware. IgnoreMultipartValidation bool // When set the middleware will log when output does not have checksum or @@ -150,7 +159,8 @@ type OutputMiddlewareOptions struct { // checksum. func AddOutputMiddleware(stack *middleware.Stack, options OutputMiddlewareOptions) error { err := stack.Initialize.Add(&setupOutputContext{ - GetValidationMode: options.GetValidationMode, + GetValidationMode: options.GetValidationMode, + ResponseChecksumValidation: options.ResponseChecksumValidation, }, middleware.Before) if err != nil { return err diff --git a/service/internal/checksum/middleware_add_test.go b/service/internal/checksum/middleware_add_test.go index da6efe94a26..7219c4a4172 100644 --- a/service/internal/checksum/middleware_add_test.go +++ b/service/internal/checksum/middleware_add_test.go @@ -87,7 +87,6 @@ func TestAddInputMiddleware(t *testing.T) { }, }, expectFinalize: &computeInputPayloadChecksum{ - RequireChecksum: true, EnableTrailingChecksum: true, }, }, @@ -167,9 +166,6 @@ func TestAddInputMiddleware(t *testing.T) { var computeInput *computeInputPayloadChecksum if c.expectFinalize != nil && ok { computeInput = finalizeMW.(*computeInputPayloadChecksum) - if e, a := c.expectFinalize.RequireChecksum, computeInput.RequireChecksum; e != a { - t.Errorf("expect %v require checksum, got %v", e, a) - } if e, a := c.expectFinalize.EnableTrailingChecksum, computeInput.EnableTrailingChecksum; e != a { t.Errorf("expect %v enable trailing checksum, got %v", e, a) } diff --git a/service/internal/checksum/middleware_compute_input_checksum.go b/service/internal/checksum/middleware_compute_input_checksum.go index 7ffca33f0ef..1d8d470c7ef 100644 --- a/service/internal/checksum/middleware_compute_input_checksum.go +++ b/service/internal/checksum/middleware_compute_input_checksum.go @@ -7,6 +7,7 @@ import ( "hash" "io" "strconv" + "strings" v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4" internalcontext "github.com/aws/aws-sdk-go-v2/internal/context" @@ -16,7 +17,6 @@ import ( ) const ( - contentMD5Header = "Content-Md5" streamingUnsignedPayloadTrailerPayloadHash = "STREAMING-UNSIGNED-PAYLOAD-TRAILER" ) @@ -49,13 +49,6 @@ type computeInputPayloadChecksum struct { // the Algorithm's header is already set on the request. EnableTrailingChecksum bool - // States that a checksum is required to be included for the operation. If - // Input does not specify a checksum, fallback to built in MD5 checksum is - // used. - // - // Replaces smithy-go's ContentChecksum middleware. - RequireChecksum bool - // Enables support for computing the SHA256 checksum of input payloads // along with the algorithm specified checksum. Prevents downstream // middleware handlers (computePayloadSHA256) re-reading the payload. @@ -110,6 +103,15 @@ func (m *computeInputPayloadChecksum) HandleFinalize( ) ( out middleware.FinalizeOutput, metadata middleware.Metadata, err error, ) { + var checksum string + algorithm, ok, err := getInputAlgorithm(ctx) + if err != nil { + return out, metadata, err + } + if !ok { + return next.HandleFinalize(ctx, in) + } + req, ok := in.Request.(*smithyhttp.Request) if !ok { return out, metadata, computeInputHeaderChecksumError{ @@ -117,8 +119,6 @@ func (m *computeInputPayloadChecksum) HandleFinalize( } } - var algorithm Algorithm - var checksum string defer func() { if algorithm == "" || checksum == "" || err != nil { return @@ -130,29 +130,14 @@ func (m *computeInputPayloadChecksum) HandleFinalize( }) }() - // If no algorithm was specified, and the operation requires a checksum, - // fallback to the legacy content MD5 checksum. - algorithm, ok, err = getInputAlgorithm(ctx) - if err != nil { - return out, metadata, err - } else if !ok { - if m.RequireChecksum { - checksum, err = setMD5Checksum(ctx, req) - if err != nil { - return out, metadata, computeInputHeaderChecksumError{ - Msg: "failed to compute stream's MD5 checksum", - Err: err, - } - } - algorithm = Algorithm("MD5") + // If any checksum header is already set nothing to do. + for header, _ := range req.Header { + h := strings.ToUpper(header) + if strings.HasPrefix(h, "X-AMZ-CHECKSUM-") { + algorithm = Algorithm(strings.TrimPrefix(h, "X-AMZ-CHECKSUM-")) + checksum = req.Header.Get(header) + return next.HandleFinalize(ctx, in) } - return next.HandleFinalize(ctx, in) - } - - // If the checksum header is already set nothing to do. - checksumHeader := AlgorithmHTTPHeader(algorithm) - if checksum = req.Header.Get(checksumHeader); checksum != "" { - return next.HandleFinalize(ctx, in) } computePayloadHash := m.EnableComputePayloadHash @@ -195,9 +180,7 @@ func (m *computeInputPayloadChecksum) HandleFinalize( // Only seekable streams are supported for non-trailing checksums, because // the stream needs to be rewound before the handler can continue. if stream != nil && !req.IsStreamSeekable() { - return out, metadata, computeInputHeaderChecksumError{ - Msg: "unseekable stream is not supported without TLS and trailing checksum", - } + return next.HandleFinalize(ctx, in) } var sha256Checksum string @@ -217,6 +200,7 @@ func (m *computeInputPayloadChecksum) HandleFinalize( } } + checksumHeader := AlgorithmHTTPHeader(algorithm) req.Header.Set(checksumHeader, checksum) if computePayloadHash { @@ -248,7 +232,6 @@ func (e computeInputTrailingChecksumError) Unwrap() error { return e.Err } // - Trailing checksums are supported. type addInputChecksumTrailer struct { EnableTrailingChecksum bool - RequireChecksum bool EnableComputePayloadHash bool EnableDecodedContentLengthHeader bool } @@ -264,6 +247,16 @@ func (m *addInputChecksumTrailer) HandleFinalize( ) ( out middleware.FinalizeOutput, metadata middleware.Metadata, err error, ) { + algorithm, ok, err := getInputAlgorithm(ctx) + if err != nil { + return out, metadata, computeInputTrailingChecksumError{ + Msg: "failed to get algorithm", + Err: err, + } + } else if !ok { + return next.HandleFinalize(ctx, in) + } + if enabled, _ := middleware.GetStackValue(ctx, useTrailer{}).(bool); !enabled { return next.HandleFinalize(ctx, in) } @@ -281,26 +274,13 @@ func (m *addInputChecksumTrailer) HandleFinalize( } } - // If no algorithm was specified, there is nothing to do. - algorithm, ok, err := getInputAlgorithm(ctx) - if err != nil { - return out, metadata, computeInputTrailingChecksumError{ - Msg: "failed to get algorithm", - Err: err, - } - } else if !ok { - return out, metadata, computeInputTrailingChecksumError{ - Msg: "no algorithm specified", + // If any checksum header is already set nothing to do. + for header, _ := range req.Header { + if strings.HasPrefix(strings.ToLower(header), "x-amz-checksum-") { + return next.HandleFinalize(ctx, in) } } - // If the checksum header is already set before finalize could run, there - // is nothing to do. - checksumHeader := AlgorithmHTTPHeader(algorithm) - if req.Header.Get(checksumHeader) != "" { - return next.HandleFinalize(ctx, in) - } - stream := req.GetStream() streamLength, err := getRequestStreamLength(req) if err != nil { @@ -444,39 +424,3 @@ func getRequestStreamLength(req *smithyhttp.Request) (int64, error) { return -1, nil } - -// setMD5Checksum computes the MD5 of the request payload and sets it to the -// Content-MD5 header. Returning the MD5 base64 encoded string or error. -// -// If the MD5 is already set as the Content-MD5 header, that value will be -// returned, and nothing else will be done. -// -// If the payload is empty, no MD5 will be computed. No error will be returned. -// Empty payloads do not have an MD5 value. -// -// Replaces the smithy-go middleware for httpChecksum trait. -func setMD5Checksum(ctx context.Context, req *smithyhttp.Request) (string, error) { - if v := req.Header.Get(contentMD5Header); len(v) != 0 { - return v, nil - } - stream := req.GetStream() - if stream == nil { - return "", nil - } - - if !req.IsStreamSeekable() { - return "", fmt.Errorf( - "unseekable stream is not supported for computing md5 checksum") - } - - v, err := computeMD5Checksum(stream) - if err != nil { - return "", err - } - if err := req.RewindStream(); err != nil { - return "", fmt.Errorf("failed to rewind stream after computing MD5 checksum, %w", err) - } - // set the 'Content-MD5' header - req.Header.Set(contentMD5Header, string(v)) - return string(v), nil -} diff --git a/service/internal/checksum/middleware_compute_input_checksum_test.go b/service/internal/checksum/middleware_compute_input_checksum_test.go index c8e97de22de..42117cb0382 100644 --- a/service/internal/checksum/middleware_compute_input_checksum_test.go +++ b/service/internal/checksum/middleware_compute_input_checksum_test.go @@ -93,33 +93,96 @@ func TestComputeInputPayloadChecksum(t *testing.T) { "CRC32": "AAAAAA==", }, }, - "no algorithm": { + "http no algorithm checksum header preset": { buildInput: middleware.BuildInput{ Request: func() *smithyhttp.Request { r := smithyhttp.NewStackRequest().(*smithyhttp.Request) - r.URL, _ = url.Parse("https://example.aws") + r.URL, _ = url.Parse("http://example.aws") + r.ContentLength = 11 + r.Header.Set(AlgorithmHTTPHeader(AlgorithmCRC32), "AAAAAA==") + r = requestMust(r.SetStream(bytes.NewReader([]byte("hello world")))) + return r + }(), + }, + expectHeader: http.Header{ + "X-Amz-Checksum-Crc32": []string{"AAAAAA=="}, + }, + expectContentLength: 11, + expectPayload: []byte("hello world"), + }, + "http no algorithm set": { + buildInput: middleware.BuildInput{ + Request: func() *smithyhttp.Request { + r := smithyhttp.NewStackRequest().(*smithyhttp.Request) + r.URL, _ = url.Parse("http://example.aws") r = requestMust(r.SetStream(strings.NewReader("hello world"))) r.ContentLength = 11 return r }(), }, + expectContentLength: 11, expectHeader: http.Header{}, + expectPayload: []byte("hello world"), + }, + "https no algorithm set": { + buildInput: middleware.BuildInput{ + Request: func() *smithyhttp.Request { + r := smithyhttp.NewStackRequest().(*smithyhttp.Request) + r.URL, _ = url.Parse("https://example.aws") + r = requestMust(r.SetStream(strings.NewReader("hello world"))) + r.ContentLength = 11 + return r + }(), + }, expectContentLength: 11, + expectHeader: http.Header{}, expectPayload: []byte("hello world"), }, - "nil stream no algorithm require checksum": { - optionsFn: func(o *computeInputPayloadChecksum) { - o.RequireChecksum = true + "http crc64 checksum header preset": { + initContext: func(ctx context.Context) context.Context { + return internalcontext.SetChecksumInputAlgorithm(ctx, string(AlgorithmCRC32)) }, buildInput: middleware.BuildInput{ Request: func() *smithyhttp.Request { r := smithyhttp.NewStackRequest().(*smithyhttp.Request) r.URL, _ = url.Parse("http://example.aws") + r.ContentLength = 11 + r.Header.Set(AlgorithmHTTPHeader(AlgorithmCRC64NVME), "S2Zv/ZHmbVs=") + r = requestMust(r.SetStream(bytes.NewReader([]byte("hello world")))) return r }(), }, - expectContentLength: -1, - expectHeader: http.Header{}, + expectHeader: http.Header{ + "X-Amz-Checksum-Crc64nvme": []string{"S2Zv/ZHmbVs="}, + }, + expectContentLength: 11, + expectPayload: []byte("hello world"), + expectChecksumMetadata: map[string]string{ + "CRC64NVME": "S2Zv/ZHmbVs=", + }, + }, + "https crc64 checksum header preset": { + initContext: func(ctx context.Context) context.Context { + return internalcontext.SetChecksumInputAlgorithm(ctx, string(AlgorithmCRC32)) + }, + buildInput: middleware.BuildInput{ + Request: func() *smithyhttp.Request { + r := smithyhttp.NewStackRequest().(*smithyhttp.Request) + r.URL, _ = url.Parse("https://example.aws") + r.ContentLength = 11 + r.Header.Set(AlgorithmHTTPHeader(AlgorithmCRC64NVME), "S2Zv/ZHmbVs=") + r = requestMust(r.SetStream(bytes.NewReader([]byte("hello world")))) + return r + }(), + }, + expectHeader: http.Header{ + "X-Amz-Checksum-Crc64nvme": []string{"S2Zv/ZHmbVs="}, + }, + expectContentLength: 11, + expectPayload: []byte("hello world"), + expectChecksumMetadata: map[string]string{ + "CRC64NVME": "S2Zv/ZHmbVs=", + }, }, }, @@ -254,94 +317,27 @@ func TestComputeInputPayloadChecksum(t *testing.T) { "CRC32": "AAAAAA==", }, }, - "http no algorithm require checksum": { - optionsFn: func(o *computeInputPayloadChecksum) { - o.RequireChecksum = true - }, - buildInput: middleware.BuildInput{ - Request: func() *smithyhttp.Request { - r := smithyhttp.NewStackRequest().(*smithyhttp.Request) - r.URL, _ = url.Parse("http://example.aws") - r.ContentLength = 11 - r = requestMust(r.SetStream(bytes.NewReader([]byte("hello world")))) - return r - }(), - }, - expectHeader: http.Header{ - "Content-Md5": []string{"XrY7u+Ae7tCTyyK7j1rNww=="}, - }, - expectContentLength: 11, - expectPayload: []byte("hello world"), - expectChecksumMetadata: map[string]string{ - "MD5": "XrY7u+Ae7tCTyyK7j1rNww==", - }, - }, - "http no algorithm require checksum header preset": { - optionsFn: func(o *computeInputPayloadChecksum) { - o.RequireChecksum = true - }, - buildInput: middleware.BuildInput{ - Request: func() *smithyhttp.Request { - r := smithyhttp.NewStackRequest().(*smithyhttp.Request) - r.URL, _ = url.Parse("http://example.aws") - r.ContentLength = 11 - r.Header.Set("Content-MD5", "XrY7u+Ae7tCTyyK7j1rNww==") - r = requestMust(r.SetStream(bytes.NewReader([]byte("hello world")))) - return r - }(), - }, - expectHeader: http.Header{ - "Content-Md5": []string{"XrY7u+Ae7tCTyyK7j1rNww=="}, - }, - expectContentLength: 11, - expectPayload: []byte("hello world"), - expectChecksumMetadata: map[string]string{ - "MD5": "XrY7u+Ae7tCTyyK7j1rNww==", - }, - }, - "https no algorithm require checksum": { - optionsFn: func(o *computeInputPayloadChecksum) { - o.RequireChecksum = true - }, - buildInput: middleware.BuildInput{ - Request: func() *smithyhttp.Request { - r := smithyhttp.NewStackRequest().(*smithyhttp.Request) - r.URL, _ = url.Parse("https://example.aws") - r.ContentLength = 11 - r = requestMust(r.SetStream(bytes.NewReader([]byte("hello world")))) - return r - }(), - }, - expectHeader: http.Header{ - "Content-Md5": []string{"XrY7u+Ae7tCTyyK7j1rNww=="}, - }, - expectContentLength: 11, - expectPayload: []byte("hello world"), - expectChecksumMetadata: map[string]string{ - "MD5": "XrY7u+Ae7tCTyyK7j1rNww==", - }, - }, "http seekable": { initContext: func(ctx context.Context) context.Context { - return internalcontext.SetChecksumInputAlgorithm(ctx, string(AlgorithmCRC32)) + return internalcontext.SetChecksumInputAlgorithm(ctx, string(AlgorithmCRC32C)) }, buildInput: middleware.BuildInput{ Request: func() *smithyhttp.Request { r := smithyhttp.NewStackRequest().(*smithyhttp.Request) r.URL, _ = url.Parse("http://example.aws") r.ContentLength = 11 - r = requestMust(r.SetStream(bytes.NewReader([]byte("hello world")))) + r = requestMust(r.SetStream(bytes.NewReader([]byte("Hello world")))) return r }(), }, expectHeader: http.Header{ - "X-Amz-Checksum-Crc32": []string{"DUoRhQ=="}, + "X-Amz-Checksum-Crc32c": []string{"crUfeA=="}, }, expectContentLength: 11, - expectPayload: []byte("hello world"), - expectPayloadHash: "b94d27b9934d3e08a52e52d7da7dabfac484efe37a5380ee9088f7ace2efcde9", + expectPayload: []byte("Hello world"), + expectPayloadHash: "64ec88ca00b268e5ba1a35678a1b5316d212f4f366b2477232534a8aeca37f3c", expectChecksumMetadata: map[string]string{ - "CRC32": "DUoRhQ==", + "CRC32C": "crUfeA==", }, }, "http payload hash already set": { @@ -474,7 +470,7 @@ func TestComputeInputPayloadChecksum(t *testing.T) { "build error": { "unknown algorithm": { initContext: func(ctx context.Context) context.Context { - return internalcontext.SetChecksumInputAlgorithm(ctx, string("unknown")) + return internalcontext.SetChecksumInputAlgorithm(ctx, "unknown") }, buildInput: middleware.BuildInput{ Request: func() *smithyhttp.Request { @@ -487,24 +483,9 @@ func TestComputeInputPayloadChecksum(t *testing.T) { expectErr: "failed to parse algorithm", expectBuildErr: true, }, - "no algorithm require checksum unseekable stream": { - optionsFn: func(o *computeInputPayloadChecksum) { - o.RequireChecksum = true - }, - buildInput: middleware.BuildInput{ - Request: func() *smithyhttp.Request { - r := smithyhttp.NewStackRequest().(*smithyhttp.Request) - r.URL, _ = url.Parse("http://example.aws") - r = requestMust(r.SetStream(bytes.NewBuffer([]byte("hello world")))) - return r - }(), - }, - expectErr: "unseekable stream is not supported", - expectBuildErr: true, - }, "http unseekable stream": { initContext: func(ctx context.Context) context.Context { - return internalcontext.SetChecksumInputAlgorithm(ctx, string(AlgorithmCRC32)) + return internalcontext.SetChecksumInputAlgorithm(ctx, string(AlgorithmSHA1)) }, buildInput: middleware.BuildInput{ Request: func() *smithyhttp.Request { @@ -514,8 +495,9 @@ func TestComputeInputPayloadChecksum(t *testing.T) { return r }(), }, - expectErr: "unseekable stream is not supported", - expectBuildErr: true, + expectContentLength: -1, + expectHeader: http.Header{}, + expectPayload: []byte("hello world"), }, "http stream read error": { initContext: func(ctx context.Context) context.Context { @@ -568,8 +550,9 @@ func TestComputeInputPayloadChecksum(t *testing.T) { return r }(), }, - expectErr: "unseekable stream is not supported", - expectBuildErr: true, + expectContentLength: -1, + expectHeader: http.Header{}, + expectPayload: []byte("hello world"), }, }, @@ -627,54 +610,54 @@ func TestComputeInputPayloadChecksum(t *testing.T) { }, "https seekable": { initContext: func(ctx context.Context) context.Context { - return internalcontext.SetChecksumInputAlgorithm(ctx, string(AlgorithmCRC32)) + return internalcontext.SetChecksumInputAlgorithm(ctx, string(AlgorithmSHA1)) }, buildInput: middleware.BuildInput{ Request: func() *smithyhttp.Request { r := smithyhttp.NewStackRequest().(*smithyhttp.Request) r.URL, _ = url.Parse("https://example.aws") r.ContentLength = 11 - r = requestMust(r.SetStream(bytes.NewReader([]byte("hello world")))) + r = requestMust(r.SetStream(bytes.NewReader([]byte("Hello world")))) return r }(), }, expectHeader: http.Header{ "Content-Encoding": []string{"aws-chunked"}, "X-Amz-Decoded-Content-Length": []string{"11"}, - "X-Amz-Trailer": []string{"x-amz-checksum-crc32"}, + "X-Amz-Trailer": []string{"x-amz-checksum-sha1"}, }, - expectContentLength: 52, - expectPayload: []byte("b\r\nhello world\r\n0\r\nx-amz-checksum-crc32:DUoRhQ==\r\n\r\n"), + expectContentLength: 71, + expectPayload: []byte("b\r\nHello world\r\n0\r\nx-amz-checksum-sha1:e1AsOh9IyGCa4hLN+2Od7jlnP14=\r\n\r\n"), expectPayloadHash: "STREAMING-UNSIGNED-PAYLOAD-TRAILER", expectDeferToFinalize: true, expectChecksumMetadata: map[string]string{ - "CRC32": "DUoRhQ==", + "SHA1": "e1AsOh9IyGCa4hLN+2Od7jlnP14=", }, }, "https seekable unknown length": { initContext: func(ctx context.Context) context.Context { - return internalcontext.SetChecksumInputAlgorithm(ctx, string(AlgorithmCRC32)) + return internalcontext.SetChecksumInputAlgorithm(ctx, string(AlgorithmCRC32C)) }, buildInput: middleware.BuildInput{ Request: func() *smithyhttp.Request { r := smithyhttp.NewStackRequest().(*smithyhttp.Request) r.URL, _ = url.Parse("https://example.aws") r.ContentLength = -1 - r = requestMust(r.SetStream(bytes.NewReader([]byte("hello world")))) + r = requestMust(r.SetStream(bytes.NewReader([]byte("Hello world")))) return r }(), }, expectHeader: http.Header{ "Content-Encoding": []string{"aws-chunked"}, "X-Amz-Decoded-Content-Length": []string{"11"}, - "X-Amz-Trailer": []string{"x-amz-checksum-crc32"}, + "X-Amz-Trailer": []string{"x-amz-checksum-crc32c"}, }, - expectContentLength: 52, - expectPayload: []byte("b\r\nhello world\r\n0\r\nx-amz-checksum-crc32:DUoRhQ==\r\n\r\n"), + expectContentLength: 53, + expectPayload: []byte("b\r\nHello world\r\n0\r\nx-amz-checksum-crc32c:crUfeA==\r\n\r\n"), expectPayloadHash: "STREAMING-UNSIGNED-PAYLOAD-TRAILER", expectDeferToFinalize: true, expectChecksumMetadata: map[string]string{ - "CRC32": "DUoRhQ==", + "CRC32C": "crUfeA==", }, }, "https no compute payload hash": { @@ -706,12 +689,12 @@ func TestComputeInputPayloadChecksum(t *testing.T) { }, }, "https no decode content length": { - initContext: func(ctx context.Context) context.Context { - return internalcontext.SetChecksumInputAlgorithm(ctx, string(AlgorithmCRC32)) - }, optionsFn: func(o *computeInputPayloadChecksum) { o.EnableDecodedContentLengthHeader = false }, + initContext: func(ctx context.Context) context.Context { + return internalcontext.SetChecksumInputAlgorithm(ctx, string(AlgorithmCRC32)) + }, buildInput: middleware.BuildInput{ Request: func() *smithyhttp.Request { r := smithyhttp.NewStackRequest().(*smithyhttp.Request) @@ -778,7 +761,6 @@ func TestComputeInputPayloadChecksum(t *testing.T) { } trailerMiddleware := &addInputChecksumTrailer{ EnableTrailingChecksum: m.EnableTrailingChecksum, - RequireChecksum: m.RequireChecksum, EnableComputePayloadHash: m.EnableComputePayloadHash, EnableDecodedContentLengthHeader: m.EnableDecodedContentLengthHeader, } @@ -920,7 +902,7 @@ func TestComputeInputPayloadChecksum(t *testing.T) { // assert computed input checksums metadata computedMetadata, ok := GetComputedInputChecksums(metadata) - if e, a := ok, (c.expectChecksumMetadata != nil); e != a { + if e, a := (c.expectChecksumMetadata != nil), ok; e != a { t.Fatalf("expect checksum metadata %t, got %t, %v", e, a, computedMetadata) } if c.expectChecksumMetadata != nil { diff --git a/service/internal/checksum/middleware_setup_context.go b/service/internal/checksum/middleware_setup_context.go index 3db73afe7e8..7fcd29a3814 100644 --- a/service/internal/checksum/middleware_setup_context.go +++ b/service/internal/checksum/middleware_setup_context.go @@ -2,11 +2,16 @@ package checksum import ( "context" + "github.com/aws/aws-sdk-go-v2/aws" internalcontext "github.com/aws/aws-sdk-go-v2/internal/context" "github.com/aws/smithy-go/middleware" ) +const ( + checksumValidationModeEnabled = "ENABLED" +) + // setupChecksumContext is the initial middleware that looks up the input // used to configure checksum behavior. This middleware must be executed before // input validation step or any other checksum middleware. @@ -17,6 +22,16 @@ type setupInputContext struct { // Given the input parameter value, the function must return the algorithm // and true, or false if no algorithm is specified. GetAlgorithm func(interface{}) (string, bool) + + // RequireChecksum indicates whether operation model forces middleware to compute the input payload's checksum. + // If RequireChecksum is set to true, checksum will be calculated and RequestChecksumCalculation will be ignored, + // otherwise RequestChecksumCalculation will be used to indicate if checksum will be calculated + RequireChecksum bool + + // RequestChecksumCalculation is the user config to opt-in/out request checksum calculation. If RequireChecksum is + // set to true, checksum will be calculated and this field will be ignored, otherwise + // RequestChecksumCalculation will be used to indicate if checksum will be calculated + RequestChecksumCalculation aws.RequestChecksumCalculation } // ID for the middleware @@ -31,13 +46,13 @@ func (m *setupInputContext) HandleInitialize( ) ( out middleware.InitializeOutput, metadata middleware.Metadata, err error, ) { - // Check if validation algorithm is specified. - if m.GetAlgorithm != nil { - // check is input resource has a checksum algorithm - algorithm, ok := m.GetAlgorithm(in.Parameters) - if ok && len(algorithm) != 0 { - ctx = internalcontext.SetChecksumInputAlgorithm(ctx, algorithm) - } + if algorithm, ok := m.GetAlgorithm(in.Parameters); ok { + ctx = internalcontext.SetChecksumInputAlgorithm(ctx, algorithm) + return next.HandleInitialize(ctx, in) + } + + if m.RequireChecksum || m.RequestChecksumCalculation == aws.RequestChecksumCalculationWhenSupported { + ctx = internalcontext.SetChecksumInputAlgorithm(ctx, string(AlgorithmCRC32)) } return next.HandleInitialize(ctx, in) @@ -50,6 +65,9 @@ type setupOutputContext struct { // Given the input parameter value, the function must return the validation // mode and true, or false if no mode is specified. GetValidationMode func(interface{}) (string, bool) + + // ResponseChecksumValidation states user config to opt-in/out checksum validation + ResponseChecksumValidation aws.ResponseChecksumValidation } // ID for the middleware @@ -64,13 +82,11 @@ func (m *setupOutputContext) HandleInitialize( ) ( out middleware.InitializeOutput, metadata middleware.Metadata, err error, ) { - // Check if validation mode is specified. - if m.GetValidationMode != nil { - // check is input resource has a checksum algorithm - mode, ok := m.GetValidationMode(in.Parameters) - if ok && len(mode) != 0 { - ctx = setContextOutputValidationMode(ctx, mode) - } + + mode, _ := m.GetValidationMode(in.Parameters) + + if m.ResponseChecksumValidation == aws.ResponseChecksumValidationWhenSupported || mode == checksumValidationModeEnabled { + ctx = setContextOutputValidationMode(ctx, checksumValidationModeEnabled) } return next.HandleInitialize(ctx, in) diff --git a/service/internal/checksum/middleware_setup_context_test.go b/service/internal/checksum/middleware_setup_context_test.go index e629ee088d7..0f766430eb8 100644 --- a/service/internal/checksum/middleware_setup_context_test.go +++ b/service/internal/checksum/middleware_setup_context_test.go @@ -5,6 +5,7 @@ package checksum import ( "context" + "github.com/aws/aws-sdk-go-v2/aws" "testing" internalcontext "github.com/aws/aws-sdk-go-v2/internal/context" @@ -17,29 +18,47 @@ func TestSetupInput(t *testing.T) { } cases := map[string]struct { - inputParams interface{} - getAlgorithm func(interface{}) (string, bool) - expectValue string + inputParams interface{} + getAlgorithm func(interface{}) (string, bool) + RequireChecksum bool + RequestChecksumCalculation aws.RequestChecksumCalculation + expectValue string }{ - "nil accessor": { + "user config require checksum and algorithm unset": { + RequestChecksumCalculation: aws.RequestChecksumCalculationWhenRequired, + getAlgorithm: func(v interface{}) (string, bool) { + return "", false + }, expectValue: "", }, - "found empty": { - inputParams: Params{Value: ""}, + "require checksum found empty": { + RequireChecksum: true, + inputParams: Params{Value: ""}, getAlgorithm: func(v interface{}) (string, bool) { vv := v.(Params) return vv.Value, true }, expectValue: "", }, - "found not set": { - inputParams: Params{Value: ""}, + "user config require checksum found empty": { + RequestChecksumCalculation: aws.RequestChecksumCalculationWhenRequired, + inputParams: Params{Value: ""}, getAlgorithm: func(v interface{}) (string, bool) { - return "", false + vv := v.(Params) + return vv.Value, true }, expectValue: "", }, - "found": { + "require checksum and found": { + RequireChecksum: true, + inputParams: Params{Value: "abc123"}, + getAlgorithm: func(v interface{}) (string, bool) { + vv := v.(Params) + return vv.Value, true + }, + expectValue: "abc123", + }, + "user config support checksum and found": { inputParams: Params{Value: "abc123"}, getAlgorithm: func(v interface{}) (string, bool) { vv := v.(Params) @@ -47,12 +66,37 @@ func TestSetupInput(t *testing.T) { }, expectValue: "abc123", }, + "user config require checksum and found": { + RequestChecksumCalculation: aws.RequestChecksumCalculationWhenRequired, + inputParams: Params{Value: "abc123"}, + getAlgorithm: func(v interface{}) (string, bool) { + vv := v.(Params) + return vv.Value, true + }, + expectValue: "abc123", + }, + "require checksum unset and use default": { + RequireChecksum: true, + getAlgorithm: func(v interface{}) (string, bool) { + return "", false + }, + expectValue: "CRC32", + }, + "user config support checksum and use default": { + RequestChecksumCalculation: aws.RequestChecksumCalculationWhenSupported, + getAlgorithm: func(v interface{}) (string, bool) { + return "", false + }, + expectValue: "CRC32", + }, } for name, c := range cases { t.Run(name, func(t *testing.T) { m := setupInputContext{ - GetAlgorithm: c.getAlgorithm, + GetAlgorithm: c.getAlgorithm, + RequireChecksum: c.RequireChecksum, + RequestChecksumCalculation: c.RequestChecksumCalculation, } _, _, err := m.HandleInitialize(context.Background(), @@ -83,42 +127,54 @@ func TestSetupOutput(t *testing.T) { } cases := map[string]struct { - inputParams interface{} - getValidationMode func(interface{}) (string, bool) - expectValue string + inputParams interface{} + ResponseChecksumValidation aws.ResponseChecksumValidation + getValidationMode func(interface{}) (string, bool) + expectValue string }{ - "nil accessor": { - expectValue: "", + "user config support checksum found empty": { + ResponseChecksumValidation: aws.ResponseChecksumValidationWhenSupported, + inputParams: Params{Value: ""}, + getValidationMode: func(v interface{}) (string, bool) { + vv := v.(Params) + return vv.Value, true + }, + expectValue: "ENABLED", }, - "found empty": { - inputParams: Params{Value: ""}, + "user config support checksum found invalid value": { + ResponseChecksumValidation: aws.ResponseChecksumValidationWhenSupported, + inputParams: Params{Value: "abc123"}, getValidationMode: func(v interface{}) (string, bool) { vv := v.(Params) return vv.Value, true }, - expectValue: "", + expectValue: "ENABLED", }, - "found not set": { - inputParams: Params{Value: ""}, + "user config require checksum found invalid value": { + ResponseChecksumValidation: aws.ResponseChecksumValidationWhenRequired, + inputParams: Params{Value: "abc123"}, getValidationMode: func(v interface{}) (string, bool) { - return "", false + vv := v.(Params) + return vv.Value, true }, expectValue: "", }, - "found": { - inputParams: Params{Value: "abc123"}, + "user config require checksum found valid value": { + ResponseChecksumValidation: aws.ResponseChecksumValidationWhenRequired, + inputParams: Params{Value: "ENABLED"}, getValidationMode: func(v interface{}) (string, bool) { vv := v.(Params) return vv.Value, true }, - expectValue: "abc123", + expectValue: "ENABLED", }, } for name, c := range cases { t.Run(name, func(t *testing.T) { m := setupOutputContext{ - GetValidationMode: c.getValidationMode, + GetValidationMode: c.getValidationMode, + ResponseChecksumValidation: c.ResponseChecksumValidation, } _, _, err := m.HandleInitialize(context.Background(), diff --git a/service/internal/checksum/middleware_validate_output.go b/service/internal/checksum/middleware_validate_output.go index 9fde12d86d7..14096a1ce31 100644 --- a/service/internal/checksum/middleware_validate_output.go +++ b/service/internal/checksum/middleware_validate_output.go @@ -55,7 +55,7 @@ func (m *validateOutputPayloadChecksum) ID() string { } // HandleDeserialize is a Deserialize middleware that wraps the HTTP response -// body with an io.ReadCloser that will validate the its checksum. +// body with an io.ReadCloser that will validate its checksum. func (m *validateOutputPayloadChecksum) HandleDeserialize( ctx context.Context, in middleware.DeserializeInput, next middleware.DeserializeHandler, ) ( @@ -66,8 +66,7 @@ func (m *validateOutputPayloadChecksum) HandleDeserialize( return out, metadata, err } - // If there is no validation mode specified nothing is supported. - if mode := getContextOutputValidationMode(ctx); mode != "ENABLED" { + if mode := getContextOutputValidationMode(ctx); mode != checksumValidationModeEnabled { return out, metadata, err } @@ -90,8 +89,6 @@ func (m *validateOutputPayloadChecksum) HandleDeserialize( algorithmToUse = algorithm } - // TODO this must validate the validation mode is set to enabled. - logger := middleware.GetLogger(ctx) // Skip validation if no checksum algorithm or checksum is available. diff --git a/service/internal/checksum/middleware_validate_output_test.go b/service/internal/checksum/middleware_validate_output_test.go index 0bf923e51e7..618c233dc6f 100644 --- a/service/internal/checksum/middleware_validate_output_test.go +++ b/service/internal/checksum/middleware_validate_output_test.go @@ -49,7 +49,21 @@ func TestValidateOutputPayloadChecksum(t *testing.T) { expectAlgorithmsUsed: []string{"CRC32"}, expectPayload: []byte("hello world"), }, - "failure": { + "no checksum required": { + response: &smithyhttp.Response{ + Response: &http.Response{ + StatusCode: 200, + Header: func() http.Header { + h := http.Header{} + h.Set(AlgorithmHTTPHeader(AlgorithmCRC32C), "crUfeA==") + return h + }(), + Body: ioutil.NopCloser(strings.NewReader("Hello world")), + }, + }, + expectPayload: []byte("Hello world"), + }, + "checksum mismatch failure": { modifyContext: func(ctx context.Context) context.Context { return setContextOutputValidationMode(ctx, "ENABLED") }, @@ -101,19 +115,6 @@ func TestValidateOutputPayloadChecksum(t *testing.T) { expectLogged: "no supported checksum", expectPayload: []byte("hello world"), }, - "no output validation model": { - response: &smithyhttp.Response{ - Response: &http.Response{ - StatusCode: 200, - Header: func() http.Header { - h := http.Header{} - return h - }(), - Body: ioutil.NopCloser(strings.NewReader("hello world")), - }, - }, - expectPayload: []byte("hello world"), - }, "unknown output validation model": { modifyContext: func(ctx context.Context) context.Context { return setContextOutputValidationMode(ctx, "something else") @@ -189,7 +190,7 @@ func TestValidateOutputPayloadChecksum(t *testing.T) { validateOutput := validateOutputPayloadChecksum{ Algorithms: []Algorithm{ - AlgorithmSHA1, AlgorithmCRC32, AlgorithmCRC32C, + AlgorithmSHA1, AlgorithmCRC32, AlgorithmCRC32C, AlgorithmSHA256, }, LogValidationSkipped: true, LogMultipartValidationSkipped: true, diff --git a/service/s3/api_client.go b/service/s3/api_client.go index 08e432799a9..d47cc4bfd2b 100644 --- a/service/s3/api_client.go +++ b/service/s3/api_client.go @@ -449,15 +449,17 @@ func setResolvedDefaultsMode(o *Options) { // NewFromConfig returns a new client from the provided config. func NewFromConfig(cfg aws.Config, optFns ...func(*Options)) *Client { opts := Options{ - Region: cfg.Region, - DefaultsMode: cfg.DefaultsMode, - RuntimeEnvironment: cfg.RuntimeEnvironment, - HTTPClient: cfg.HTTPClient, - Credentials: cfg.Credentials, - APIOptions: cfg.APIOptions, - Logger: cfg.Logger, - ClientLogMode: cfg.ClientLogMode, - AppID: cfg.AppID, + Region: cfg.Region, + DefaultsMode: cfg.DefaultsMode, + RuntimeEnvironment: cfg.RuntimeEnvironment, + HTTPClient: cfg.HTTPClient, + Credentials: cfg.Credentials, + APIOptions: cfg.APIOptions, + Logger: cfg.Logger, + ClientLogMode: cfg.ClientLogMode, + AppID: cfg.AppID, + RequestChecksumCalculation: cfg.RequestChecksumCalculation, + ResponseChecksumValidation: cfg.ResponseChecksumValidation, } resolveAWSRetryerProvider(cfg, &opts) resolveAWSRetryMaxAttempts(cfg, &opts) diff --git a/service/s3/api_op_DeleteObjects.go b/service/s3/api_op_DeleteObjects.go index d08261e5d69..438ada291c7 100644 --- a/service/s3/api_op_DeleteObjects.go +++ b/service/s3/api_op_DeleteObjects.go @@ -427,6 +427,7 @@ func addDeleteObjectsInputChecksumMiddlewares(stack *middleware.Stack, options O return internalChecksum.AddInputMiddleware(stack, internalChecksum.InputMiddlewareOptions{ GetAlgorithm: getDeleteObjectsRequestAlgorithmMember, RequireChecksum: true, + RequestChecksumCalculation: options.RequestChecksumCalculation, EnableTrailingChecksum: false, EnableComputeSHA256PayloadHash: true, EnableDecodedContentLengthHeader: true, diff --git a/service/s3/api_op_GetObject.go b/service/s3/api_op_GetObject.go index 0a9785d2e32..dd2df9f7676 100644 --- a/service/s3/api_op_GetObject.go +++ b/service/s3/api_op_GetObject.go @@ -794,6 +794,7 @@ func getGetObjectRequestValidationModeMember(input interface{}) (string, bool) { func addGetObjectOutputChecksumMiddlewares(stack *middleware.Stack, options Options) error { return internalChecksum.AddOutputMiddleware(stack, internalChecksum.OutputMiddlewareOptions{ GetValidationMode: getGetObjectRequestValidationModeMember, + ResponseChecksumValidation: options.ResponseChecksumValidation, ValidationAlgorithms: []string{"CRC32", "CRC32C", "SHA256", "SHA1"}, IgnoreMultipartValidation: true, LogValidationSkipped: true, diff --git a/service/s3/api_op_PutBucketAccelerateConfiguration.go b/service/s3/api_op_PutBucketAccelerateConfiguration.go index c400607c8ac..4e5ce8ba633 100644 --- a/service/s3/api_op_PutBucketAccelerateConfiguration.go +++ b/service/s3/api_op_PutBucketAccelerateConfiguration.go @@ -265,6 +265,7 @@ func addPutBucketAccelerateConfigurationInputChecksumMiddlewares(stack *middlewa return internalChecksum.AddInputMiddleware(stack, internalChecksum.InputMiddlewareOptions{ GetAlgorithm: getPutBucketAccelerateConfigurationRequestAlgorithmMember, RequireChecksum: false, + RequestChecksumCalculation: options.RequestChecksumCalculation, EnableTrailingChecksum: false, EnableComputeSHA256PayloadHash: true, EnableDecodedContentLengthHeader: true, diff --git a/service/s3/api_op_PutBucketAcl.go b/service/s3/api_op_PutBucketAcl.go index 9562fafb5a5..ec57b045421 100644 --- a/service/s3/api_op_PutBucketAcl.go +++ b/service/s3/api_op_PutBucketAcl.go @@ -412,6 +412,7 @@ func addPutBucketAclInputChecksumMiddlewares(stack *middleware.Stack, options Op return internalChecksum.AddInputMiddleware(stack, internalChecksum.InputMiddlewareOptions{ GetAlgorithm: getPutBucketAclRequestAlgorithmMember, RequireChecksum: true, + RequestChecksumCalculation: options.RequestChecksumCalculation, EnableTrailingChecksum: false, EnableComputeSHA256PayloadHash: true, EnableDecodedContentLengthHeader: true, diff --git a/service/s3/api_op_PutBucketCors.go b/service/s3/api_op_PutBucketCors.go index a08c41c1bc4..88ba2ed1755 100644 --- a/service/s3/api_op_PutBucketCors.go +++ b/service/s3/api_op_PutBucketCors.go @@ -290,6 +290,7 @@ func addPutBucketCorsInputChecksumMiddlewares(stack *middleware.Stack, options O return internalChecksum.AddInputMiddleware(stack, internalChecksum.InputMiddlewareOptions{ GetAlgorithm: getPutBucketCorsRequestAlgorithmMember, RequireChecksum: true, + RequestChecksumCalculation: options.RequestChecksumCalculation, EnableTrailingChecksum: false, EnableComputeSHA256PayloadHash: true, EnableDecodedContentLengthHeader: true, diff --git a/service/s3/api_op_PutBucketEncryption.go b/service/s3/api_op_PutBucketEncryption.go index e68b84c364b..c1220f74986 100644 --- a/service/s3/api_op_PutBucketEncryption.go +++ b/service/s3/api_op_PutBucketEncryption.go @@ -365,6 +365,7 @@ func addPutBucketEncryptionInputChecksumMiddlewares(stack *middleware.Stack, opt return internalChecksum.AddInputMiddleware(stack, internalChecksum.InputMiddlewareOptions{ GetAlgorithm: getPutBucketEncryptionRequestAlgorithmMember, RequireChecksum: true, + RequestChecksumCalculation: options.RequestChecksumCalculation, EnableTrailingChecksum: false, EnableComputeSHA256PayloadHash: true, EnableDecodedContentLengthHeader: true, diff --git a/service/s3/api_op_PutBucketLifecycleConfiguration.go b/service/s3/api_op_PutBucketLifecycleConfiguration.go index 4ab7b63b01e..6e8d307c5d3 100644 --- a/service/s3/api_op_PutBucketLifecycleConfiguration.go +++ b/service/s3/api_op_PutBucketLifecycleConfiguration.go @@ -333,6 +333,7 @@ func addPutBucketLifecycleConfigurationInputChecksumMiddlewares(stack *middlewar return internalChecksum.AddInputMiddleware(stack, internalChecksum.InputMiddlewareOptions{ GetAlgorithm: getPutBucketLifecycleConfigurationRequestAlgorithmMember, RequireChecksum: true, + RequestChecksumCalculation: options.RequestChecksumCalculation, EnableTrailingChecksum: false, EnableComputeSHA256PayloadHash: true, EnableDecodedContentLengthHeader: true, diff --git a/service/s3/api_op_PutBucketLogging.go b/service/s3/api_op_PutBucketLogging.go index 10e04be1ddb..71a3c8d4c87 100644 --- a/service/s3/api_op_PutBucketLogging.go +++ b/service/s3/api_op_PutBucketLogging.go @@ -297,6 +297,7 @@ func addPutBucketLoggingInputChecksumMiddlewares(stack *middleware.Stack, option return internalChecksum.AddInputMiddleware(stack, internalChecksum.InputMiddlewareOptions{ GetAlgorithm: getPutBucketLoggingRequestAlgorithmMember, RequireChecksum: true, + RequestChecksumCalculation: options.RequestChecksumCalculation, EnableTrailingChecksum: false, EnableComputeSHA256PayloadHash: true, EnableDecodedContentLengthHeader: true, diff --git a/service/s3/api_op_PutBucketOwnershipControls.go b/service/s3/api_op_PutBucketOwnershipControls.go index 6d5517e83cb..1e1e997de8e 100644 --- a/service/s3/api_op_PutBucketOwnershipControls.go +++ b/service/s3/api_op_PutBucketOwnershipControls.go @@ -229,6 +229,7 @@ func addPutBucketOwnershipControlsInputChecksumMiddlewares(stack *middleware.Sta return internalChecksum.AddInputMiddleware(stack, internalChecksum.InputMiddlewareOptions{ GetAlgorithm: nil, RequireChecksum: true, + RequestChecksumCalculation: options.RequestChecksumCalculation, EnableTrailingChecksum: false, EnableComputeSHA256PayloadHash: true, EnableDecodedContentLengthHeader: true, diff --git a/service/s3/api_op_PutBucketPolicy.go b/service/s3/api_op_PutBucketPolicy.go index b7e93b2cb4e..bfe3d202c30 100644 --- a/service/s3/api_op_PutBucketPolicy.go +++ b/service/s3/api_op_PutBucketPolicy.go @@ -337,6 +337,7 @@ func addPutBucketPolicyInputChecksumMiddlewares(stack *middleware.Stack, options return internalChecksum.AddInputMiddleware(stack, internalChecksum.InputMiddlewareOptions{ GetAlgorithm: getPutBucketPolicyRequestAlgorithmMember, RequireChecksum: true, + RequestChecksumCalculation: options.RequestChecksumCalculation, EnableTrailingChecksum: false, EnableComputeSHA256PayloadHash: true, EnableDecodedContentLengthHeader: true, diff --git a/service/s3/api_op_PutBucketReplication.go b/service/s3/api_op_PutBucketReplication.go index 1b67f7ec331..ed997618143 100644 --- a/service/s3/api_op_PutBucketReplication.go +++ b/service/s3/api_op_PutBucketReplication.go @@ -308,6 +308,7 @@ func addPutBucketReplicationInputChecksumMiddlewares(stack *middleware.Stack, op return internalChecksum.AddInputMiddleware(stack, internalChecksum.InputMiddlewareOptions{ GetAlgorithm: getPutBucketReplicationRequestAlgorithmMember, RequireChecksum: true, + RequestChecksumCalculation: options.RequestChecksumCalculation, EnableTrailingChecksum: false, EnableComputeSHA256PayloadHash: true, EnableDecodedContentLengthHeader: true, diff --git a/service/s3/api_op_PutBucketRequestPayment.go b/service/s3/api_op_PutBucketRequestPayment.go index fb9ffd47560..f53a284d733 100644 --- a/service/s3/api_op_PutBucketRequestPayment.go +++ b/service/s3/api_op_PutBucketRequestPayment.go @@ -255,6 +255,7 @@ func addPutBucketRequestPaymentInputChecksumMiddlewares(stack *middleware.Stack, return internalChecksum.AddInputMiddleware(stack, internalChecksum.InputMiddlewareOptions{ GetAlgorithm: getPutBucketRequestPaymentRequestAlgorithmMember, RequireChecksum: true, + RequestChecksumCalculation: options.RequestChecksumCalculation, EnableTrailingChecksum: false, EnableComputeSHA256PayloadHash: true, EnableDecodedContentLengthHeader: true, diff --git a/service/s3/api_op_PutBucketTagging.go b/service/s3/api_op_PutBucketTagging.go index 7bd67f9bf98..08b34887784 100644 --- a/service/s3/api_op_PutBucketTagging.go +++ b/service/s3/api_op_PutBucketTagging.go @@ -287,6 +287,7 @@ func addPutBucketTaggingInputChecksumMiddlewares(stack *middleware.Stack, option return internalChecksum.AddInputMiddleware(stack, internalChecksum.InputMiddlewareOptions{ GetAlgorithm: getPutBucketTaggingRequestAlgorithmMember, RequireChecksum: true, + RequestChecksumCalculation: options.RequestChecksumCalculation, EnableTrailingChecksum: false, EnableComputeSHA256PayloadHash: true, EnableDecodedContentLengthHeader: true, diff --git a/service/s3/api_op_PutBucketVersioning.go b/service/s3/api_op_PutBucketVersioning.go index 04cbcf08dc4..d3e17fdae4e 100644 --- a/service/s3/api_op_PutBucketVersioning.go +++ b/service/s3/api_op_PutBucketVersioning.go @@ -289,6 +289,7 @@ func addPutBucketVersioningInputChecksumMiddlewares(stack *middleware.Stack, opt return internalChecksum.AddInputMiddleware(stack, internalChecksum.InputMiddlewareOptions{ GetAlgorithm: getPutBucketVersioningRequestAlgorithmMember, RequireChecksum: true, + RequestChecksumCalculation: options.RequestChecksumCalculation, EnableTrailingChecksum: false, EnableComputeSHA256PayloadHash: true, EnableDecodedContentLengthHeader: true, diff --git a/service/s3/api_op_PutBucketWebsite.go b/service/s3/api_op_PutBucketWebsite.go index ebcb87f4fa3..665e05dc314 100644 --- a/service/s3/api_op_PutBucketWebsite.go +++ b/service/s3/api_op_PutBucketWebsite.go @@ -310,6 +310,7 @@ func addPutBucketWebsiteInputChecksumMiddlewares(stack *middleware.Stack, option return internalChecksum.AddInputMiddleware(stack, internalChecksum.InputMiddlewareOptions{ GetAlgorithm: getPutBucketWebsiteRequestAlgorithmMember, RequireChecksum: true, + RequestChecksumCalculation: options.RequestChecksumCalculation, EnableTrailingChecksum: false, EnableComputeSHA256PayloadHash: true, EnableDecodedContentLengthHeader: true, diff --git a/service/s3/api_op_PutObject.go b/service/s3/api_op_PutObject.go index 9dc442f7d82..d64c9e2b650 100644 --- a/service/s3/api_op_PutObject.go +++ b/service/s3/api_op_PutObject.go @@ -876,6 +876,7 @@ func addPutObjectInputChecksumMiddlewares(stack *middleware.Stack, options Optio return internalChecksum.AddInputMiddleware(stack, internalChecksum.InputMiddlewareOptions{ GetAlgorithm: getPutObjectRequestAlgorithmMember, RequireChecksum: false, + RequestChecksumCalculation: options.RequestChecksumCalculation, EnableTrailingChecksum: true, EnableComputeSHA256PayloadHash: true, EnableDecodedContentLengthHeader: true, diff --git a/service/s3/api_op_PutObjectAcl.go b/service/s3/api_op_PutObjectAcl.go index 481384026de..f0d9678a2cd 100644 --- a/service/s3/api_op_PutObjectAcl.go +++ b/service/s3/api_op_PutObjectAcl.go @@ -464,6 +464,7 @@ func addPutObjectAclInputChecksumMiddlewares(stack *middleware.Stack, options Op return internalChecksum.AddInputMiddleware(stack, internalChecksum.InputMiddlewareOptions{ GetAlgorithm: getPutObjectAclRequestAlgorithmMember, RequireChecksum: true, + RequestChecksumCalculation: options.RequestChecksumCalculation, EnableTrailingChecksum: false, EnableComputeSHA256PayloadHash: true, EnableDecodedContentLengthHeader: true, diff --git a/service/s3/api_op_PutObjectLegalHold.go b/service/s3/api_op_PutObjectLegalHold.go index 22737391d50..14830434d00 100644 --- a/service/s3/api_op_PutObjectLegalHold.go +++ b/service/s3/api_op_PutObjectLegalHold.go @@ -279,6 +279,7 @@ func addPutObjectLegalHoldInputChecksumMiddlewares(stack *middleware.Stack, opti return internalChecksum.AddInputMiddleware(stack, internalChecksum.InputMiddlewareOptions{ GetAlgorithm: getPutObjectLegalHoldRequestAlgorithmMember, RequireChecksum: true, + RequestChecksumCalculation: options.RequestChecksumCalculation, EnableTrailingChecksum: false, EnableComputeSHA256PayloadHash: true, EnableDecodedContentLengthHeader: true, diff --git a/service/s3/api_op_PutObjectLockConfiguration.go b/service/s3/api_op_PutObjectLockConfiguration.go index 3b0501d83bf..b866576ad87 100644 --- a/service/s3/api_op_PutObjectLockConfiguration.go +++ b/service/s3/api_op_PutObjectLockConfiguration.go @@ -270,6 +270,7 @@ func addPutObjectLockConfigurationInputChecksumMiddlewares(stack *middleware.Sta return internalChecksum.AddInputMiddleware(stack, internalChecksum.InputMiddlewareOptions{ GetAlgorithm: getPutObjectLockConfigurationRequestAlgorithmMember, RequireChecksum: true, + RequestChecksumCalculation: options.RequestChecksumCalculation, EnableTrailingChecksum: false, EnableComputeSHA256PayloadHash: true, EnableDecodedContentLengthHeader: true, diff --git a/service/s3/api_op_PutObjectRetention.go b/service/s3/api_op_PutObjectRetention.go index 6bb5682fca8..248b9cd1f30 100644 --- a/service/s3/api_op_PutObjectRetention.go +++ b/service/s3/api_op_PutObjectRetention.go @@ -286,6 +286,7 @@ func addPutObjectRetentionInputChecksumMiddlewares(stack *middleware.Stack, opti return internalChecksum.AddInputMiddleware(stack, internalChecksum.InputMiddlewareOptions{ GetAlgorithm: getPutObjectRetentionRequestAlgorithmMember, RequireChecksum: true, + RequestChecksumCalculation: options.RequestChecksumCalculation, EnableTrailingChecksum: false, EnableComputeSHA256PayloadHash: true, EnableDecodedContentLengthHeader: true, diff --git a/service/s3/api_op_PutObjectTagging.go b/service/s3/api_op_PutObjectTagging.go index 1f637c939b0..34fec9f1dca 100644 --- a/service/s3/api_op_PutObjectTagging.go +++ b/service/s3/api_op_PutObjectTagging.go @@ -322,6 +322,7 @@ func addPutObjectTaggingInputChecksumMiddlewares(stack *middleware.Stack, option return internalChecksum.AddInputMiddleware(stack, internalChecksum.InputMiddlewareOptions{ GetAlgorithm: getPutObjectTaggingRequestAlgorithmMember, RequireChecksum: true, + RequestChecksumCalculation: options.RequestChecksumCalculation, EnableTrailingChecksum: false, EnableComputeSHA256PayloadHash: true, EnableDecodedContentLengthHeader: true, diff --git a/service/s3/api_op_PutPublicAccessBlock.go b/service/s3/api_op_PutPublicAccessBlock.go index 7878fb783d9..1f26e90b654 100644 --- a/service/s3/api_op_PutPublicAccessBlock.go +++ b/service/s3/api_op_PutPublicAccessBlock.go @@ -273,6 +273,7 @@ func addPutPublicAccessBlockInputChecksumMiddlewares(stack *middleware.Stack, op return internalChecksum.AddInputMiddleware(stack, internalChecksum.InputMiddlewareOptions{ GetAlgorithm: getPutPublicAccessBlockRequestAlgorithmMember, RequireChecksum: true, + RequestChecksumCalculation: options.RequestChecksumCalculation, EnableTrailingChecksum: false, EnableComputeSHA256PayloadHash: true, EnableDecodedContentLengthHeader: true, diff --git a/service/s3/api_op_RestoreObject.go b/service/s3/api_op_RestoreObject.go index d0ed1312a41..d206f6fa935 100644 --- a/service/s3/api_op_RestoreObject.go +++ b/service/s3/api_op_RestoreObject.go @@ -426,6 +426,7 @@ func addRestoreObjectInputChecksumMiddlewares(stack *middleware.Stack, options O return internalChecksum.AddInputMiddleware(stack, internalChecksum.InputMiddlewareOptions{ GetAlgorithm: getRestoreObjectRequestAlgorithmMember, RequireChecksum: false, + RequestChecksumCalculation: options.RequestChecksumCalculation, EnableTrailingChecksum: false, EnableComputeSHA256PayloadHash: true, EnableDecodedContentLengthHeader: true, diff --git a/service/s3/api_op_UploadPart.go b/service/s3/api_op_UploadPart.go index 62ee5938c0f..01811aa607c 100644 --- a/service/s3/api_op_UploadPart.go +++ b/service/s3/api_op_UploadPart.go @@ -595,6 +595,7 @@ func addUploadPartInputChecksumMiddlewares(stack *middleware.Stack, options Opti return internalChecksum.AddInputMiddleware(stack, internalChecksum.InputMiddlewareOptions{ GetAlgorithm: getUploadPartRequestAlgorithmMember, RequireChecksum: false, + RequestChecksumCalculation: options.RequestChecksumCalculation, EnableTrailingChecksum: true, EnableComputeSHA256PayloadHash: true, EnableDecodedContentLengthHeader: true, diff --git a/service/s3/options.go b/service/s3/options.go index 8c67e4c6218..6b98e8802de 100644 --- a/service/s3/options.go +++ b/service/s3/options.go @@ -92,6 +92,12 @@ type Options struct { // The region to send requests to. (Required) Region string + // Indicates how user opt-in/out request checksum calculation + RequestChecksumCalculation aws.RequestChecksumCalculation + + // Indicates how user opt-in/out response checksum validation + ResponseChecksumValidation aws.ResponseChecksumValidation + // RetryMaxAttempts specifies the maximum number attempts an API client will call // an operation that fails with a retryable error. A value of 0 is ignored, and // will not be used to configure the API client created default retryer, or modify