From 8069fe56546c2cdcfbeb5290f5c6130590927308 Mon Sep 17 00:00:00 2001 From: Emmanuel T Odeke Date: Thu, 29 Mar 2018 20:07:35 -0600 Subject: [PATCH] Proxy (http.Hijacker).Hijack() calls to Handler.responseWriter in ochttp (#648) Fixes #642 Ensure that plugin/ochttp.Handler proxies (http.Hijacker).Hijack() calls to its responseWriter that we captured. Also add tests to ensure that the behavior with the Go standard library is preserved that is: a) HTTP/1.X connections successfully can use (http.Hijacker).Hijack() b) HTTP/2.X connections are incompatible with (http.Hijacker).Hijack() because with HTTP/2.X, multiple requests can be sent on the same connection thus it doesn't make sense to hijack a connection. The standard library also panics and enforces this behavior. Note that both a) and b) are only enforced at runtime, like we do. We don't need to do any work, except solidify and lock this behavior in our tests. As an advantage, we can now successfully use ochttp.Handler with websockets! If the underlying ResponseWriter doesn't implement http.Hijacker, return an error signifying this. This is consensus after offline discussions with @ramonza and @rakyll, given that other packages might pass in their custom response writers not just net/http. Also make the initial Hijacker tests integration tests and spell out the actual basic testing in an added and small unit test. --- plugin/ochttp/server.go | 15 +++- plugin/ochttp/server_test.go | 157 +++++++++++++++++++++++++++++++++++ 2 files changed, 171 insertions(+), 1 deletion(-) diff --git a/plugin/ochttp/server.go b/plugin/ochttp/server.go index b402a65f3..a92c7c1ed 100644 --- a/plugin/ochttp/server.go +++ b/plugin/ochttp/server.go @@ -15,7 +15,10 @@ package ochttp import ( + "bufio" "context" + "errors" + "net" "net/http" "strconv" "sync" @@ -65,7 +68,6 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { defer traceEnd() w, statsEnd = h.startStats(w, r) defer statsEnd() - handler := h.Handler if handler == nil { handler = http.DefaultServeMux @@ -140,6 +142,17 @@ type trackingResponseWriter struct { } var _ http.ResponseWriter = (*trackingResponseWriter)(nil) +var _ http.Hijacker = (*trackingResponseWriter)(nil) + +var errHijackerUnimplemented = errors.New("ResponseWriter does not implement http.Hijacker") + +func (t *trackingResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + hj, ok := t.writer.(http.Hijacker) + if !ok { + return nil, nil, errHijackerUnimplemented + } + return hj.Hijack() +} func (t *trackingResponseWriter) end() { t.endOnce.Do(func() { diff --git a/plugin/ochttp/server_test.go b/plugin/ochttp/server_test.go index bdfe8b0b0..7c1fee660 100644 --- a/plugin/ochttp/server_test.go +++ b/plugin/ochttp/server_test.go @@ -1,11 +1,19 @@ package ochttp import ( + "bufio" "bytes" + "crypto/tls" + "fmt" + "io/ioutil" + "net" "net/http" "net/http/httptest" + "strings" "testing" + "golang.org/x/net/http2" + "go.opencensus.io/stats/view" "go.opencensus.io/trace" ) @@ -116,3 +124,152 @@ func TestHandlerStatsCollection(t *testing.T) { } } } + +type testResponseWriterHijacker struct { + httptest.ResponseRecorder +} + +func (trw *testResponseWriterHijacker) Hijack() (net.Conn, *bufio.ReadWriter, error) { + return nil, nil, nil +} + +func TestUnitTestHandlerProxiesHijack(t *testing.T) { + tests := []struct { + w http.ResponseWriter + wantErr string + }{ + {httptest.NewRecorder(), "ResponseWriter does not implement http.Hijacker"}, + {nil, "ResponseWriter does not implement http.Hijacker"}, + {new(testResponseWriterHijacker), ""}, + } + + for i, tt := range tests { + tw := &trackingResponseWriter{writer: tt.w} + conn, buf, err := tw.Hijack() + if tt.wantErr != "" { + if err == nil || !strings.Contains(err.Error(), tt.wantErr) { + t.Errorf("#%d got error (%v) want error substring (%q)", i, err, tt.wantErr) + } + if conn != nil { + t.Errorf("#%d inconsistent state got non-nil conn (%v)", i, conn) + } + if buf != nil { + t.Errorf("#%d inconsistent state got non-nil buf (%v)", i, buf) + } + continue + } + + if err != nil { + t.Errorf("#%d got unexpected error %v", i, err) + } + } +} + +// Integration test with net/http to ensure that our Handler proxies to its +// response the call to (http.Hijack).Hijacker() and that that successfully +// passes with HTTP/1.1 connections. See Issue #642 +func TestHandlerProxiesHijack_HTTP1(t *testing.T) { + cst := httptest.NewServer(&Handler{ + Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var writeMsg func(string) + defer func() { + err := recover() + writeMsg(fmt.Sprintf("Proto=%s\npanic=%v", r.Proto, err != nil)) + }() + conn, _, _ := w.(http.Hijacker).Hijack() + writeMsg = func(msg string) { + fmt.Fprintf(conn, "%s 200\nContentLength: %d", r.Proto, len(msg)) + fmt.Fprintf(conn, "\r\n\r\n%s", msg) + conn.Close() + } + }), + }) + defer cst.Close() + + testCases := []struct { + name string + tr *http.Transport + want string + }{ + { + name: "http1-transport", + tr: new(http.Transport), + want: "Proto=HTTP/1.1\npanic=false", + }, + { + name: "http2-transport", + tr: func() *http.Transport { + tr := new(http.Transport) + http2.ConfigureTransport(tr) + return tr + }(), + want: "Proto=HTTP/1.1\npanic=false", + }, + } + + for _, tc := range testCases { + c := &http.Client{Transport: &Transport{Base: tc.tr}} + res, err := c.Get(cst.URL) + if err != nil { + t.Errorf("(%s) unexpected error %v", tc.name, err) + continue + } + blob, _ := ioutil.ReadAll(res.Body) + res.Body.Close() + if g, w := string(blob), tc.want; g != w { + t.Errorf("(%s) got = %q; want = %q", tc.name, g, w) + } + } +} + +// Integration test with net/http, x/net/http2 to ensure that our Handler proxies +// to its response the call to (http.Hijack).Hijacker() and that that crashes +// since http.Hijacker and HTTP/2.0 connections are incompatible, but the +// detection is only at runtime and ensure that we can stream and flush to the +// connection even after invoking Hijack(). See Issue #642. +func TestHandlerProxiesHijack_HTTP2(t *testing.T) { + cst := httptest.NewUnstartedServer(&Handler{ + Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, _, err := w.(http.Hijacker).Hijack() + if conn != nil { + data := fmt.Sprintf("Surprisingly got the Hijacker() Proto: %s", r.Proto) + fmt.Fprintf(conn, "%s 200\nContent-Length:%d\r\n\r\n%s", r.Proto, len(data), data) + conn.Close() + return + } + + switch { + case err == nil: + fmt.Fprintf(w, "Unexpectedly did not encounter an error!") + default: + fmt.Fprintf(w, "Unexpected error: %v", err) + case strings.Contains(err.(error).Error(), "Hijack"): + // Confirmed HTTP/2.0, let's stream to it + for i := 0; i < 5; i++ { + fmt.Fprintf(w, "%d\n", i) + w.(http.Flusher).Flush() + } + } + }), + }) + cst.TLS = &tls.Config{NextProtos: []string{"h2"}} + cst.StartTLS() + defer cst.Close() + + if wantPrefix := "https://"; !strings.HasPrefix(cst.URL, wantPrefix) { + t.Fatalf("URL got = %q wantPrefix = %q", cst.URL, wantPrefix) + } + + tr := &http.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}} + http2.ConfigureTransport(tr) + c := &http.Client{Transport: tr} + res, err := c.Get(cst.URL) + if err != nil { + t.Fatalf("Unexpected error %v", err) + } + blob, _ := ioutil.ReadAll(res.Body) + res.Body.Close() + if g, w := string(blob), "0\n1\n2\n3\n4\n"; g != w { + t.Errorf("got = %q; want = %q", g, w) + } +}