Skip to content

Commit

Permalink
[fix] httpbp middleware doesn't flush chunked responses (#573)
Browse files Browse the repository at this point in the history
* wrapped middleware supports Flush() and Hijack()

* Update httpbp/middlewares.go

Co-authored-by: Kyle Lemons <[email protected]>

* added unit test

* chain wrapper functions together

* test tweaks

* additional test cases

* update to use jump table and support http.Pusher

* remove generated code and write all wrappers by hand

* PR feedback

Co-authored-by: Adam Sax <[email protected]>
Co-authored-by: Kyle Lemons <[email protected]>
  • Loading branch information
3 people authored Nov 1, 2022
1 parent 87a56c4 commit fae00cb
Show file tree
Hide file tree
Showing 3 changed files with 240 additions and 5 deletions.
13 changes: 8 additions & 5 deletions httpbp/middlewares.go
Original file line number Diff line number Diff line change
Expand Up @@ -394,9 +394,10 @@ func recordStatusCode(counters counterGenerator) Middleware {
return func(name string, next HandlerFunc) HandlerFunc {
counter := counters.Counter("baseplate.http." + name + ".response")
return func(ctx context.Context, w http.ResponseWriter, r *http.Request) (err error) {
wrapped := &statusCodeRecorder{ResponseWriter: w}
rec := &statusCodeRecorder{ResponseWriter: w}
wrapped := wrapResponseWriter(w, rec)
defer func() {
code := wrapped.getCode(err)
code := rec.getCode(err)
counter.With("status", statusCodeFamily(code)).Add(1)
}()

Expand Down Expand Up @@ -453,9 +454,11 @@ func PrometheusServerMetrics(_ string) Middleware {
}
serverActiveRequests.With(activeRequestLabels).Inc()

wrapped := &responseRecorder{ResponseWriter: w}
rec := &responseRecorder{ResponseWriter: w}
wrapped := wrapResponseWriter(w, rec)

defer func() {
code := errorCodeForMetrics(wrapped.responseCode, err)
code := errorCodeForMetrics(rec.responseCode, err)
success := isRequestSuccessful(code, err)

labels := prometheus.Labels{
Expand All @@ -465,7 +468,7 @@ func PrometheusServerMetrics(_ string) Middleware {
}
serverLatency.With(labels).Observe(time.Since(start).Seconds())
serverRequestSize.With(labels).Observe(float64(r.ContentLength))
serverResponseSize.With(labels).Observe(float64(wrapped.bytesWritten))
serverResponseSize.With(labels).Observe(float64(rec.bytesWritten))

totalRequestLabels := prometheus.Labels{
methodLabel: method,
Expand Down
156 changes: 156 additions & 0 deletions httpbp/middlewares_test.go
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
package httpbp_test

import (
"bufio"
"context"
"encoding/json"
"errors"
"net"
"net/http"
"net/http/httptest"
"sort"
"strings"
"testing"

"github.com/reddit/baseplate.go"
"github.com/reddit/baseplate.go/ecinterface"
"github.com/reddit/baseplate.go/httpbp"
"github.com/reddit/baseplate.go/log"
Expand Down Expand Up @@ -323,3 +326,156 @@ func TestSupportedMethods(t *testing.T) {
)
}
}

func TestMiddlewareResponseWrapping(t *testing.T) {
store := newSecretsStore(t)
defer store.Close()

bp := baseplate.NewTestBaseplate(baseplate.NewTestBaseplateArgs{
Config: baseplate.Config{Addr: ":8080"},
Store: store,
EdgeContextImpl: ecinterface.Mock(),
})

args := httpbp.ServerArgs{
Baseplate: bp,
Middlewares: []httpbp.Middleware{
func(name string, next httpbp.HandlerFunc) httpbp.HandlerFunc {
return func(ctx context.Context, w http.ResponseWriter, r *http.Request) error {
if flusher, isFlusher := w.(http.Flusher); isFlusher {
flusher.Flush()
}

next(ctx, w, r)
return nil
}
},
func(name string, next httpbp.HandlerFunc) httpbp.HandlerFunc {
return func(ctx context.Context, w http.ResponseWriter, r *http.Request) error {
if hijacker, isHijacker := w.(http.Hijacker); isHijacker {
hijacker.Hijack()
}

next(ctx, w, r)
return nil
}
},
func(name string, next httpbp.HandlerFunc) httpbp.HandlerFunc {
return func(ctx context.Context, w http.ResponseWriter, r *http.Request) error {
if pusher, isPusher := w.(http.Pusher); isPusher {
pusher.Push("target", &http.PushOptions{})
}

next(ctx, w, r)
return nil
}
},
},
Endpoints: map[httpbp.Pattern]httpbp.Endpoint{
"/test": {
Name: "test",
Methods: []string{http.MethodGet},
Handle: func(ctx context.Context, w http.ResponseWriter, r *http.Request) error {
w.Write([]byte("endpoint"))
return nil
},
},
},
}

// register our middleware to the EndpointRegistry
args, err := args.SetupEndpoints()

if err != nil {
t.Fatal(err)
}

t.Run("non-flushable-non-hijackable", func(tt *testing.T) {
type baseResponseWriter struct {
http.ResponseWriter
}

r := httptest.NewRequest(http.MethodGet, "/test", nil)
inner := httptest.NewRecorder()
w := baseResponseWriter{inner}
args.EndpointRegistry.ServeHTTP(w, r)

if inner.Flushed {
tt.Error("expected response to not be flushed")
}
})

// Test the a flushable response
t.Run("flushable", func(tt *testing.T) {
r := httptest.NewRequest(http.MethodGet, "/test", nil)
w := httptest.NewRecorder()
args.EndpointRegistry.ServeHTTP(w, r)

if !w.Flushed {
tt.Error("expected http response to be flushed")
}
})

t.Run("hijackable", func(tt *testing.T) {
r := httptest.NewRequest(http.MethodGet, "/test", nil)
w := &hijackableResponseRecorder{httptest.NewRecorder(), false}
args.EndpointRegistry.ServeHTTP(w, r)

if !w.Hijacked {
tt.Error("expected http response to be hijacked")
}
})

t.Run("pushable", func(tt *testing.T) {
r := httptest.NewRequest(http.MethodGet, "/test", nil)
w := &pushableResponseRecorder{httptest.NewRecorder(), false}
args.EndpointRegistry.ServeHTTP(w, r)

if !w.Pushed {
tt.Error("expected http response to be pushed")
}
})

t.Run("hijackable-flushable", func(tt *testing.T) {
type hijackableFlushableRecorder struct {
hijackableResponseRecorder
http.Flusher
}

r := httptest.NewRequest(http.MethodGet, "/test", nil)
inner := httptest.NewRecorder()
w := &hijackableFlushableRecorder{
hijackableResponseRecorder{inner, false},
inner,
}
args.EndpointRegistry.ServeHTTP(w, r)

if !w.Hijacked {
tt.Error("expected http response to be hijacked")
}

if !inner.Flushed {
tt.Error("expected http response to be flushed")
}
})
}

type hijackableResponseRecorder struct {
http.ResponseWriter
Hijacked bool
}

func (h *hijackableResponseRecorder) Hijack() (net.Conn, *bufio.ReadWriter, error) {
h.Hijacked = true
return nil, nil, nil
}

type pushableResponseRecorder struct {
http.ResponseWriter
Pushed bool
}

func (p *pushableResponseRecorder) Push(target string, opts *http.PushOptions) error {
p.Pushed = true
return nil
}
76 changes: 76 additions & 0 deletions httpbp/response_wrappers.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
package httpbp

import (
"net/http"
)

type optionalResponseWriter uint64

const (
flusher optionalResponseWriter = 1 << iota
hijacker
pusher
)

func wrapResponseWriter(orig, wrapped http.ResponseWriter) http.ResponseWriter {
var w optionalResponseWriter
f, isFlusher := orig.(http.Flusher)
if isFlusher {
w |= flusher
}
h, isHijacker := orig.(http.Hijacker)
if isHijacker {
w |= hijacker
}
p, isPusher := orig.(http.Pusher)
if isPusher {
w |= pusher
}

switch w {
case 0:
return wrapped
case flusher:
return struct {
http.ResponseWriter
http.Flusher
}{wrapped, f}
case hijacker:
return struct {
http.ResponseWriter
http.Hijacker
}{wrapped, h}
case flusher | hijacker:
return struct {
http.ResponseWriter
http.Flusher
http.Hijacker
}{wrapped, f, h}
case pusher:
return struct {
http.ResponseWriter
http.Pusher
}{wrapped, p}
case flusher | pusher:
return struct {
http.ResponseWriter
http.Flusher
http.Pusher
}{wrapped, f, p}
case hijacker | pusher:
return struct {
http.ResponseWriter
http.Hijacker
http.Pusher
}{wrapped, h, p}
case flusher | hijacker | pusher:
return struct {
http.ResponseWriter
http.Flusher
http.Hijacker
http.Pusher
}{wrapped, f, h, p}
}

return wrapped
}

0 comments on commit fae00cb

Please sign in to comment.