Skip to content

Commit

Permalink
feat(header forwarding): add HeaderForwarder and plumbing (#507)
Browse files Browse the repository at this point in the history
* add HeaderForwarder and plumbing

* PR feedback

* make add-license

* make lint

* update templates

* codegen from templates

* dedupe template function

* fix linting

* fix tests
  • Loading branch information
potterbm-cb authored Oct 18, 2024
1 parent 348e01b commit bad12fa
Show file tree
Hide file tree
Showing 21 changed files with 479 additions and 61 deletions.
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,11 @@ GO_MOD_PACKAGES=./types/...
GO_FOLDERS=$(shell echo ${GO_PACKAGES} | sed -e "s/\.\///g" | sed -e "s/\/\.\.\.//g")
GO_MOD_FOLDERS=$(shell echo ${GO_MOD_PACKAGES} | sed -e "s/\.\///g" | sed -e "s/\/\.\.\.//g")
TEST_SCRIPT=go test ${GO_PACKAGES}
LINT_SETTINGS=golint,misspell,gocyclo,gocritic,whitespace,goconst,gocognit,bodyclose,unconvert,lll,unparam
LINT_SETTINGS=misspell,gocyclo,gocritic,whitespace,goconst,gocognit,bodyclose,unconvert,lll,unparam

build:
go build ./...

deps:
go get ./...

Expand Down
2 changes: 1 addition & 1 deletion asserter/asserter.go
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,7 @@ func NewGenericRosettaClient(
ignoreRosettaSpecValidation: true,
}

//init default operation statuses for generic rosetta client
// init default operation statuses for generic rosetta client
InitOperationStatus(asserter)

return asserter, nil
Expand Down
2 changes: 1 addition & 1 deletion constructor/worker/worker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1848,7 +1848,7 @@ func TestHTTPRequestWorker(t *testing.T) {

w.Header().Set("Content-Type", test.contentType)
w.WriteHeader(test.statusCode)
fmt.Fprintf(w, test.response)
fmt.Fprint(w, test.response)
}))

defer ts.Close()
Expand Down
2 changes: 2 additions & 0 deletions examples/server/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,14 @@ func NewBlockchainRouter(
networkAPIController := server.NewNetworkAPIController(
networkAPIService,
asserter,
nil,
)

blockAPIService := services.NewBlockAPIService(network)
blockAPIController := server.NewBlockAPIController(
blockAPIService,
asserter,
nil,
)

return server.NewRouter(networkAPIController, blockAPIController)
Expand Down
2 changes: 1 addition & 1 deletion fetcher/block.go
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ func (f *Fetcher) UnsafeBlock(
}

// Exit early if no need to fetch txs
if blockResponse.OtherTransactions == nil || len(blockResponse.OtherTransactions) == 0 {
if len(blockResponse.OtherTransactions) == 0 {
return blockResponse.Block, nil
}

Expand Down
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ require (
github.com/dgraph-io/badger/v2 v2.2007.4
github.com/ethereum/go-ethereum v1.10.21
github.com/fatih/color v1.13.0
github.com/google/uuid v1.6.0
github.com/gorilla/mux v1.8.0
github.com/lucasjones/reggen v0.0.0-20180717132126-cdb49ff09d77
github.com/neilotoole/errgroup v0.1.6
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ github.com/golang/snappy v0.0.4/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEW
github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU=
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/gorilla/mux v1.8.0 h1:i40aqfkR1h2SlN9hojwV5ZA91wcXFOvkdNIeFDP5koI=
github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So=
github.com/gorilla/websocket v1.4.1/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
Expand Down
43 changes: 43 additions & 0 deletions headerforwarder/context_headers.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
// Copyright 2024 Coinbase, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package headerforwarder

import (
"context"
"net/http"

"github.com/google/uuid"
)

type contextKey string

const requestIDKey = contextKey("request_id")

func ContextWithRosettaID(ctx context.Context) context.Context {
return context.WithValue(ctx, requestIDKey, uuid.NewString())
}

func RosettaIDFromContext(ctx context.Context) string {
return ctx.Value(requestIDKey).(string)
}

func RosettaIDFromRequest(r *http.Request) string {
switch value := r.Context().Value(requestIDKey).(type) {
case string:
return value
default:
return ""
}
}
142 changes: 142 additions & 0 deletions headerforwarder/forwarder.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
// Copyright 2024 Coinbase, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package headerforwarder

import (
"net/http"
)

// HeaderExtractingTransport is a utility to help a rosetta server forward headers to and from
// native node requests. It implements several interfaces to achieve that:
// - http.RoundTripper: this can be used to create an http Client that will automatically save headers
// if necessary
// - func(http.Handler) http.Handler: this can be used to wrap an http.Handler to set headers
// on the response
//
// the headers can be requested later.
//
// TODO: this should expire entries after a certain amount of time
type HeaderForwarder struct {
requestHeaders map[string]http.Header
interestingHeaders []string
actualTransport http.RoundTripper
}

func NewHeaderForwarder(interestingHeaders []string, transport http.RoundTripper) *HeaderForwarder {
return &HeaderForwarder{
requestHeaders: make(map[string]http.Header),
interestingHeaders: interestingHeaders,
actualTransport: transport,
}
}

// RoundTrip implements http.RoundTripper and will be used to construct an http Client which
// saves the native node response headers if necessary.
func (hf *HeaderForwarder) RoundTrip(req *http.Request) (*http.Response, error) {
resp, err := hf.actualTransport.RoundTrip(req)

if err == nil && hf.shouldRememberHeaders(req, resp) {
hf.rememberHeaders(req, resp)
}

return resp, err
}

// shouldRememberHeaders is called to determine if response headers should be remembered for a
// given request. Response headers will only be remembered if the request does not contain all of
// the interesting headers and the response contains at least one of the interesting headers.
//
// It should be noted that the request and response here are for a request to the native node,
// not a request to the Rosetta server.
func (hf *HeaderForwarder) shouldRememberHeaders(req *http.Request, resp *http.Response) bool {
requestHasAllHeaders := true
responseHasSomeHeaders := false

for _, interestingHeader := range hf.interestingHeaders {
_, requestHasHeader := req.Header[http.CanonicalHeaderKey(interestingHeader)]
_, responseHasHeader := resp.Header[http.CanonicalHeaderKey(interestingHeader)]

if !requestHasHeader {
requestHasAllHeaders = false
}

if responseHasHeader {
responseHasSomeHeaders = true
}
}

// only remember headers if the request does not contain all of the interesting headers and the
// response contains at least one
return !requestHasAllHeaders && responseHasSomeHeaders
}

// rememberHeaders is called to save the native node response headers. The request object
// here is a native node request (constructed by go-ethereum for geth-based rosetta implementations).
// The response object is a native node response.
func (hf *HeaderForwarder) rememberHeaders(req *http.Request, resp *http.Response) {
ctx := req.Context()
// rosettaRequestID := services.osettaIdFromContext(ctx)
rosettaRequestID := RosettaIDFromContext(ctx)

// Only remember interesting headers
headersToRemember := make(http.Header)
for _, interestingHeader := range hf.interestingHeaders {
headersToRemember.Set(interestingHeader, resp.Header.Get(interestingHeader))
}

hf.requestHeaders[rosettaRequestID] = headersToRemember
}

// GetResponseHeaders returns any native node response headers that were recorded for a request ID.
func (hf *HeaderForwarder) getResponseHeaders(rosettaRequestID string) (http.Header, bool) {
headers, ok := hf.requestHeaders[rosettaRequestID]

// Delete the headers from the map after they are retrieved
// This is safe to call even if the key doesn't exist
delete(hf.requestHeaders, rosettaRequestID)

return headers, ok
}

// HeaderForwarderHandler will allow the next handler to serve the request, and then checks
// if there are any native node response headers recorded for the request. If there are, it will set
// those headers on the response
func (hf *HeaderForwarder) HeaderForwarderHandler(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// add a unique ID to the request context, and make a new request for it
requestWithID := hf.WithRequestID(r)

// Serve the request
// NOTE: ResponseWriter::WriteHeader() WILL be called here, so we can't set headers after this happens
// We include a wrapper around the response writer that allows us to set headers just before
// WriteHeader is called
wrappedResponseWriter := NewResponseWriter(
w,
RosettaIDFromRequest(requestWithID),
hf.getResponseHeaders,
)
next.ServeHTTP(wrappedResponseWriter, requestWithID)
})
}

// WithRequestID adds a unique ID to the request context. A new request is returned that contains the
// new context
func (hf *HeaderForwarder) WithRequestID(req *http.Request) *http.Request {
ctx := req.Context()
ctxWithID := ContextWithRosettaID(ctx)
requestWithID := req.WithContext(ctxWithID)

return requestWithID
}
68 changes: 68 additions & 0 deletions headerforwarder/response_writer.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
// Copyright 2024 Coinbase, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package headerforwarder

import (
"net/http"
)

// ResponseWriter is a wrapper around a http.ResponseWriter that allows us to set headers
// just before the WriteHeader function is called. These headers will be extracted from native node
// responses, and set on the rosetta response.
type ResponseWriter struct {
writer http.ResponseWriter
RosettaRequestID string
GetAdditionalHeaders func(string) (http.Header, bool)
}

func NewResponseWriter(
writer http.ResponseWriter,
rosettaRequestID string,
getAdditionalHeaders func(string) (http.Header, bool),
) *ResponseWriter {
return &ResponseWriter{
writer: writer,
RosettaRequestID: rosettaRequestID,
GetAdditionalHeaders: getAdditionalHeaders,
}
}

// Header passes through to the underlying ResponseWriter instance
func (hfrw *ResponseWriter) Header() http.Header {
return hfrw.writer.Header()
}

// Write passes through to the underlying ResponseWriter instance
func (hfrw *ResponseWriter) Write(b []byte) (int, error) {
return hfrw.writer.Write(b)
}

// WriteHeader will add any final extracted headers, and then pass through to the underlying ResponseWriter instance
func (hfrw *ResponseWriter) WriteHeader(statusCode int) {
hfrw.AddExtractedHeaders()
hfrw.writer.WriteHeader(statusCode)
}

func (hfrw *ResponseWriter) AddExtractedHeaders() {
headers, hasAdditionalHeaders := hfrw.GetAdditionalHeaders(hfrw.RosettaRequestID)

if hasAdditionalHeaders {
for key, values := range headers {
for _, value := range values {
hfrw.writer.Header().Add(key, value)
}
}
}
}
26 changes: 20 additions & 6 deletions server/api_account.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package server

import (
"context"
"encoding/json"
"net/http"
"strings"
Expand All @@ -28,18 +29,21 @@ import (
// A AccountAPIController binds http requests to an api service and writes the service results to
// the http response
type AccountAPIController struct {
service AccountAPIServicer
asserter *asserter.Asserter
service AccountAPIServicer
asserter *asserter.Asserter
contextFromRequest func(*http.Request) context.Context
}

// NewAccountAPIController creates a default api controller
func NewAccountAPIController(
s AccountAPIServicer,
asserter *asserter.Asserter,
contextFromRequest func(*http.Request) context.Context,
) Router {
return &AccountAPIController{
service: s,
asserter: asserter,
service: s,
asserter: asserter,
contextFromRequest: contextFromRequest,
}
}

Expand All @@ -61,6 +65,16 @@ func (c *AccountAPIController) Routes() Routes {
}
}

func (c *AccountAPIController) ContextFromRequest(r *http.Request) context.Context {
ctx := r.Context()

if c.contextFromRequest != nil {
ctx = c.contextFromRequest(r)
}

return ctx
}

// AccountBalance - Get an Account's Balance
func (c *AccountAPIController) AccountBalance(w http.ResponseWriter, r *http.Request) {
accountBalanceRequest := &types.AccountBalanceRequest{}
Expand All @@ -81,7 +95,7 @@ func (c *AccountAPIController) AccountBalance(w http.ResponseWriter, r *http.Req
return
}

result, serviceErr := c.service.AccountBalance(r.Context(), accountBalanceRequest)
result, serviceErr := c.service.AccountBalance(c.ContextFromRequest(r), accountBalanceRequest)
if serviceErr != nil {
EncodeJSONResponse(serviceErr, http.StatusInternalServerError, w)

Expand Down Expand Up @@ -111,7 +125,7 @@ func (c *AccountAPIController) AccountCoins(w http.ResponseWriter, r *http.Reque
return
}

result, serviceErr := c.service.AccountCoins(r.Context(), accountCoinsRequest)
result, serviceErr := c.service.AccountCoins(c.ContextFromRequest(r), accountCoinsRequest)
if serviceErr != nil {
EncodeJSONResponse(serviceErr, http.StatusInternalServerError, w)

Expand Down
Loading

0 comments on commit bad12fa

Please sign in to comment.