Skip to content

Commit

Permalink
Merge pull request #875 from aws/s3errorDeser
Browse files Browse the repository at this point in the history
support s3control using different xml error format than s3
  • Loading branch information
skotambkar authored Nov 9, 2020
2 parents 7eedf3f + e027394 commit 8e108cc
Show file tree
Hide file tree
Showing 5 changed files with 460 additions and 133 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import software.amazon.smithy.codegen.core.SymbolProvider;
import software.amazon.smithy.go.codegen.GoWriter;
import software.amazon.smithy.go.codegen.SmithyGoDependency;
import software.amazon.smithy.go.codegen.SymbolUtils;
import software.amazon.smithy.go.codegen.SyntheticClone;
import software.amazon.smithy.go.codegen.integration.ProtocolGenerator;
import software.amazon.smithy.model.Model;
Expand Down Expand Up @@ -274,12 +275,28 @@ public static void writeXmlErrorMessageCodeDeserializer(ProtocolGenerator.Genera
ServiceShape service = context.getService();

if (requiresS3Customization(service)) {
writer.addUseImports(AwsCustomGoDependency.S3_SHARED_CUSTOMIZATION);
Symbol getErrorComponentFunction = SymbolUtils.createValueSymbolBuilder(
"GetErrorResponseComponents",
AwsCustomGoDependency.S3_SHARED_CUSTOMIZATION
).build();

Symbol errorOptions = SymbolUtils.createValueSymbolBuilder(
"ErrorResponseDeserializerOptions",
AwsCustomGoDependency.S3_SHARED_CUSTOMIZATION
).build();

if (isS3Service(service)){
writer.write("errorComponents, err := s3shared.GetS3ErrorResponseComponents(errorBody, response.StatusCode)");
// s3 service
writer.openBlock("errorComponents, err := $T(errorBody, $T{",
"})", getErrorComponentFunction, errorOptions, () -> {
writer.write("UseStatusCode : true, StatusCode : response.StatusCode,");
});
} else {
// s3 control
writer.write("errorComponents, err := s3shared.GetErrorResponseComponents(errorBody)");
writer.openBlock("errorComponents, err := $T(errorBody, $T{",
"})", getErrorComponentFunction, errorOptions, () -> {
writer.write("IsWrappedWithErrorTag: true,");
});
}

writer.write("if err != nil { return err }");
Expand Down
62 changes: 53 additions & 9 deletions service/internal/s3shared/xml_utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,29 +17,73 @@ type ErrorComponents struct {
HostID string `xml:"HostId"`
}

// GetErrorResponseComponents returns the error fields from an xml error response body
func GetErrorResponseComponents(r io.Reader) (ErrorComponents, error) {
// GetUnwrappedErrorResponseComponents returns the error fields from an xml error response body
func GetUnwrappedErrorResponseComponents(r io.Reader) (ErrorComponents, error) {
var errComponents ErrorComponents
if err := xml.NewDecoder(r).Decode(&errComponents); err != nil && err != io.EOF {
return ErrorComponents{}, fmt.Errorf("error while deserializing xml error response : %w", err)
}
return errComponents, nil
}

// GetS3ErrorResponseComponents returns the error fields from an S3 xml error response body.
// If an error code or message is not retrieved, it is derived from the http status code
func GetS3ErrorResponseComponents(r io.Reader, statusCode int) (ErrorComponents, error) {
errComponents, err := GetErrorResponseComponents(r)
// GetWrappedErrorResponseComponents returns the error fields from an xml error response body
// in which error code, and message are wrapped by a <Error> tag
func GetWrappedErrorResponseComponents(r io.Reader) (ErrorComponents, error) {
var errComponents struct {
Code string `xml:"Error>Code"`
Message string `xml:"Error>Message"`
RequestID string `xml:"RequestId"`
HostID string `xml:"HostId"`
}

if err := xml.NewDecoder(r).Decode(&errComponents); err != nil && err != io.EOF {
return ErrorComponents{}, fmt.Errorf("error while deserializing xml error response : %w", err)
}

return ErrorComponents{
Code: errComponents.Code,
Message: errComponents.Message,
RequestID: errComponents.RequestID,
HostID: errComponents.HostID,
}, nil
}

// GetErrorResponseComponents retrieves error components according to passed in options
func GetErrorResponseComponents(r io.Reader, options ErrorResponseDeserializerOptions) (ErrorComponents, error) {
var errComponents ErrorComponents
var err error

if options.IsWrappedWithErrorTag {
errComponents, err = GetWrappedErrorResponseComponents(r)
} else {
errComponents, err = GetUnwrappedErrorResponseComponents(r)
}

if err != nil {
return ErrorComponents{}, err
}

// for S3 service, we derive err code and message, if none is found
if len(errComponents.Code) == 0 && len(errComponents.Message) == 0 {
// If an error code or message is not retrieved, it is derived from the http status code
// eg, for S3 service, we derive err code and message, if none is found
if options.UseStatusCode && len(errComponents.Code) == 0 &&
len(errComponents.Message) == 0 {
// derive code and message from status code
statusText := http.StatusText(statusCode)
statusText := http.StatusText(options.StatusCode)
errComponents.Code = strings.Replace(statusText, " ", "", -1)
errComponents.Message = statusText
}
return errComponents, nil
}

// ErrorResponseDeserializerOptions represents error response deserializer options for s3 and s3-control service
type ErrorResponseDeserializerOptions struct {
// UseStatusCode denotes if status code should be used to retrieve error code, msg
UseStatusCode bool

// StatusCode is status code of error response
StatusCode int

//IsWrappedWithErrorTag represents if error response's code, msg is wrapped within an
// additional <Error> tag
IsWrappedWithErrorTag bool
}
36 changes: 33 additions & 3 deletions service/internal/s3shared/xml_utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,15 @@ func TestGetResponseErrorCode(t *testing.T) {
<RequestId>foo-id</RequestId>
</Error>`

const wrappedXMLErrorResponse = `<ErrorResponse><Error>
<Type>Sender</Type>
<Code>InvalidGreeting</Code>
<Message>Hi</Message>
</Error>
<HostId>bar-id</HostId>
<RequestId>foo-id</RequestId>
</ErrorResponse>`

cases := map[string]struct {
getErr func() (ErrorComponents, error)
expectedErrorCode string
Expand All @@ -24,7 +33,11 @@ func TestGetResponseErrorCode(t *testing.T) {
"standard xml error": {
getErr: func() (ErrorComponents, error) {
errResp := strings.NewReader(xmlErrorResponse)
return GetErrorResponseComponents(errResp)
return GetErrorResponseComponents(errResp, ErrorResponseDeserializerOptions{
UseStatusCode: false,
StatusCode: 0,
IsWrappedWithErrorTag: false,
})
},
expectedErrorCode: "InvalidGreeting",
expectedErrorMessage: "Hi",
Expand All @@ -35,17 +48,34 @@ func TestGetResponseErrorCode(t *testing.T) {
"s3 no response body": {
getErr: func() (ErrorComponents, error) {
errResp := strings.NewReader("")
return GetS3ErrorResponseComponents(errResp, 400)
return GetErrorResponseComponents(errResp, ErrorResponseDeserializerOptions{
UseStatusCode: true,
StatusCode: 400,
})
},
expectedErrorCode: "BadRequest",
expectedErrorMessage: "Bad Request",
},
"s3control no response body": {
getErr: func() (ErrorComponents, error) {
errResp := strings.NewReader("")
return GetErrorResponseComponents(errResp)
return GetErrorResponseComponents(errResp, ErrorResponseDeserializerOptions{
IsWrappedWithErrorTag: true,
})
},
},
"s3control standard response body": {
getErr: func() (ErrorComponents, error) {
errResp := strings.NewReader(wrappedXMLErrorResponse)
return GetErrorResponseComponents(errResp, ErrorResponseDeserializerOptions{
IsWrappedWithErrorTag: true,
})
},
expectedErrorCode: "InvalidGreeting",
expectedErrorMessage: "Hi",
expectedErrorRequestID: "foo-id",
expectedErrorHostID: "bar-id",
},
}

for name, c := range cases {
Expand Down
Loading

0 comments on commit 8e108cc

Please sign in to comment.