Skip to content

Commit

Permalink
feature/ec2/imds: Fix Client's response handling and operation timeou…
Browse files Browse the repository at this point in the history
…t race (#1448)

Fixes #1253 race between reading a IMDS response, and the
operationTimeout middleware cleaning up its timeout context. Changes the
IMDS client to always buffer the response body received, before the
result is deserialized. This ensures that the consumer of the
operation's response body will not race with context cleanup within the
middleware stack.

Updates the IMDS Client operations to not override the passed in
Context's Deadline or Timeout options. If an Client operation is called
with a Context with a Deadline or Timeout, the client will no longer
override it with the client's default timeout.

Updates operationTimeout so that if DefaultTimeout is unset (aka zero)
operationTimeout will not set a default timeout on the context.
  • Loading branch information
jasdel authored Oct 6, 2021
1 parent f1baf2d commit 0d3bd7a
Show file tree
Hide file tree
Showing 5 changed files with 196 additions and 10 deletions.
8 changes: 8 additions & 0 deletions .changelog/17ac89419cd94e598fcc93444709cc1a.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
{
"id": "17ac8941-9cd9-4e59-8fcc-93444709cc1a",
"type": "feature",
"description": "Respect passed in Context Deadline/Timeout. Updates the IMDS Client operations to not override the passed in Context's Deadline or Timeout options. If an Client operation is called with a Context with a Deadline or Timeout, the client will no longer override it with the client's default timeout.",
"modules": [
"feature/ec2/imds"
]
}
8 changes: 8 additions & 0 deletions .changelog/53dad1d685864ddfb0304829c2a45e4c.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
{
"id": "53dad1d6-8586-4ddf-b030-4829c2a45e4c",
"type": "bugfix",
"description": "Fix IMDS client's response handling and operation timeout race. Fixes #1253",
"modules": [
"feature/ec2/imds"
]
}
5 changes: 5 additions & 0 deletions feature/ec2/imds/doc.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
// Package imds provides the API client for interacting with the Amazon EC2
// Instance Metadata Service.
//
// All Client operation calls have a default timeout. If the operation is not
// completed before this timeout expires, the operation will be canceled. This
// timeout can be overridden by providing Context with a timeout or deadline
// with calling the client's operations.
//
// See the EC2 IMDS user guide for more information on using the API.
// https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/ec2-instance-metadata.html
package imds
40 changes: 31 additions & 9 deletions feature/ec2/imds/request_middleware.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
package imds

import (
"bytes"
"context"
"fmt"
"io/ioutil"
"net/url"
"path"
"time"
Expand Down Expand Up @@ -52,7 +54,7 @@ func addRequestMiddleware(stack *middleware.Stack,

// Operation timeout
err = stack.Initialize.Add(&operationTimeout{
Timeout: defaultOperationTimeout,
DefaultTimeout: defaultOperationTimeout,
}, middleware.Before)
if err != nil {
return err
Expand Down Expand Up @@ -142,12 +144,20 @@ func (m *deserializeResponse) HandleDeserialize(
resp, ok := out.RawResponse.(*smithyhttp.Response)
if !ok {
return out, metadata, fmt.Errorf(
"unexpected transport response type, %T", out.RawResponse)
"unexpected transport response type, %T, want %T", out.RawResponse, resp)
}
defer resp.Body.Close()

// Anything thats not 200 |< 300 is error
// read the full body so that any operation timeouts cleanup will not race
// the body being read.
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
return out, metadata, fmt.Errorf("read response body failed, %w", err)
}
resp.Body = ioutil.NopCloser(bytes.NewReader(body))

// Anything that's not 200 |< 300 is error
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
resp.Body.Close()
return out, metadata, &smithyhttp.ResponseError{
Response: resp,
Err: fmt.Errorf("request to EC2 IMDS failed"),
Expand Down Expand Up @@ -213,8 +223,19 @@ const (
defaultOperationTimeout = 5 * time.Second
)

// operationTimeout adds a timeout on the middleware stack if the Context the
// stack was called with does not have a deadline. The next middleware must
// complete before the timeout, or the context will be canceled.
//
// If DefaultTimeout is zero, no default timeout will be used if the Context
// does not have a timeout.
//
// The next middleware must also ensure that any resources that are also
// canceled by the stack's context are completely consumed before returning.
// Otherwise the timeout cleanup will race the resource being consumed
// upstream.
type operationTimeout struct {
Timeout time.Duration
DefaultTimeout time.Duration
}

func (*operationTimeout) ID() string { return "OperationTimeout" }
Expand All @@ -224,10 +245,11 @@ func (m *operationTimeout) HandleInitialize(
) (
output middleware.InitializeOutput, metadata middleware.Metadata, err error,
) {
var cancelFn func()

ctx, cancelFn = context.WithTimeout(ctx, m.Timeout)
defer cancelFn()
if _, ok := ctx.Deadline(); !ok && m.DefaultTimeout != 0 {
var cancelFn func()
ctx, cancelFn = context.WithTimeout(ctx, m.DefaultTimeout)
defer cancelFn()
}

return next.HandleInitialize(ctx, input)
}
Expand Down
145 changes: 144 additions & 1 deletion feature/ec2/imds/request_middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"bytes"
"context"
"encoding/hex"
"fmt"
"io"
"io/ioutil"
"net/http"
Expand Down Expand Up @@ -126,7 +127,7 @@ func TestAddRequestMiddleware(t *testing.T) {

func TestOperationTimeoutMiddleware(t *testing.T) {
m := &operationTimeout{
Timeout: time.Nanosecond,
DefaultTimeout: time.Nanosecond,
}

_, _, err := m.HandleInitialize(context.Background(), middleware.InitializeInput{},
Expand All @@ -135,6 +136,10 @@ func TestOperationTimeoutMiddleware(t *testing.T) {
) (
out middleware.InitializeOutput, metadata middleware.Metadata, err error,
) {
if _, ok := ctx.Deadline(); !ok {
return out, metadata, fmt.Errorf("expect context deadline to be set")
}

if err := sdk.SleepWithContext(ctx, time.Second); err != nil {
return out, metadata, err
}
Expand All @@ -150,6 +155,144 @@ func TestOperationTimeoutMiddleware(t *testing.T) {
}
}

func TestOperationTimeoutMiddleware_noDefaultTimeout(t *testing.T) {
m := &operationTimeout{}

_, _, err := m.HandleInitialize(context.Background(), middleware.InitializeInput{},
middleware.InitializeHandlerFunc(func(
ctx context.Context, input middleware.InitializeInput,
) (
out middleware.InitializeOutput, metadata middleware.Metadata, err error,
) {
if t, ok := ctx.Deadline(); ok {
return out, metadata, fmt.Errorf("expect no context deadline, got %v", t)
}

return out, metadata, nil
}))
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
}

func TestOperationTimeoutMiddleware_withCustomDeadline(t *testing.T) {
m := &operationTimeout{
DefaultTimeout: time.Nanosecond,
}

expectDeadline := time.Now().Add(time.Hour)
ctx, cancelFn := context.WithDeadline(context.Background(), expectDeadline)
defer cancelFn()

_, _, err := m.HandleInitialize(ctx, middleware.InitializeInput{},
middleware.InitializeHandlerFunc(func(
ctx context.Context, input middleware.InitializeInput,
) (
out middleware.InitializeOutput, metadata middleware.Metadata, err error,
) {
t, ok := ctx.Deadline()
if !ok {
return out, metadata, fmt.Errorf("expect context deadline to be set")
}
if e, a := expectDeadline, t; !e.Equal(a) {
return out, metadata, fmt.Errorf("expect %v deadline, got %v", e, a)
}

return out, metadata, nil
}))
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
}

// Ensure that the response body is read in the deserialize middleware,
// ensuring that the timeoutOperation middleware won't race canceling the
// context with the upstream reading the response body.
// * https://github.com/aws/aws-sdk-go-v2/issues/1253
func TestDeserailizeResponse_cacheBody(t *testing.T) {
type Output struct {
Content io.ReadCloser
}
m := &deserializeResponse{
GetOutput: func(resp *smithyhttp.Response) (interface{}, error) {
return &Output{
Content: resp.Body,
}, nil
},
}

expectBody := "hello world!"
originalBody := &bytesReader{
reader: strings.NewReader(expectBody),
}
if originalBody.closed {
t.Fatalf("expect original body not to be closed yet")
}

out, _, err := m.HandleDeserialize(context.Background(), middleware.DeserializeInput{},
middleware.DeserializeHandlerFunc(func(
ctx context.Context, input middleware.DeserializeInput,
) (
out middleware.DeserializeOutput, metadata middleware.Metadata, err error,
) {
out.RawResponse = &smithyhttp.Response{
Response: &http.Response{
StatusCode: 200,
Status: "200 OK",
Header: http.Header{},
ContentLength: int64(originalBody.Len()),
Body: originalBody,
},
}
return out, metadata, nil
}))
if err != nil {
t.Fatalf("expect no error, got %v", err)
}

if !originalBody.closed {
t.Errorf("expect original body to be closed, was not")
}

result, ok := out.Result.(*Output)
if !ok {
t.Fatalf("expect result to be Output, got %T, %v", result, result)
}

actualBody, err := ioutil.ReadAll(result.Content)
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
if e, a := expectBody, string(actualBody); e != a {
t.Errorf("expect %v body, got %v", e, a)
}
if err := result.Content.Close(); err != nil {
t.Fatalf("expect no error, got %v", err)
}
}

type bytesReader struct {
reader interface {
io.Reader
Len() int
}
closed bool
}

func (r *bytesReader) Len() int {
return r.reader.Len()
}
func (r *bytesReader) Close() error {
r.closed = true
return nil
}
func (r *bytesReader) Read(p []byte) (int, error) {
if r.closed {
return 0, io.EOF
}
return r.reader.Read(p)
}

type successAPIResponseHandler struct {
t *testing.T
path string
Expand Down

0 comments on commit 0d3bd7a

Please sign in to comment.