From 24a88063b255c2d06930cd82b74cacc154790101 Mon Sep 17 00:00:00 2001 From: skotambkar Date: Thu, 15 Oct 2020 12:32:35 -0700 Subject: [PATCH] limits deriving error code from status to s3 --- .../smithy/aws/go/codegen/XMLProtocolUtils.java | 12 +++++++++++- service/internal/s3shared/xml_utils.go | 5 +++-- service/internal/s3shared/xml_utils_test.go | 14 +++++++++++--- 3 files changed, 25 insertions(+), 6 deletions(-) diff --git a/codegen/smithy-aws-go-codegen/src/main/java/software/amazon/smithy/aws/go/codegen/XMLProtocolUtils.java b/codegen/smithy-aws-go-codegen/src/main/java/software/amazon/smithy/aws/go/codegen/XMLProtocolUtils.java index 80932ed71d3..85e19d2123b 100644 --- a/codegen/smithy-aws-go-codegen/src/main/java/software/amazon/smithy/aws/go/codegen/XMLProtocolUtils.java +++ b/codegen/smithy-aws-go-codegen/src/main/java/software/amazon/smithy/aws/go/codegen/XMLProtocolUtils.java @@ -275,8 +275,13 @@ public static void writeXmlErrorMessageCodeDeserializer(ProtocolGenerator.Genera if (requiresS3Customization(service)) { writer.addUseImports(AwsCustomGoDependency.S3_SHARED_CUSTOMIZATION); - writer.write("errorComponents, err := s3shared.GetErrorResponseComponents(errorBody, response.StatusCode)"); + String getErrorComponents = String.format("errorComponents, err := " + + "s3shared.GetErrorResponseComponents(errorBody, response.StatusCode, %s)", + isS3Service(service)); + writer.write(getErrorComponents); + writer.write("if err != nil { return err }"); + writer.insertTrailingNewline(); writer.openBlock("if hostID := errorComponents.HostID; len(hostID)!=0 {", "}", () -> { writer.write("s3shared.SetHostIDMetadata(metadata, hostID)"); @@ -307,6 +312,11 @@ private static boolean requiresS3Customization(ServiceShape service) { String serviceId= service.expectTrait(ServiceTrait.class).getSdkId(); return serviceId.equalsIgnoreCase("S3") || serviceId.equalsIgnoreCase("S3 Control"); } + + private static boolean isS3Service(ServiceShape service) { + String serviceId= service.expectTrait(ServiceTrait.class).getSdkId(); + return serviceId.equalsIgnoreCase("S3"); + } } diff --git a/service/internal/s3shared/xml_utils.go b/service/internal/s3shared/xml_utils.go index d26a2fa1657..d4cec6b6d0a 100644 --- a/service/internal/s3shared/xml_utils.go +++ b/service/internal/s3shared/xml_utils.go @@ -18,13 +18,14 @@ type ErrorComponents struct { } // GetErrorResponseComponents returns the error fields from an xml error response body -func GetErrorResponseComponents(r io.Reader, statusCode int) (ErrorComponents, error) { +func GetErrorResponseComponents(r io.Reader, statusCode int, isS3service bool) (ErrorComponents, error) { var errComponents ErrorComponents if err := xml.NewDecoder(r).Decode(&errComponents); err != nil && err != io.EOF { return ErrorComponents{}, fmt.Errorf("error while deserializing xml error response : %w", err) } - if len(errComponents.Code) == 0 && len(errComponents.Message) == 0 { + // for S3 service, we derive err code and message, if none is found + if isS3service && len(errComponents.Code) == 0 && len(errComponents.Message) == 0 { // derive code and message from status code statusText := http.StatusText(statusCode) errComponents.Code = strings.Replace(statusText, " ", "", -1) diff --git a/service/internal/s3shared/xml_utils_test.go b/service/internal/s3shared/xml_utils_test.go index 460c5607998..64bc1ef148a 100644 --- a/service/internal/s3shared/xml_utils_test.go +++ b/service/internal/s3shared/xml_utils_test.go @@ -9,6 +9,7 @@ import ( func TestGetResponseErrorCode(t *testing.T) { cases := map[string]struct { + isS3Service bool status int errorResponse io.Reader expectedErrorCode string @@ -17,7 +18,8 @@ func TestGetResponseErrorCode(t *testing.T) { expectedErrorHostID string }{ "standard xml error": { - status: 400, + isS3Service: true, + status: 400, errorResponse: bytes.NewReader([]byte(` Sender InvalidGreeting @@ -30,17 +32,23 @@ func TestGetResponseErrorCode(t *testing.T) { expectedErrorRequestID: "foo-id", expectedErrorHostID: "bar-id", }, - "no response body": { + "s3 no response body": { + isS3Service: true, status: 400, errorResponse: bytes.NewReader([]byte(``)), expectedErrorCode: "BadRequest", expectedErrorMessage: "Bad Request", }, + "s3control no response body": { + isS3Service: false, + status: 400, + errorResponse: bytes.NewReader([]byte(``)), + }, } for name, c := range cases { t.Run(name, func(t *testing.T) { - ec, err := GetErrorResponseComponents(c.errorResponse, c.status) + ec, err := GetErrorResponseComponents(c.errorResponse, c.status, c.isS3Service) if err != nil { t.Fatalf("expected no error, got %v", err) }