Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add flag to disable IMDSv1 fallback #4748

Merged
merged 10 commits into from
Mar 13, 2023
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG_PENDING.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,7 @@

### SDK Enhancements

* `aws/ec2metadata`: Added an option to disable fallback to IMDSv1.
* When set the SDK will no longer fallback to IMDSv1 when fetching a token fails. Use `aws.WithEC2MetadataDisableFallback` to enable.

### SDK Bugs
64 changes: 46 additions & 18 deletions aws/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,16 @@ type RequestRetryer interface{}
// A Config provides service configuration for service clients. By default,
// all clients will use the defaults.DefaultConfig structure.
//
// // Create Session with MaxRetries configuration to be shared by multiple
// // service clients.
// sess := session.Must(session.NewSession(&aws.Config{
// MaxRetries: aws.Int(3),
// }))
// // Create Session with MaxRetries configuration to be shared by multiple
// // service clients.
// sess := session.Must(session.NewSession(&aws.Config{
// MaxRetries: aws.Int(3),
// }))
//
// // Create S3 service client with a specific Region.
// svc := s3.New(sess, &aws.Config{
// Region: aws.String("us-west-2"),
// })
// // Create S3 service client with a specific Region.
// svc := s3.New(sess, &aws.Config{
// Region: aws.String("us-west-2"),
// })
type Config struct {
// Enables verbose error printing of all credential chain errors.
// Should be used when wanting to see all errors while attempting to
Expand Down Expand Up @@ -192,6 +192,23 @@ type Config struct {
//
EC2MetadataDisableTimeoutOverride *bool

// Set this to `false` to disable EC2Metadata client from falling back to IMDSv1.
// By default, EC2 role credentials will fall back to IMDSv1 as needed for backwards compatibility.
// You can disable this behavior by explicitly setting this flag to `false`. When false, the EC2Metadata
// client will return any errors encountered from attempting to fetch a token instead of silently
// using the insecure data flow of IMDSv1.
//
// Example:
// sess := session.Must(session.NewSession(aws.NewConfig()
// .WithEC2MetadataEnableFallback(false)))
//
// svc := s3.New(sess)
//
// See [configuring IMDS] for more information.
//
// [configuring IMDS]: https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/configuring-instance-metadata-service.html
EC2MetadataEnableFallback *bool

// Instructs the endpoint to be generated for a service client to
// be the dual stack endpoint. The dual stack endpoint will support
// both IPv4 and IPv6 addressing.
Expand Down Expand Up @@ -283,16 +300,16 @@ type Config struct {
// NewConfig returns a new Config pointer that can be chained with builder
// methods to set multiple configuration values inline without using pointers.
//
// // Create Session with MaxRetries configuration to be shared by multiple
// // service clients.
// sess := session.Must(session.NewSession(aws.NewConfig().
// WithMaxRetries(3),
// ))
// // Create Session with MaxRetries configuration to be shared by multiple
// // service clients.
// sess := session.Must(session.NewSession(aws.NewConfig().
// WithMaxRetries(3),
// ))
//
// // Create S3 service client with a specific Region.
// svc := s3.New(sess, aws.NewConfig().
// WithRegion("us-west-2"),
// )
// // Create S3 service client with a specific Region.
// svc := s3.New(sess, aws.NewConfig().
// WithRegion("us-west-2"),
// )
func NewConfig() *Config {
return &Config{}
}
Expand Down Expand Up @@ -432,6 +449,13 @@ func (c *Config) WithEC2MetadataDisableTimeoutOverride(enable bool) *Config {
return c
}

// WithEC2MetadataEnableFallback sets a config EC2MetadataEnableFallback value
// returning a Config pointer for chaining.
func (c *Config) WithEC2MetadataEnableFallback(v bool) *Config {
c.EC2MetadataEnableFallback = &v
return c
}

// WithSleepDelay overrides the function used to sleep while waiting for the
// next retry. Defaults to time.Sleep.
func (c *Config) WithSleepDelay(fn func(time.Duration)) *Config {
Expand Down Expand Up @@ -576,6 +600,10 @@ func mergeInConfig(dst *Config, other *Config) {
dst.EC2MetadataDisableTimeoutOverride = other.EC2MetadataDisableTimeoutOverride
}

if other.EC2MetadataEnableFallback != nil {
dst.EC2MetadataEnableFallback = other.EC2MetadataEnableFallback
}

if other.SleepDelay != nil {
dst.SleepDelay = other.SleepDelay
}
Expand Down
42 changes: 34 additions & 8 deletions aws/ec2metadata/api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,9 @@ const (
NotFoundRequestTestType
InvalidTokenRequestTestType
ServerErrorForTokenTestType
pageNotFoundForTokenTestType
pageNotFoundWith401TestType
PageNotFoundForTokenTestType
PageNotFoundWith401TestType
ThrottleErrorForTokenNoFallbackTestType
)

type testServer struct {
Expand Down Expand Up @@ -126,12 +127,15 @@ func newTestServer(t *testing.T, testType testType, testServer *testServer) *htt
case ServerErrorForTokenTestType:
mux.HandleFunc("/latest/api/token", getTokenRequiredParams(t, testServer.serverErrorGetTokenHandler))
mux.HandleFunc("/", testServer.insecureGetLatestHandler)
case pageNotFoundForTokenTestType:
case PageNotFoundForTokenTestType:
mux.HandleFunc("/latest/api/token", getTokenRequiredParams(t, testServer.pageNotFoundGetTokenHandler))
mux.HandleFunc("/", testServer.insecureGetLatestHandler)
case pageNotFoundWith401TestType:
case PageNotFoundWith401TestType:
mux.HandleFunc("/latest/api/token", getTokenRequiredParams(t, testServer.pageNotFoundGetTokenHandler))
mux.HandleFunc("/", testServer.unauthorizedGetLatestHandler)
case ThrottleErrorForTokenNoFallbackTestType:
mux.HandleFunc("/latest/api/token", getTokenRequiredParams(t, testServer.throtleErrorGetTokenHandler))
mux.HandleFunc("/", testServer.unauthorizedGetLatestHandler)

}

Expand Down Expand Up @@ -213,6 +217,10 @@ func (s *testServer) unauthorizedGetLatestHandler(w http.ResponseWriter, r *http
http.Error(w, "", 401)
}

func (s *testServer) throtleErrorGetTokenHandler(w http.ResponseWriter, r *http.Request) {
http.Error(w, "", 429)
}

func (opListProvider *operationListProvider) addToOperationPerformedList(r *request.Request) {
opListProvider.operationsPerformed = append(opListProvider.operationsPerformed, r.Operation.Name)
}
Expand Down Expand Up @@ -241,6 +249,7 @@ func TestGetMetadata(t *testing.T) {
expectedData string
expectedError string
expectedOperationsAttempted []string
enableImdsFallback *bool
}{
"Insecure server success case": {
NewServer: func(t *testing.T, tokens []string) *httptest.Server {
Expand Down Expand Up @@ -325,6 +334,21 @@ func TestGetMetadata(t *testing.T) {
expectedData: "IMDSProfileForGoSDK",
expectedOperationsAttempted: []string{"GetToken", "GetMetadata", "GetMetadata"},
},
"No fallback to IMDSv1": {
NewServer: func(t *testing.T, tokens []string) *httptest.Server {
testType := ThrottleErrorForTokenNoFallbackTestType
Ts := &testServer{
t: t,
tokens: []string{},
data: "IMDSProfileForGoSDK",
}
return newTestServer(t, testType, Ts)
},
expectedError: "failed to get IMDS token and fallback is disabled",
// 2 attempts + 2 retries per/attempt
expectedOperationsAttempted: []string{"GetToken", "GetToken", "GetToken", "GetToken", "GetToken", "GetToken"},
enableImdsFallback: aws.Bool(false),
},
}

for name, x := range cases {
Expand All @@ -336,8 +360,10 @@ func TestGetMetadata(t *testing.T) {
op := &operationListProvider{}

c := ec2metadata.New(unit.Session, &aws.Config{
Endpoint: aws.String(server.URL),
Endpoint: aws.String(server.URL),
EC2MetadataEnableFallback: x.enableImdsFallback,
})

c.Handlers.CompleteAttempt.PushBack(op.addToOperationPerformedList)

tokenCounter := -1
Expand Down Expand Up @@ -953,7 +979,7 @@ func TestExhaustiveRetryToFetchToken(t *testing.T) {
data: "IMDSProfileForSDKGo",
}

server := newTestServer(t, pageNotFoundForTokenTestType, ts)
server := newTestServer(t, PageNotFoundForTokenTestType, ts)
defer server.Close()

op := &operationListProvider{}
Expand Down Expand Up @@ -1007,7 +1033,7 @@ func TestExhaustiveRetryWith401(t *testing.T) {
data: "IMDSProfileForSDKGo",
}

server := newTestServer(t, pageNotFoundWith401TestType, ts)
server := newTestServer(t, PageNotFoundWith401TestType, ts)
defer server.Close()

op := &operationListProvider{}
Expand Down Expand Up @@ -1117,7 +1143,7 @@ func TestRequestTimeOut(t *testing.T) {
t.Fatalf("Expected no error, got %v", err)
}

expectedOperationsPerformed = []string{"GetToken", "GetMetadata", "GetMetadata"}
expectedOperationsPerformed = []string{"GetToken", "GetMetadata", "GetToken", "GetMetadata"}
if e, a := expectedOperationsPerformed, op.operationsPerformed; !reflect.DeepEqual(e, a) {
t.Fatalf("expect %v operations, got %v", e, a)
}
Expand Down
10 changes: 5 additions & 5 deletions aws/ec2metadata/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,13 @@ type EC2Metadata struct {
// New creates a new instance of the EC2Metadata client with a session.
// This client is safe to use across multiple goroutines.
//
//
// Example:
// // Create a EC2Metadata client from just a session.
// svc := ec2metadata.New(mySession)
//
// // Create a EC2Metadata client with additional configuration
// svc := ec2metadata.New(mySession, aws.NewConfig().WithLogLevel(aws.LogDebugHTTPBody))
// // Create a EC2Metadata client from just a session.
// svc := ec2metadata.New(mySession)
//
// // Create a EC2Metadata client with additional configuration
// svc := ec2metadata.New(mySession, aws.NewConfig().WithLogLevel(aws.LogDebugHTTPBody))
func New(p client.ConfigProvider, cfgs ...*aws.Config) *EC2Metadata {
c := p.ClientConfig(ServiceName, cfgs...)
return NewClient(*c.Config, c.Handlers, c.Endpoint, c.SigningRegion)
Expand Down
25 changes: 14 additions & 11 deletions aws/ec2metadata/token_provider.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package ec2metadata

import (
"fmt"
"net/http"
"sync/atomic"
"time"
Expand Down Expand Up @@ -33,11 +34,15 @@ func newTokenProvider(c *EC2Metadata, duration time.Duration) *tokenProvider {
return &tokenProvider{client: c, configuredTTL: duration}
}

// check if fallback is enabled
func (t *tokenProvider) fallbackEnabled() bool {
return t.client.Config.EC2MetadataEnableFallback == nil || *t.client.Config.EC2MetadataEnableFallback
}

// fetchTokenHandler fetches token for EC2Metadata service client by default.
func (t *tokenProvider) fetchTokenHandler(r *request.Request) {

// short-circuits to insecure data flow if tokenProvider is disabled.
if v := atomic.LoadUint32(&t.disabled); v == 1 {
if v := atomic.LoadUint32(&t.disabled); v == 1 && t.fallbackEnabled() {
return
}

Expand All @@ -49,23 +54,21 @@ func (t *tokenProvider) fetchTokenHandler(r *request.Request) {
output, err := t.client.getToken(r.Context(), t.configuredTTL)

if err != nil {
// only attempt fallback to insecure data flow if IMDSv1 is enabled
if !t.fallbackEnabled() {
r.Error = awserr.New("EC2MetadataError", "failed to get IMDS token and fallback is disabled", err)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: i think this error message could have specified what version of IMDS failed and what version its falling back to. so like:

failed to get IMDSv2 token and fallback to IMDSv1 is disabled

return
}

// change the disabled flag on token provider to true,
// when error is request timeout error.
// change the disabled flag on token provider to true and fallback
if requestFailureError, ok := err.(awserr.RequestFailure); ok {
switch requestFailureError.StatusCode() {
case http.StatusForbidden, http.StatusNotFound, http.StatusMethodNotAllowed:
atomic.StoreUint32(&t.disabled, 1)
t.client.Config.Logger.Log(fmt.Sprintf("WARN: failed to get session token, falling back to IMDSv1: %v", requestFailureError))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice

case http.StatusBadRequest:
r.Error = requestFailureError
}

// Check if request timed out while waiting for response
if e, ok := requestFailureError.OrigErr().(awserr.Error); ok {
if e.Code() == request.ErrCodeRequestError {
Copy link
Contributor

@lucix-aws lucix-aws Mar 8, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question: did we ever get a handle on what was triggering this code path? From when I looked at it, the comment didn't appear accurate, it seemed like anything in the handler stack could have made it here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is my concern and maybe a reason that fallback is happening more often than it should. This seems as though it can trigger too easily.

Copy link
Contributor

@isaiahvita isaiahvita Mar 13, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agreed. the original Standard states that this should be for "timeouts". im wondering if it should be checking for ErrCodeResponseTimeout rather than ErrCodeRequestError

EDIT: this was meant to be a reply to @lucix-aws and @aajtodd thread but github seemed to format it as a separate thread

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We talked offline about this and decided to just remove this. This was triggering permanent fallback for likely far too many error response codes. In the event of a timeout error it will still fallback for the current request but will retry imdsv2 for subsequent requests.

atomic.StoreUint32(&t.disabled, 1)
}
}
}
return
}
Expand Down