From 63f8545a4c58dbb0ac0bf29df54e782066410206 Mon Sep 17 00:00:00 2001 From: Yuxuan 'fishy' Wang Date: Tue, 11 Jun 2024 10:38:16 -0700 Subject: [PATCH] kafkabp: Support AWS IMDS v2 for rack id Also modernize the code a bit (use sync.OnceValues, etc.). --- kafkabp/rack.go | 138 +++++++++++++++++++++++++++++++++--------------- 1 file changed, 94 insertions(+), 44 deletions(-) diff --git a/kafkabp/rack.go b/kafkabp/rack.go index e74550827..fdeee5661 100644 --- a/kafkabp/rack.go +++ b/kafkabp/rack.go @@ -5,6 +5,7 @@ import ( "encoding" "fmt" "io" + "log/slog" "net/http" "strings" "sync" @@ -88,7 +89,7 @@ func FixedRackID(id string) RackIDFunc { // Default values for SimpleHTTPRackIDConfig. const ( SimpleHTTPRackIDDefaultLimit = 1024 - SimpleHTTPRackIDDefaultTimeout = time.Second + SimpleHTTPRackIDDefaultTimeout = 1 * time.Second ) // SimpleHTTPRackIDConfig defines the config to be used in SimpleHTTPRackID. @@ -126,12 +127,12 @@ func SimpleHTTPRackID(cfg SimpleHTTPRackIDConfig) RackIDFunc { } return func() string { - client := http.Client{ - Timeout: cfg.Timeout, - } - resp, err := client.Get(cfg.URL) + ctx, cancel := context.WithTimeout(context.Background(), cfg.Timeout) + defer cancel() + + req, err := http.NewRequestWithContext(ctx, cfg.URL, http.MethodGet, http.NoBody) if err != nil { - cfg.Logger.Log(context.Background(), fmt.Sprintf( + cfg.Logger.Log(ctx, fmt.Sprintf( "Failed to get rack id from %s: %v", cfg.URL, err, @@ -139,42 +140,91 @@ func SimpleHTTPRackID(cfg SimpleHTTPRackIDConfig) RackIDFunc { return "" } - defer func() { - io.Copy(io.Discard, resp.Body) - resp.Body.Close() - }() - content, err := io.ReadAll(io.LimitReader(resp.Body, cfg.Limit)) + content, err := doHTTP(req, cfg.Limit) if err != nil { cfg.Logger.Log(context.Background(), fmt.Sprintf( - "Failed to read rack id response from %s: %v", + "Failed to get rack id from %s: %v", cfg.URL, err, )) return "" } - if resp.StatusCode >= 400 { - cfg.Logger.Log(context.Background(), fmt.Sprintf( - "Rack id URL %s returned status code %d: %s", - cfg.URL, - resp.StatusCode, - content, - )) - return "" - } - return strings.TrimSpace(string(content)) + return content } } -// Global cache for AWSAvailabilityZoneRackID. -var ( - awsCachedRackID string - awsRackIDOnce sync.Once -) +var client http.Client + +// doHTTP executes http request, reads the body up to the limit given, and +// return the body read as string with whitespace trimmed. +func doHTTP(r *http.Request, readLimit int64) (string, error) { + resp, err := client.Do(r) + if err != nil { + return "", fmt.Errorf("kafkabp.doHTTP: request failed: %w", err) + } + + defer func() { + io.Copy(io.Discard, resp.Body) + resp.Body.Close() + }() + + content, err := io.ReadAll(io.LimitReader(resp.Body, readLimit)) + if err != nil { + return "", fmt.Errorf("kafkabp.doHTTP: failed to read response body: %w", err) + } + + body := strings.TrimSpace(string(content)) + if resp.StatusCode >= 400 { + return "", fmt.Errorf("kafkabp.doHTTP: got http response with code %d and body %q", resp.StatusCode, body) + } + + return body, nil +} -// References: -// https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html -// https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-categories.html -const awsAZurl = "http://169.254.169.254/latest/meta-data/placement/availability-zone" +var awsRackID = sync.OnceValues(func() (string, error) { + const ( + // References: + // https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html + // https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-categories.html + tokenURL = "http://169.254.169.254/latest/api/token" + azURL = "http://169.254.169.254/latest/meta-data/placement/availability-zone" + + timeout = time.Second + readLimit = 1024 + ) + + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + token, err := func(ctx context.Context) (string, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodPut, tokenURL, http.NoBody) + if err != nil { + return "", fmt.Errorf("kafkabp.awsRackID: failed to create request from url %q: %w", tokenURL, err) + } + req.Header.Set("X-aws-ec2-metadata-token-ttl-seconds", "21600") + + token, err := doHTTP(req, readLimit) + if err != nil { + return "", fmt.Errorf("kafkabp.awsRackID: failed to get AWS IMDS v2 token from url %q: %w", tokenURL, err) + } + return token, nil + }(ctx) + if err != nil { + return "", err + } + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, azURL, http.NoBody) + if err != nil { + return "", fmt.Errorf("kafkabp.awsRackID: failed to create request from url %q: %w", azURL, err) + } + req.Header.Set("X-aws-ec2-metadata-token", token) + + id, err := doHTTP(req, readLimit) + if err != nil { + err = fmt.Errorf("kafkabp.awsRackID: failed to get AWS availability zone from url %q: %w", azURL, err) + } + return id, err +}) // AWSAvailabilityZoneRackID is a RackIDFunc implementation that returns AWS // availability zone as the rack id. @@ -194,18 +244,18 @@ const awsAZurl = "http://169.254.169.254/latest/meta-data/placement/availability // // other configs // }) // -// It uses SimpleHTTPRackIDConfig underneath with log.DefaultWrapper with a -// prometheus counter of kafkabp_aws_rack_id_failure_total and default -// Limit & Timeout. +// It uses AWS instance metadata HTTP API with 1second overall timeout and 1024 +// HTTP response read limits.. +// +// If there was an error retrieving rack id through AWS instance metadata API, +// the same error will be logged at slog's warning level every time +// AWSAvailabilityZoneRackID is called. func AWSAvailabilityZoneRackID() string { - awsRackIDOnce.Do(func() { - awsCachedRackID = SimpleHTTPRackID(SimpleHTTPRackIDConfig{ - URL: awsAZurl, - Logger: log.CounterWrapper( - nil, // delegate, let it fallback to DefaultWrapper - awsRackFailure, - ), - })() - }) - return awsCachedRackID + id, err := awsRackID() + if err != nil { + awsRackFailure.Inc() + slog.Warn("Failed to get AWS availability zone as rack id", "err", err) + return "" + } + return id }