From fe480b13105dd4fa754e9913316a1419c175ac92 Mon Sep 17 00:00:00 2001 From: Jason Del Ponte <961963+jasdel@users.noreply.github.com> Date: Tue, 5 Oct 2021 14:16:11 -0700 Subject: [PATCH] feature/ec2/imds: Fix IMDS client's response handling and operation timeout race Fixes #1253 race between reading a IMDS response, and the operationTimeout middleware cleaning up its timeout context. Changes the IMDS client to always buffer the response body received, before the result is deserialized. This ensures that the consumer of the operation's response body will not race with context cleanup within the middleware stack. --- .../53dad1d685864ddfb0304829c2a45e4c.json | 8 ++ feature/ec2/imds/request_middleware.go | 21 ++++- feature/ec2/imds/request_middleware_test.go | 88 +++++++++++++++++++ 3 files changed, 114 insertions(+), 3 deletions(-) create mode 100644 .changelog/53dad1d685864ddfb0304829c2a45e4c.json diff --git a/.changelog/53dad1d685864ddfb0304829c2a45e4c.json b/.changelog/53dad1d685864ddfb0304829c2a45e4c.json new file mode 100644 index 00000000000..c25231e2a2f --- /dev/null +++ b/.changelog/53dad1d685864ddfb0304829c2a45e4c.json @@ -0,0 +1,8 @@ +{ + "id": "53dad1d6-8586-4ddf-b030-4829c2a45e4c", + "type": "bugfix", + "description": "Fix IMDS client's response handling and operation timeout race. Fixes #1253", + "modules": [ + "feature/ec2/imds" + ] +} \ No newline at end of file diff --git a/feature/ec2/imds/request_middleware.go b/feature/ec2/imds/request_middleware.go index 93f02405f99..a10916acf51 100644 --- a/feature/ec2/imds/request_middleware.go +++ b/feature/ec2/imds/request_middleware.go @@ -1,8 +1,10 @@ package imds import ( + "bytes" "context" "fmt" + "io/ioutil" "net/url" "path" "time" @@ -142,12 +144,20 @@ func (m *deserializeResponse) HandleDeserialize( resp, ok := out.RawResponse.(*smithyhttp.Response) if !ok { return out, metadata, fmt.Errorf( - "unexpected transport response type, %T", out.RawResponse) + "unexpected transport response type, %T, want %T", out.RawResponse, resp) } + defer resp.Body.Close() - // Anything thats not 200 |< 300 is error + // read the full body so that any operation timeouts cleanup will not race + // the body being read. + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + return out, metadata, fmt.Errorf("read response body failed, %w", err) + } + resp.Body = ioutil.NopCloser(bytes.NewReader(body)) + + // Anything that's not 200 |< 300 is error if resp.StatusCode < 200 || resp.StatusCode >= 300 { - resp.Body.Close() return out, metadata, &smithyhttp.ResponseError{ Response: resp, Err: fmt.Errorf("request to EC2 IMDS failed"), @@ -213,6 +223,11 @@ const ( defaultOperationTimeout = 5 * time.Second ) +// operationTimeout adds a timeout on the middleware stack. The next middleware +// must complete before the timeout. The next middleware must also ensure that +// any response body payloads are completely read from the response message +// before returning. Otherwise timeouts cleanup will race the response body +// being read upstream. type operationTimeout struct { Timeout time.Duration } diff --git a/feature/ec2/imds/request_middleware_test.go b/feature/ec2/imds/request_middleware_test.go index 629947e096a..614eacfa75a 100644 --- a/feature/ec2/imds/request_middleware_test.go +++ b/feature/ec2/imds/request_middleware_test.go @@ -150,6 +150,94 @@ func TestOperationTimeoutMiddleware(t *testing.T) { } } +// Ensure that the response body is read in the deserialize middleware, +// ensuring that the timeoutOperation middleware won't race canceling the +// context with the upstream reading the response body. +// * https://github.com/aws/aws-sdk-go-v2/issues/1253 +func TestDeserailizeResponse_cacheBody(t *testing.T) { + type Output struct { + Content io.ReadCloser + } + m := &deserializeResponse{ + GetOutput: func(resp *smithyhttp.Response) (interface{}, error) { + return &Output{ + Content: resp.Body, + }, nil + }, + } + + expectBody := "hello world!" + originalBody := &bytesReader{ + reader: strings.NewReader(expectBody), + } + if originalBody.closed { + t.Fatalf("expect original body not to be closed yet") + } + + out, _, err := m.HandleDeserialize(context.Background(), middleware.DeserializeInput{}, + middleware.DeserializeHandlerFunc(func( + ctx context.Context, input middleware.DeserializeInput, + ) ( + out middleware.DeserializeOutput, metadata middleware.Metadata, err error, + ) { + out.RawResponse = &smithyhttp.Response{ + Response: &http.Response{ + StatusCode: 200, + Status: "200 OK", + Header: http.Header{}, + ContentLength: int64(originalBody.Len()), + Body: originalBody, + }, + } + return out, metadata, nil + })) + if err != nil { + t.Fatalf("expect no error, got %v", err) + } + + if !originalBody.closed { + t.Errorf("expect original body to be closed, was not") + } + + result, ok := out.Result.(*Output) + if !ok { + t.Fatalf("expect result to be Output, got %T, %v", result, result) + } + + actualBody, err := ioutil.ReadAll(result.Content) + if err != nil { + t.Fatalf("expect no error, got %v", err) + } + if e, a := expectBody, string(actualBody); e != a { + t.Errorf("expect %v body, got %v", e, a) + } + if err := result.Content.Close(); err != nil { + t.Fatalf("expect no error, got %v", err) + } +} + +type bytesReader struct { + reader interface { + io.Reader + Len() int + } + closed bool +} + +func (r *bytesReader) Len() int { + return r.reader.Len() +} +func (r *bytesReader) Close() error { + r.closed = true + return nil +} +func (r *bytesReader) Read(p []byte) (int, error) { + if r.closed { + return 0, io.EOF + } + return r.reader.Read(p) +} + type successAPIResponseHandler struct { t *testing.T path string