Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fix] httpbp middleware doesn't flush chunked responses #573

Merged
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions httpbp/middlewares.go
Original file line number Diff line number Diff line change
Expand Up @@ -394,9 +394,10 @@ func recordStatusCode(counters counterGenerator) Middleware {
return func(name string, next HandlerFunc) HandlerFunc {
counter := counters.Counter("baseplate.http." + name + ".response")
return func(ctx context.Context, w http.ResponseWriter, r *http.Request) (err error) {
wrapped := &statusCodeRecorder{ResponseWriter: w}
rec := &statusCodeRecorder{ResponseWriter: w}
wrapped := wrapResponseWriter(w, rec)
defer func() {
code := wrapped.getCode(err)
code := rec.getCode(err)
counter.With("status", statusCodeFamily(code)).Add(1)
}()

Expand Down Expand Up @@ -453,9 +454,11 @@ func PrometheusServerMetrics(_ string) Middleware {
}
serverActiveRequests.With(activeRequestLabels).Inc()

wrapped := &responseRecorder{ResponseWriter: w}
rec := &responseRecorder{ResponseWriter: w}
wrapped := wrapResponseWriter(w, rec)

defer func() {
code := errorCodeForMetrics(wrapped.responseCode, err)
code := errorCodeForMetrics(rec.responseCode, err)
success := isRequestSuccessful(code, err)

labels := prometheus.Labels{
Expand All @@ -465,7 +468,7 @@ func PrometheusServerMetrics(_ string) Middleware {
}
serverLatency.With(labels).Observe(time.Since(start).Seconds())
serverRequestSize.With(labels).Observe(float64(r.ContentLength))
serverResponseSize.With(labels).Observe(float64(wrapped.bytesWritten))
serverResponseSize.With(labels).Observe(float64(rec.bytesWritten))

totalRequestLabels := prometheus.Labels{
methodLabel: method,
Expand Down
156 changes: 156 additions & 0 deletions httpbp/middlewares_test.go
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
package httpbp_test

import (
"bufio"
"context"
"encoding/json"
"errors"
"net"
"net/http"
"net/http/httptest"
"sort"
"strings"
"testing"

"github.com/reddit/baseplate.go"
"github.com/reddit/baseplate.go/ecinterface"
"github.com/reddit/baseplate.go/httpbp"
"github.com/reddit/baseplate.go/log"
Expand Down Expand Up @@ -323,3 +326,156 @@ func TestSupportedMethods(t *testing.T) {
)
}
}

func TestMiddlewareResponseWrapping(t *testing.T) {
store := newSecretsStore(t)
defer store.Close()

bp := baseplate.NewTestBaseplate(baseplate.NewTestBaseplateArgs{
Config: baseplate.Config{Addr: ":8080"},
Store: store,
EdgeContextImpl: ecinterface.Mock(),
})

args := httpbp.ServerArgs{
Baseplate: bp,
Middlewares: []httpbp.Middleware{
func(name string, next httpbp.HandlerFunc) httpbp.HandlerFunc {
return func(ctx context.Context, w http.ResponseWriter, r *http.Request) error {
if flusher, isFlusher := w.(http.Flusher); isFlusher {
flusher.Flush()
}

next(ctx, w, r)
return nil
}
},
func(name string, next httpbp.HandlerFunc) httpbp.HandlerFunc {
return func(ctx context.Context, w http.ResponseWriter, r *http.Request) error {
if hijacker, isHijacker := w.(http.Hijacker); isHijacker {
hijacker.Hijack()
}

next(ctx, w, r)
return nil
}
},
func(name string, next httpbp.HandlerFunc) httpbp.HandlerFunc {
return func(ctx context.Context, w http.ResponseWriter, r *http.Request) error {
if pusher, isPusher := w.(http.Pusher); isPusher {
pusher.Push("target", &http.PushOptions{})
}

next(ctx, w, r)
return nil
}
},
},
Endpoints: map[httpbp.Pattern]httpbp.Endpoint{
"/test": {
Name: "test",
Methods: []string{http.MethodGet},
Handle: func(ctx context.Context, w http.ResponseWriter, r *http.Request) error {
w.Write([]byte("endpoint"))
return nil
},
},
},
}

// register our middleware to the EndpointRegistry
args, err := args.SetupEndpoints()

if err != nil {
t.Fatal(err)
}

t.Run("non-flushable-non-hijackable", func(tt *testing.T) {
type baseResponseWriter struct {
http.ResponseWriter
}

r := httptest.NewRequest(http.MethodGet, "/test", nil)
inner := httptest.NewRecorder()
w := baseResponseWriter{inner}
args.EndpointRegistry.ServeHTTP(w, r)

if inner.Flushed {
tt.Error("expected response to not be flushed")
}
})

// Test the a flushable response
t.Run("flushable", func(tt *testing.T) {
r := httptest.NewRequest(http.MethodGet, "/test", nil)
w := httptest.NewRecorder()
args.EndpointRegistry.ServeHTTP(w, r)

if !w.Flushed {
tt.Error("expected http response to be flushed")
}
})

t.Run("hijackable", func(tt *testing.T) {
r := httptest.NewRequest(http.MethodGet, "/test", nil)
w := &hijackableResponseRecorder{httptest.NewRecorder(), false}
args.EndpointRegistry.ServeHTTP(w, r)

if !w.Hijacked {
tt.Error("expected http response to be hijacked")
}
})

t.Run("pushable", func(tt *testing.T) {
r := httptest.NewRequest(http.MethodGet, "/test", nil)
w := &pushableResponseRecorder{httptest.NewRecorder(), false}
args.EndpointRegistry.ServeHTTP(w, r)

if !w.Pushed {
tt.Error("expected http response to be pushed")
}
})

t.Run("hijackable-flushable", func(tt *testing.T) {
type hijackableFlushableRecorder struct {
hijackableResponseRecorder
http.Flusher
}

r := httptest.NewRequest(http.MethodGet, "/test", nil)
inner := httptest.NewRecorder()
w := &hijackableFlushableRecorder{
hijackableResponseRecorder{inner, false},
inner,
}
args.EndpointRegistry.ServeHTTP(w, r)

if !w.Hijacked {
tt.Error("expected http response to be hijacked")
}

if !inner.Flushed {
tt.Error("expected http response to be flushed")
}
})
adamthesax marked this conversation as resolved.
Show resolved Hide resolved
}

type hijackableResponseRecorder struct {
http.ResponseWriter
Hijacked bool
}

func (h *hijackableResponseRecorder) Hijack() (net.Conn, *bufio.ReadWriter, error) {
h.Hijacked = true
return nil, nil, nil
}

type pushableResponseRecorder struct {
http.ResponseWriter
Pushed bool
}

func (p *pushableResponseRecorder) Push(target string, opts *http.PushOptions) error {
p.Pushed = true
return nil
}
33 changes: 33 additions & 0 deletions httpbp/response_wrappers.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// DO NOT EDIT.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DO NOT EDIT is incompatible with "this was generated and then edited" :D

Copy link
Member

@fishy fishy Oct 27, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1.

I would say it's better to just hand write this code:

type optionalResponseWriter int

const (
  flusher optionalResponseWriter = 1 << iota  
  hijacker
  pusher
)

func wrapResponseWriter(...) ... {
  var f optionalResponseWriter
  ...

  switch f {
  case 0:
    ...
  case flusher:
    ...
  case flusher | hijacker:
    ...
  ...
}

// This code was partially generated and then made human readable.
// This approach was adapted from https://github.com/badgerodon/contextaware/blob/4c442dfd39106512496bdd13c42c451da8ddeff3/internal/generate-wrap/main.go
// See also https://www.doxsey.net/blog/fixing-interface-erasure-in-go/ for more context

package httpbp

import (
"net/http"
)

func wrapResponseWriter(orig, wrapped http.ResponseWriter) http.ResponseWriter {
var f uint64
flusher, isFlusher := orig.(http.Flusher)
if isFlusher { f |= 0x0001 }
hijacker, isHijacker := orig.(http.Hijacker)
if isHijacker { f |= 0x0002 }
pusher, isPusher := orig.(http.Pusher)
if isPusher { f |= 0x0004 }

switch f {
case 0x0000: return wrapped
case 0x0001: return struct{http.ResponseWriter;http.Flusher}{wrapped, flusher}
case 0x0002: return struct{http.ResponseWriter;http.Hijacker}{wrapped, hijacker}
case 0x0003: return struct{http.ResponseWriter;http.Flusher;http.Hijacker}{wrapped, flusher,hijacker}
case 0x0004: return struct{http.ResponseWriter;http.Pusher}{wrapped, pusher}
case 0x0005: return struct{http.ResponseWriter;http.Flusher;http.Pusher}{wrapped,flusher,pusher}
case 0x0006: return struct{http.ResponseWriter;http.Hijacker;http.Pusher}{wrapped, hijacker,pusher}
case 0x0007: return struct{http.ResponseWriter;http.Flusher;http.Hijacker;http.Pusher}{wrapped, flusher,hijacker,pusher}
}

return wrapped
}