Skip to content
This repository has been archived by the owner on Jul 31, 2023. It is now read-only.

Commit

Permalink
Proxy (http.Hijacker).Hijack() calls to Handler.responseWriter in och…
Browse files Browse the repository at this point in the history
…ttp (#648)

Fixes #642

Ensure that plugin/ochttp.Handler proxies (http.Hijacker).Hijack()
calls to its responseWriter that we captured. Also add tests to
ensure that the behavior with the Go standard library is preserved
that is:
a) HTTP/1.X connections successfully can use (http.Hijacker).Hijack()
b) HTTP/2.X connections are incompatible with (http.Hijacker).Hijack()
because with HTTP/2.X, multiple requests can be sent on the same
connection thus it doesn't make sense to hijack a connection. The
standard library also panics and enforces this behavior.

Note that both a) and b) are only enforced at runtime, like we do.
We don't need to do any work, except solidify and lock this behavior
in our tests.

As an advantage, we can now successfully use ochttp.Handler with
websockets!

If the underlying ResponseWriter doesn't implement http.Hijacker,
return an error signifying this. This is consensus after offline
discussions with @Ramonza and @rakyll, given that other packages
might pass in their custom response writers not just net/http.

Also make the initial Hijacker tests integration tests and spell
out the actual basic testing in an added and small unit test.
  • Loading branch information
Emmanuel T Odeke authored and rakyll committed Mar 30, 2018
1 parent 26cab1e commit 8069fe5
Show file tree
Hide file tree
Showing 2 changed files with 171 additions and 1 deletion.
15 changes: 14 additions & 1 deletion plugin/ochttp/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@
package ochttp

import (
"bufio"
"context"
"errors"
"net"
"net/http"
"strconv"
"sync"
Expand Down Expand Up @@ -65,7 +68,6 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
defer traceEnd()
w, statsEnd = h.startStats(w, r)
defer statsEnd()

handler := h.Handler
if handler == nil {
handler = http.DefaultServeMux
Expand Down Expand Up @@ -140,6 +142,17 @@ type trackingResponseWriter struct {
}

var _ http.ResponseWriter = (*trackingResponseWriter)(nil)
var _ http.Hijacker = (*trackingResponseWriter)(nil)

var errHijackerUnimplemented = errors.New("ResponseWriter does not implement http.Hijacker")

func (t *trackingResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
hj, ok := t.writer.(http.Hijacker)
if !ok {
return nil, nil, errHijackerUnimplemented
}
return hj.Hijack()
}

func (t *trackingResponseWriter) end() {
t.endOnce.Do(func() {
Expand Down
157 changes: 157 additions & 0 deletions plugin/ochttp/server_test.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,19 @@
package ochttp

import (
"bufio"
"bytes"
"crypto/tls"
"fmt"
"io/ioutil"
"net"
"net/http"
"net/http/httptest"
"strings"
"testing"

"golang.org/x/net/http2"

"go.opencensus.io/stats/view"
"go.opencensus.io/trace"
)
Expand Down Expand Up @@ -116,3 +124,152 @@ func TestHandlerStatsCollection(t *testing.T) {
}
}
}

type testResponseWriterHijacker struct {
httptest.ResponseRecorder
}

func (trw *testResponseWriterHijacker) Hijack() (net.Conn, *bufio.ReadWriter, error) {
return nil, nil, nil
}

func TestUnitTestHandlerProxiesHijack(t *testing.T) {
tests := []struct {
w http.ResponseWriter
wantErr string
}{
{httptest.NewRecorder(), "ResponseWriter does not implement http.Hijacker"},
{nil, "ResponseWriter does not implement http.Hijacker"},
{new(testResponseWriterHijacker), ""},
}

for i, tt := range tests {
tw := &trackingResponseWriter{writer: tt.w}
conn, buf, err := tw.Hijack()
if tt.wantErr != "" {
if err == nil || !strings.Contains(err.Error(), tt.wantErr) {
t.Errorf("#%d got error (%v) want error substring (%q)", i, err, tt.wantErr)
}
if conn != nil {
t.Errorf("#%d inconsistent state got non-nil conn (%v)", i, conn)
}
if buf != nil {
t.Errorf("#%d inconsistent state got non-nil buf (%v)", i, buf)
}
continue
}

if err != nil {
t.Errorf("#%d got unexpected error %v", i, err)
}
}
}

// Integration test with net/http to ensure that our Handler proxies to its
// response the call to (http.Hijack).Hijacker() and that that successfully
// passes with HTTP/1.1 connections. See Issue #642
func TestHandlerProxiesHijack_HTTP1(t *testing.T) {
cst := httptest.NewServer(&Handler{
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var writeMsg func(string)
defer func() {
err := recover()
writeMsg(fmt.Sprintf("Proto=%s\npanic=%v", r.Proto, err != nil))
}()
conn, _, _ := w.(http.Hijacker).Hijack()
writeMsg = func(msg string) {
fmt.Fprintf(conn, "%s 200\nContentLength: %d", r.Proto, len(msg))
fmt.Fprintf(conn, "\r\n\r\n%s", msg)
conn.Close()
}
}),
})
defer cst.Close()

testCases := []struct {
name string
tr *http.Transport
want string
}{
{
name: "http1-transport",
tr: new(http.Transport),
want: "Proto=HTTP/1.1\npanic=false",
},
{
name: "http2-transport",
tr: func() *http.Transport {
tr := new(http.Transport)
http2.ConfigureTransport(tr)
return tr
}(),
want: "Proto=HTTP/1.1\npanic=false",
},
}

for _, tc := range testCases {
c := &http.Client{Transport: &Transport{Base: tc.tr}}
res, err := c.Get(cst.URL)
if err != nil {
t.Errorf("(%s) unexpected error %v", tc.name, err)
continue
}
blob, _ := ioutil.ReadAll(res.Body)
res.Body.Close()
if g, w := string(blob), tc.want; g != w {
t.Errorf("(%s) got = %q; want = %q", tc.name, g, w)
}
}
}

// Integration test with net/http, x/net/http2 to ensure that our Handler proxies
// to its response the call to (http.Hijack).Hijacker() and that that crashes
// since http.Hijacker and HTTP/2.0 connections are incompatible, but the
// detection is only at runtime and ensure that we can stream and flush to the
// connection even after invoking Hijack(). See Issue #642.
func TestHandlerProxiesHijack_HTTP2(t *testing.T) {
cst := httptest.NewUnstartedServer(&Handler{
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, _, err := w.(http.Hijacker).Hijack()
if conn != nil {
data := fmt.Sprintf("Surprisingly got the Hijacker() Proto: %s", r.Proto)
fmt.Fprintf(conn, "%s 200\nContent-Length:%d\r\n\r\n%s", r.Proto, len(data), data)
conn.Close()
return
}

switch {
case err == nil:
fmt.Fprintf(w, "Unexpectedly did not encounter an error!")
default:
fmt.Fprintf(w, "Unexpected error: %v", err)
case strings.Contains(err.(error).Error(), "Hijack"):
// Confirmed HTTP/2.0, let's stream to it
for i := 0; i < 5; i++ {
fmt.Fprintf(w, "%d\n", i)
w.(http.Flusher).Flush()
}
}
}),
})
cst.TLS = &tls.Config{NextProtos: []string{"h2"}}
cst.StartTLS()
defer cst.Close()

if wantPrefix := "https://"; !strings.HasPrefix(cst.URL, wantPrefix) {
t.Fatalf("URL got = %q wantPrefix = %q", cst.URL, wantPrefix)
}

tr := &http.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}}
http2.ConfigureTransport(tr)
c := &http.Client{Transport: tr}
res, err := c.Get(cst.URL)
if err != nil {
t.Fatalf("Unexpected error %v", err)
}
blob, _ := ioutil.ReadAll(res.Body)
res.Body.Close()
if g, w := string(blob), "0\n1\n2\n3\n4\n"; g != w {
t.Errorf("got = %q; want = %q", g, w)
}
}

0 comments on commit 8069fe5

Please sign in to comment.