Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fix] httpbp middleware doesn't flush chunked responses #573

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions httpbp/middlewares.go
Original file line number Diff line number Diff line change
Expand Up @@ -394,9 +394,10 @@ func recordStatusCode(counters counterGenerator) Middleware {
return func(name string, next HandlerFunc) HandlerFunc {
counter := counters.Counter("baseplate.http." + name + ".response")
return func(ctx context.Context, w http.ResponseWriter, r *http.Request) (err error) {
wrapped := &statusCodeRecorder{ResponseWriter: w}
rec := &statusCodeRecorder{ResponseWriter: w}
wrapped := wrapResponseWriter(w, rec)
defer func() {
code := wrapped.getCode(err)
code := rec.getCode(err)
counter.With("status", statusCodeFamily(code)).Add(1)
}()

Expand Down Expand Up @@ -453,9 +454,11 @@ func PrometheusServerMetrics(_ string) Middleware {
}
serverActiveRequests.With(activeRequestLabels).Inc()

wrapped := &responseRecorder{ResponseWriter: w}
rec := &responseRecorder{ResponseWriter: w}
wrapped := wrapResponseWriter(w, rec)

defer func() {
code := errorCodeForMetrics(wrapped.responseCode, err)
code := errorCodeForMetrics(rec.responseCode, err)
success := isRequestSuccessful(code, err)

labels := prometheus.Labels{
Expand All @@ -465,7 +468,7 @@ func PrometheusServerMetrics(_ string) Middleware {
}
serverLatency.With(labels).Observe(time.Since(start).Seconds())
serverRequestSize.With(labels).Observe(float64(r.ContentLength))
serverResponseSize.With(labels).Observe(float64(wrapped.bytesWritten))
serverResponseSize.With(labels).Observe(float64(rec.bytesWritten))

totalRequestLabels := prometheus.Labels{
methodLabel: method,
Expand Down
156 changes: 156 additions & 0 deletions httpbp/middlewares_test.go
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
package httpbp_test

import (
"bufio"
"context"
"encoding/json"
"errors"
"net"
"net/http"
"net/http/httptest"
"sort"
"strings"
"testing"

"github.com/reddit/baseplate.go"
"github.com/reddit/baseplate.go/ecinterface"
"github.com/reddit/baseplate.go/httpbp"
"github.com/reddit/baseplate.go/log"
Expand Down Expand Up @@ -323,3 +326,156 @@ func TestSupportedMethods(t *testing.T) {
)
}
}

func TestMiddlewareResponseWrapping(t *testing.T) {
store := newSecretsStore(t)
defer store.Close()

bp := baseplate.NewTestBaseplate(baseplate.NewTestBaseplateArgs{
Config: baseplate.Config{Addr: ":8080"},
Store: store,
EdgeContextImpl: ecinterface.Mock(),
})

args := httpbp.ServerArgs{
Baseplate: bp,
Middlewares: []httpbp.Middleware{
func(name string, next httpbp.HandlerFunc) httpbp.HandlerFunc {
return func(ctx context.Context, w http.ResponseWriter, r *http.Request) error {
if flusher, isFlusher := w.(http.Flusher); isFlusher {
flusher.Flush()
}

next(ctx, w, r)
return nil
}
},
func(name string, next httpbp.HandlerFunc) httpbp.HandlerFunc {
return func(ctx context.Context, w http.ResponseWriter, r *http.Request) error {
if hijacker, isHijacker := w.(http.Hijacker); isHijacker {
hijacker.Hijack()
}

next(ctx, w, r)
return nil
}
},
func(name string, next httpbp.HandlerFunc) httpbp.HandlerFunc {
return func(ctx context.Context, w http.ResponseWriter, r *http.Request) error {
if pusher, isPusher := w.(http.Pusher); isPusher {
pusher.Push("target", &http.PushOptions{})
}

next(ctx, w, r)
return nil
}
},
},
Endpoints: map[httpbp.Pattern]httpbp.Endpoint{
"/test": {
Name: "test",
Methods: []string{http.MethodGet},
Handle: func(ctx context.Context, w http.ResponseWriter, r *http.Request) error {
w.Write([]byte("endpoint"))
return nil
},
},
},
}

// register our middleware to the EndpointRegistry
args, err := args.SetupEndpoints()

if err != nil {
t.Fatal(err)
}

t.Run("non-flushable-non-hijackable", func(tt *testing.T) {
type baseResponseWriter struct {
http.ResponseWriter
}

r := httptest.NewRequest(http.MethodGet, "/test", nil)
inner := httptest.NewRecorder()
w := baseResponseWriter{inner}
args.EndpointRegistry.ServeHTTP(w, r)

if inner.Flushed {
tt.Error("expected response to not be flushed")
}
})

// Test the a flushable response
t.Run("flushable", func(tt *testing.T) {
r := httptest.NewRequest(http.MethodGet, "/test", nil)
w := httptest.NewRecorder()
args.EndpointRegistry.ServeHTTP(w, r)

if !w.Flushed {
tt.Error("expected http response to be flushed")
}
})

t.Run("hijackable", func(tt *testing.T) {
r := httptest.NewRequest(http.MethodGet, "/test", nil)
w := &hijackableResponseRecorder{httptest.NewRecorder(), false}
args.EndpointRegistry.ServeHTTP(w, r)

if !w.Hijacked {
tt.Error("expected http response to be hijacked")
}
})

t.Run("pushable", func(tt *testing.T) {
r := httptest.NewRequest(http.MethodGet, "/test", nil)
w := &pushableResponseRecorder{httptest.NewRecorder(), false}
args.EndpointRegistry.ServeHTTP(w, r)

if !w.Pushed {
tt.Error("expected http response to be pushed")
}
})

t.Run("hijackable-flushable", func(tt *testing.T) {
type hijackableFlushableRecorder struct {
hijackableResponseRecorder
http.Flusher
}

r := httptest.NewRequest(http.MethodGet, "/test", nil)
inner := httptest.NewRecorder()
w := &hijackableFlushableRecorder{
hijackableResponseRecorder{inner, false},
inner,
}
args.EndpointRegistry.ServeHTTP(w, r)

if !w.Hijacked {
tt.Error("expected http response to be hijacked")
}

if !inner.Flushed {
tt.Error("expected http response to be flushed")
}
})
adamthesax marked this conversation as resolved.
Show resolved Hide resolved
}

type hijackableResponseRecorder struct {
http.ResponseWriter
Hijacked bool
}

func (h *hijackableResponseRecorder) Hijack() (net.Conn, *bufio.ReadWriter, error) {
h.Hijacked = true
return nil, nil, nil
}

type pushableResponseRecorder struct {
http.ResponseWriter
Pushed bool
}

func (p *pushableResponseRecorder) Push(target string, opts *http.PushOptions) error {
p.Pushed = true
return nil
}
76 changes: 76 additions & 0 deletions httpbp/response_wrappers.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
package httpbp

import (
"net/http"
)

type optionalResponseWriter uint64

const (
flusher optionalResponseWriter = 1 << iota
hijacker
pusher
)

func wrapResponseWriter(orig, wrapped http.ResponseWriter) http.ResponseWriter {
var w optionalResponseWriter
f, isFlusher := orig.(http.Flusher)
if isFlusher {
w |= flusher
}
h, isHijacker := orig.(http.Hijacker)
if isHijacker {
w |= hijacker
}
p, isPusher := orig.(http.Pusher)
if isPusher {
w |= pusher
}

switch w {
case 0:
return wrapped
case flusher:
return struct {
http.ResponseWriter
http.Flusher
}{wrapped, f}
case hijacker:
return struct {
http.ResponseWriter
http.Hijacker
}{wrapped, h}
case flusher | hijacker:
return struct {
http.ResponseWriter
http.Flusher
http.Hijacker
}{wrapped, f, h}
case pusher:
return struct {
http.ResponseWriter
http.Pusher
}{wrapped, p}
case flusher | pusher:
return struct {
http.ResponseWriter
http.Flusher
http.Pusher
}{wrapped, f, p}
case hijacker | pusher:
return struct {
http.ResponseWriter
http.Hijacker
http.Pusher
}{wrapped, h, p}
case flusher | hijacker | pusher:
return struct {
http.ResponseWriter
http.Flusher
http.Hijacker
http.Pusher
}{wrapped, f, h, p}
}

return wrapped
}