Skip to content

Commit

Permalink
feature/ec2/imds: Fix IMDS client's response handling and operation t…
Browse files Browse the repository at this point in the history
…imeout race

Fixes aws#1253 race between reading a IMDS response, and the
operationTimeout middleware cleaning up its timeout context. Changes the
IMDS client to always buffer the response body received, before the
result is deserialized. This ensures that the consumer of the
operation's response body will not race with context cleanup within the
middleware stack.
  • Loading branch information
jasdel committed Oct 5, 2021
1 parent 74bf5cf commit fe480b1
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 3 deletions.
8 changes: 8 additions & 0 deletions .changelog/53dad1d685864ddfb0304829c2a45e4c.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
{
"id": "53dad1d6-8586-4ddf-b030-4829c2a45e4c",
"type": "bugfix",
"description": "Fix IMDS client's response handling and operation timeout race. Fixes #1253",
"modules": [
"feature/ec2/imds"
]
}
21 changes: 18 additions & 3 deletions feature/ec2/imds/request_middleware.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
package imds

import (
"bytes"
"context"
"fmt"
"io/ioutil"
"net/url"
"path"
"time"
Expand Down Expand Up @@ -142,12 +144,20 @@ func (m *deserializeResponse) HandleDeserialize(
resp, ok := out.RawResponse.(*smithyhttp.Response)
if !ok {
return out, metadata, fmt.Errorf(
"unexpected transport response type, %T", out.RawResponse)
"unexpected transport response type, %T, want %T", out.RawResponse, resp)
}
defer resp.Body.Close()

// Anything thats not 200 |< 300 is error
// read the full body so that any operation timeouts cleanup will not race
// the body being read.
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
return out, metadata, fmt.Errorf("read response body failed, %w", err)
}
resp.Body = ioutil.NopCloser(bytes.NewReader(body))

// Anything that's not 200 |< 300 is error
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
resp.Body.Close()
return out, metadata, &smithyhttp.ResponseError{
Response: resp,
Err: fmt.Errorf("request to EC2 IMDS failed"),
Expand Down Expand Up @@ -213,6 +223,11 @@ const (
defaultOperationTimeout = 5 * time.Second
)

// operationTimeout adds a timeout on the middleware stack. The next middleware
// must complete before the timeout. The next middleware must also ensure that
// any response body payloads are completely read from the response message
// before returning. Otherwise timeouts cleanup will race the response body
// being read upstream.
type operationTimeout struct {
Timeout time.Duration
}
Expand Down
88 changes: 88 additions & 0 deletions feature/ec2/imds/request_middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,94 @@ func TestOperationTimeoutMiddleware(t *testing.T) {
}
}

// Ensure that the response body is read in the deserialize middleware,
// ensuring that the timeoutOperation middleware won't race canceling the
// context with the upstream reading the response body.
// * https://github.com/aws/aws-sdk-go-v2/issues/1253
func TestDeserailizeResponse_cacheBody(t *testing.T) {
type Output struct {
Content io.ReadCloser
}
m := &deserializeResponse{
GetOutput: func(resp *smithyhttp.Response) (interface{}, error) {
return &Output{
Content: resp.Body,
}, nil
},
}

expectBody := "hello world!"
originalBody := &bytesReader{
reader: strings.NewReader(expectBody),
}
if originalBody.closed {
t.Fatalf("expect original body not to be closed yet")
}

out, _, err := m.HandleDeserialize(context.Background(), middleware.DeserializeInput{},
middleware.DeserializeHandlerFunc(func(
ctx context.Context, input middleware.DeserializeInput,
) (
out middleware.DeserializeOutput, metadata middleware.Metadata, err error,
) {
out.RawResponse = &smithyhttp.Response{
Response: &http.Response{
StatusCode: 200,
Status: "200 OK",
Header: http.Header{},
ContentLength: int64(originalBody.Len()),
Body: originalBody,
},
}
return out, metadata, nil
}))
if err != nil {
t.Fatalf("expect no error, got %v", err)
}

if !originalBody.closed {
t.Errorf("expect original body to be closed, was not")
}

result, ok := out.Result.(*Output)
if !ok {
t.Fatalf("expect result to be Output, got %T, %v", result, result)
}

actualBody, err := ioutil.ReadAll(result.Content)
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
if e, a := expectBody, string(actualBody); e != a {
t.Errorf("expect %v body, got %v", e, a)
}
if err := result.Content.Close(); err != nil {
t.Fatalf("expect no error, got %v", err)
}
}

type bytesReader struct {
reader interface {
io.Reader
Len() int
}
closed bool
}

func (r *bytesReader) Len() int {
return r.reader.Len()
}
func (r *bytesReader) Close() error {
r.closed = true
return nil
}
func (r *bytesReader) Read(p []byte) (int, error) {
if r.closed {
return 0, io.EOF
}
return r.reader.Read(p)
}

type successAPIResponseHandler struct {
t *testing.T
path string
Expand Down

0 comments on commit fe480b1

Please sign in to comment.