Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature/ec2/imds: Fix Client's response handling and operation timeout race #1448

Merged
merged 3 commits into from
Oct 6, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions .changelog/17ac89419cd94e598fcc93444709cc1a.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
{
"id": "17ac8941-9cd9-4e59-8fcc-93444709cc1a",
"type": "feature",
"description": "Respect passed in Context Deadline/Timeout. Updates the IMDS Client operations to not override the passed in Context's Deadline or Timeout options. If an Client operation is called with a Context with a Deadline or Timeout, the client will no longer override it with the client's default timeout.",
"modules": [
"feature/ec2/imds"
]
}
8 changes: 8 additions & 0 deletions .changelog/53dad1d685864ddfb0304829c2a45e4c.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
{
"id": "53dad1d6-8586-4ddf-b030-4829c2a45e4c",
"type": "bugfix",
"description": "Fix IMDS client's response handling and operation timeout race. Fixes #1253",
"modules": [
"feature/ec2/imds"
]
}
5 changes: 5 additions & 0 deletions feature/ec2/imds/doc.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
// Package imds provides the API client for interacting with the Amazon EC2
// Instance Metadata Service.
//
// All Client operation calls have a default timeout. If the operation is not
// completed before this timeout expires, the operation will be canceled. This
// timeout can be overridden by providing Context with a timeout or deadline
// with calling the client's operations.
//
// See the EC2 IMDS user guide for more information on using the API.
// https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/ec2-instance-metadata.html
package imds
40 changes: 31 additions & 9 deletions feature/ec2/imds/request_middleware.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
package imds

import (
"bytes"
"context"
"fmt"
"io/ioutil"
"net/url"
"path"
"time"
Expand Down Expand Up @@ -52,7 +54,7 @@ func addRequestMiddleware(stack *middleware.Stack,

// Operation timeout
err = stack.Initialize.Add(&operationTimeout{
Timeout: defaultOperationTimeout,
DefaultTimeout: defaultOperationTimeout,
}, middleware.Before)
if err != nil {
return err
Expand Down Expand Up @@ -142,12 +144,20 @@ func (m *deserializeResponse) HandleDeserialize(
resp, ok := out.RawResponse.(*smithyhttp.Response)
if !ok {
return out, metadata, fmt.Errorf(
"unexpected transport response type, %T", out.RawResponse)
"unexpected transport response type, %T, want %T", out.RawResponse, resp)
}
defer resp.Body.Close()

// Anything thats not 200 |< 300 is error
// read the full body so that any operation timeouts cleanup will not race
// the body being read.
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
return out, metadata, fmt.Errorf("read response body failed, %w", err)
}
resp.Body = ioutil.NopCloser(bytes.NewReader(body))

// Anything that's not 200 |< 300 is error
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
resp.Body.Close()
return out, metadata, &smithyhttp.ResponseError{
Response: resp,
Err: fmt.Errorf("request to EC2 IMDS failed"),
Expand Down Expand Up @@ -213,8 +223,19 @@ const (
defaultOperationTimeout = 5 * time.Second
)

// operationTimeout adds a timeout on the middleware stack if the Context the
// stack was called with does not have a deadline. The next middleware must
// complete before the timeout, or the context will be canceled.
//
// If DefaultTimeout is zero, no default timeout will be used if the Context
// does not have a timeout.
//
// The next middleware must also ensure that any resources that are also
// canceled by the stack's context are completely consumed before returning.
// Otherwise the timeout cleanup will race the resource being consumed
// upstream.
type operationTimeout struct {
Timeout time.Duration
DefaultTimeout time.Duration
}

func (*operationTimeout) ID() string { return "OperationTimeout" }
Expand All @@ -224,10 +245,11 @@ func (m *operationTimeout) HandleInitialize(
) (
output middleware.InitializeOutput, metadata middleware.Metadata, err error,
) {
var cancelFn func()

ctx, cancelFn = context.WithTimeout(ctx, m.Timeout)
defer cancelFn()
if _, ok := ctx.Deadline(); !ok && m.DefaultTimeout != 0 {
var cancelFn func()
ctx, cancelFn = context.WithTimeout(ctx, m.DefaultTimeout)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe this could use the constant defaultOperationTimeout directly and remove the timeout var entirely from operationTimeout struct? Seems m.DefaultTimeout will always be defaultOperationTimeout in the end.

Not sure though of the impact/necessity on tests

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Having the DefaultTimeout member on operationTimeout middleware is handy if we want to change the timeout depending on the operation. I'm not aware of a need for this at the moment. Since operationTimeout is not exported, having the member is preferred so we can reuse the operationTimeout middleware later.

There is a test that uses a custom timeout, but I think this test needs to be tweaked a little bit to more closely align with the intention behind it. e.g. if there was a context with timeout use it, if there wasn't use a default. This test is relying on the behavior of a canceled context, and the configured nanosecond timeout input.

func TestOperationTimeoutMiddleware(t *testing.T) {
m := &operationTimeout{
DefaultTimeout: time.Nanosecond,
}
_, _, err := m.HandleInitialize(context.Background(), middleware.InitializeInput{},
middleware.InitializeHandlerFunc(func(
ctx context.Context, input middleware.InitializeInput,
) (
out middleware.InitializeOutput, metadata middleware.Metadata, err error,
) {
if err := sdk.SleepWithContext(ctx, time.Second); err != nil {
return out, metadata, err
}
return out, metadata, nil
}))
if err == nil {
t.Fatalf("expect error got none")
}
if e, a := "deadline exceeded", err.Error(); !strings.Contains(a, e) {
t.Errorf("expect %q error in %q", e, a)
}
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated operationTimeout to not set a default timeout, if DefaultTimeout is zero. I think this was the missing behavior in keeping DefaultTimeout member in operationTimeout.

defer cancelFn()
}

return next.HandleInitialize(ctx, input)
}
Expand Down
145 changes: 144 additions & 1 deletion feature/ec2/imds/request_middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"bytes"
"context"
"encoding/hex"
"fmt"
"io"
"io/ioutil"
"net/http"
Expand Down Expand Up @@ -126,7 +127,7 @@ func TestAddRequestMiddleware(t *testing.T) {

func TestOperationTimeoutMiddleware(t *testing.T) {
m := &operationTimeout{
Timeout: time.Nanosecond,
DefaultTimeout: time.Nanosecond,
}

_, _, err := m.HandleInitialize(context.Background(), middleware.InitializeInput{},
Expand All @@ -135,6 +136,10 @@ func TestOperationTimeoutMiddleware(t *testing.T) {
) (
out middleware.InitializeOutput, metadata middleware.Metadata, err error,
) {
if _, ok := ctx.Deadline(); !ok {
return out, metadata, fmt.Errorf("expect context deadline to be set")
}

if err := sdk.SleepWithContext(ctx, time.Second); err != nil {
return out, metadata, err
}
Expand All @@ -150,6 +155,144 @@ func TestOperationTimeoutMiddleware(t *testing.T) {
}
}

func TestOperationTimeoutMiddleware_noDefaultTimeout(t *testing.T) {
m := &operationTimeout{}

_, _, err := m.HandleInitialize(context.Background(), middleware.InitializeInput{},
middleware.InitializeHandlerFunc(func(
ctx context.Context, input middleware.InitializeInput,
) (
out middleware.InitializeOutput, metadata middleware.Metadata, err error,
) {
if t, ok := ctx.Deadline(); ok {
return out, metadata, fmt.Errorf("expect no context deadline, got %v", t)
}

return out, metadata, nil
}))
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
}

func TestOperationTimeoutMiddleware_withCustomDeadline(t *testing.T) {
m := &operationTimeout{
DefaultTimeout: time.Nanosecond,
}

expectDeadline := time.Now().Add(time.Hour)
ctx, cancelFn := context.WithDeadline(context.Background(), expectDeadline)
defer cancelFn()

_, _, err := m.HandleInitialize(ctx, middleware.InitializeInput{},
middleware.InitializeHandlerFunc(func(
ctx context.Context, input middleware.InitializeInput,
) (
out middleware.InitializeOutput, metadata middleware.Metadata, err error,
) {
t, ok := ctx.Deadline()
if !ok {
return out, metadata, fmt.Errorf("expect context deadline to be set")
}
if e, a := expectDeadline, t; !e.Equal(a) {
return out, metadata, fmt.Errorf("expect %v deadline, got %v", e, a)
}

return out, metadata, nil
}))
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
}

// Ensure that the response body is read in the deserialize middleware,
// ensuring that the timeoutOperation middleware won't race canceling the
// context with the upstream reading the response body.
// * https://github.com/aws/aws-sdk-go-v2/issues/1253
func TestDeserailizeResponse_cacheBody(t *testing.T) {
type Output struct {
Content io.ReadCloser
}
m := &deserializeResponse{
GetOutput: func(resp *smithyhttp.Response) (interface{}, error) {
return &Output{
Content: resp.Body,
}, nil
},
}

expectBody := "hello world!"
originalBody := &bytesReader{
reader: strings.NewReader(expectBody),
}
if originalBody.closed {
t.Fatalf("expect original body not to be closed yet")
}

out, _, err := m.HandleDeserialize(context.Background(), middleware.DeserializeInput{},
middleware.DeserializeHandlerFunc(func(
ctx context.Context, input middleware.DeserializeInput,
) (
out middleware.DeserializeOutput, metadata middleware.Metadata, err error,
) {
out.RawResponse = &smithyhttp.Response{
Response: &http.Response{
StatusCode: 200,
Status: "200 OK",
Header: http.Header{},
ContentLength: int64(originalBody.Len()),
Body: originalBody,
},
}
return out, metadata, nil
}))
if err != nil {
t.Fatalf("expect no error, got %v", err)
}

if !originalBody.closed {
t.Errorf("expect original body to be closed, was not")
}

result, ok := out.Result.(*Output)
if !ok {
t.Fatalf("expect result to be Output, got %T, %v", result, result)
}

actualBody, err := ioutil.ReadAll(result.Content)
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
if e, a := expectBody, string(actualBody); e != a {
t.Errorf("expect %v body, got %v", e, a)
}
if err := result.Content.Close(); err != nil {
t.Fatalf("expect no error, got %v", err)
}
}

type bytesReader struct {
reader interface {
io.Reader
Len() int
}
closed bool
}

func (r *bytesReader) Len() int {
return r.reader.Len()
}
func (r *bytesReader) Close() error {
r.closed = true
return nil
}
func (r *bytesReader) Read(p []byte) (int, error) {
if r.closed {
return 0, io.EOF
}
return r.reader.Read(p)
}

type successAPIResponseHandler struct {
t *testing.T
path string
Expand Down