Skip to content

Commit

Permalink
proxy: pass url via request context (#2058)
Browse files Browse the repository at this point in the history
Signed-off-by: Alexander Yastrebov <[email protected]>

Signed-off-by: Alexander Yastrebov <[email protected]>
  • Loading branch information
AlexanderYastrebov committed Aug 18, 2022
1 parent bc39a90 commit 8426343
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 47 deletions.
33 changes: 17 additions & 16 deletions proxy/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ const (
unknownRouteID = "_unknownroute_"
unknownRouteBackendType = "<unknown>"
unknownRouteBackend = "<unknown>"
backendIsProxyHeader = "X-Skipper-Proxy"

// Number of loops allowed by default.
DefaultMaxLoopbacks = 9
Expand Down Expand Up @@ -514,19 +513,21 @@ func mapRequest(ctx *context, requestContext stdlibcontext.Context, removeHopHea
rr.Header.Add("Authorization", fmt.Sprintf("Basic %s", upBase64))
}

if _, ok := stateBag[filters.BackendIsProxyKey]; ok {
forwardToProxy(r, rr)
}

ctxspan := ot.SpanFromContext(r.Context())
if ctxspan != nil {
rr = rr.WithContext(ot.ContextWithSpan(rr.Context(), ctxspan))
}

if _, ok := stateBag[filters.BackendIsProxyKey]; ok {
rr = forwardToProxy(r, rr)
}

return rr, endpoint, nil
}

func forwardToProxy(incoming, outgoing *http.Request) {
type proxyUrlContextKey struct{}

func forwardToProxy(incoming, outgoing *http.Request) *http.Request {
proxyURL := &url.URL{
Scheme: outgoing.URL.Scheme,
Host: outgoing.URL.Host,
Expand All @@ -535,7 +536,15 @@ func forwardToProxy(incoming, outgoing *http.Request) {
outgoing.URL.Host = incoming.Host
outgoing.URL.Scheme = schemeFromRequest(incoming)

outgoing.Header.Set(backendIsProxyHeader, proxyURL.String())
return outgoing.WithContext(stdlibcontext.WithValue(outgoing.Context(), proxyUrlContextKey{}, proxyURL))
}

func proxyFromContext(req *http.Request) (*url.URL, error) {
proxyURL, _ := req.Context().Value(proxyUrlContextKey{}).(*url.URL)
if proxyURL != nil {
return proxyURL, nil
}
return nil, nil
}

type skipperDialer struct {
Expand Down Expand Up @@ -628,7 +637,7 @@ func WithParams(p Params) *Proxy {
MaxIdleConnsPerHost: p.IdleConnectionsPerHost,
IdleConnTimeout: p.CloseIdleConnsPeriod,
DisableKeepAlives: p.DisableHTTPKeepalives,
Proxy: proxyFromHeader,
Proxy: proxyFromContext,
}

quit := make(chan struct{})
Expand Down Expand Up @@ -728,14 +737,6 @@ func tryCatch(p func(), onErr func(err interface{}, stack string)) {
p()
}

func proxyFromHeader(req *http.Request) (*url.URL, error) {
if u := req.Header.Get(backendIsProxyHeader); u != "" {
req.Header.Del(backendIsProxyHeader)
return url.Parse(u)
}
return nil, nil
}

// applies filters to a request
func (p *Proxy) applyFiltersToRequest(f []*routing.RouteFilter, ctx *context) []*routing.RouteFilter {
if len(f) == 0 {
Expand Down
40 changes: 9 additions & 31 deletions proxy/proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2213,44 +2213,22 @@ func TestForwardToProxy(t *testing.T) {
TLS: ti.tls,
}

forwardToProxy(incoming, outgoing)
outgoing = forwardToProxy(incoming, outgoing)

if outgoing.URL.String() != ti.expectedRequestURL {
t.Errorf("request URLs are not equal, expected %s got %s",
ti.expectedRequestURL, outgoing.URL.String())
}
assert.Equal(t, ti.expectedRequestURL, outgoing.URL.String())

proxyURL := outgoing.Header.Get(backendIsProxyHeader)
proxyURL, err := proxyFromContext(outgoing)

if proxyURL != ti.expectedProxyURL {
t.Errorf("proxy URLs are not equal, expected %s got %s",
ti.expectedProxyURL, proxyURL)
}
assert.NoError(t, err)
assert.Equal(t, ti.expectedProxyURL, proxyURL.String())
}
}

func TestProxyFromHeader(t *testing.T) {
u1, err := proxyFromHeader(&http.Request{})
if err != nil {
t.Error(err)
}
if u1 != nil {
t.Errorf("expected nil but got %v", u1)
}

expectedProxyURL := "http://proxy.example.com"
func TestProxyFromEmptyContext(t *testing.T) {
proxyUrl, err := proxyFromContext(&http.Request{})

u2, err := proxyFromHeader(&http.Request{
Header: http.Header{
backendIsProxyHeader: []string{expectedProxyURL},
},
})
if err != nil {
t.Error(err)
}
if u2.String() != expectedProxyURL {
t.Errorf("expected '%s' but got '%v'", expectedProxyURL, u2)
}
assert.NoError(t, err)
assert.Nil(t, proxyUrl)
}

func BenchmarkAccessLogNoFilter(b *testing.B) { benchmarkAccessLog(b, "", 200) }
Expand Down

0 comments on commit 8426343

Please sign in to comment.