Skip to content
This repository has been archived by the owner on Feb 15, 2019. It is now read-only.

Commit

Permalink
Transfer TLSConfig from roundtripper in websocket dialer
Browse files Browse the repository at this point in the history
  • Loading branch information
juliens authored and traefiker committed Sep 7, 2017
1 parent 484d48a commit 6c94d28
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 34 deletions.
1 change: 1 addition & 0 deletions forward/fwd.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ func New(setters ...optSetter) (*Forwarder, error) {
if f.httpForwarder.roundTripper == nil {
f.httpForwarder.roundTripper = http.DefaultTransport
}
f.websocketForwarder.TLSClientConfig = f.httpForwarder.roundTripper.(*http.Transport).TLSClientConfig
if f.httpForwarder.rewriter == nil {
h, err := os.Hostname()
if err != nil {
Expand Down
132 changes: 98 additions & 34 deletions forward/fwd_websocket_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package forward

import (
"bufio"
"crypto/tls"
"fmt"
"net"
"net/http"
Expand Down Expand Up @@ -75,13 +76,7 @@ func (s *FwdSuite) TestWebsocketServerWithoutCheckOrigin(c *C) {
}))
defer srv.Close()

proxy := testutils.NewHandler(func(w http.ResponseWriter, req *http.Request) {
path := req.URL.Path // keep the original path
// Set new backend URL
req.URL = testutils.ParseURI(srv.URL)
req.URL.Path = path
f.ServeHTTP(w, req)
})
proxy := createProxyWithForwarder(f, srv.URL)
defer proxy.Close()

proxyAddr := proxy.Listener.Addr().String()
Expand Down Expand Up @@ -119,13 +114,7 @@ func (s *FwdSuite) TestWebsocketRequestWithOrigin(c *C) {
}))
defer srv.Close()

proxy := testutils.NewHandler(func(w http.ResponseWriter, req *http.Request) {
path := req.URL.Path // keep the original path
// Set new backend URL
req.URL = testutils.ParseURI(srv.URL)
req.URL.Path = path
f.ServeHTTP(w, req)
})
proxy := createProxyWithForwarder(f, srv.URL)
defer proxy.Close()

proxyAddr := proxy.Listener.Addr().String()
Expand Down Expand Up @@ -174,13 +163,7 @@ func (s *FwdSuite) TestWebsocketRequestWithQueryParams(c *C) {
}))
defer srv.Close()

proxy := testutils.NewHandler(func(w http.ResponseWriter, req *http.Request) {
path := req.URL.Path // keep the original path
// Set new backend URL
req.URL = testutils.ParseURI(srv.URL)
req.URL.Path = path
f.ServeHTTP(w, req)
})
proxy := createProxyWithForwarder(f, srv.URL)
defer proxy.Close()

proxyAddr := proxy.Listener.Addr().String()
Expand Down Expand Up @@ -220,13 +203,7 @@ func (s *FwdSuite) TestWebsocketRequestWithEncodedChar(c *C) {
}))
defer srv.Close()

proxy := testutils.NewHandler(func(w http.ResponseWriter, req *http.Request) {
path := req.URL.Path // keep the original path
// Set new backend URL
req.URL = testutils.ParseURI(srv.URL)
req.URL.Path = path
f.ServeHTTP(w, req)
})
proxy := createProxyWithForwarder(f, srv.URL)
defer proxy.Close()

proxyAddr := proxy.Listener.Addr().String()
Expand Down Expand Up @@ -339,21 +316,108 @@ func (s *FwdSuite) TestForwardsWebsocketTraffic(c *C) {
})
defer srv.Close()

proxy := testutils.NewHandler(func(w http.ResponseWriter, req *http.Request) {
proxy := createProxyWithForwarder(f, srv.URL)
defer proxy.Close()

proxyAddr := proxy.Listener.Addr().String()
resp, err := newWebsocketRequest(
withServer(proxyAddr),
withPath("/ws"),
withData("echo"),
).send()

c.Assert(err, IsNil)
c.Assert(resp, Equals, "ok")
}

func createTLSWebsocketServer() *httptest.Server {
upgrader := gorillawebsocket.Upgrader{}
srv := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
return
}
defer conn.Close()
for {
mt, message, err := conn.ReadMessage()
if err != nil {
break
}
err = conn.WriteMessage(mt, message)
if err != nil {
break
}
}
}))
return srv
}

func createProxyWithForwarder(forwarder *Forwarder, URL string) *httptest.Server {
return testutils.NewHandler(func(w http.ResponseWriter, req *http.Request) {
path := req.URL.Path // keep the original path
// Set new backend URL
req.URL = testutils.ParseURI(srv.URL)
req.URL = testutils.ParseURI(URL)
req.URL.Path = path

f.ServeHTTP(w, req)
forwarder.ServeHTTP(w, req)
})
defer proxy.Close()
}

func (s *FwdSuite) TestWebsocketTransferTLSConfig(c *C) {
srv := createTLSWebsocketServer()
defer srv.Close()

forwarderWithoutTLSConfig, err := New()
c.Assert(err, IsNil)

proxyWithoutTLSConfig := createProxyWithForwarder(forwarderWithoutTLSConfig, srv.URL)
defer proxyWithoutTLSConfig.Close()

proxyAddr := proxyWithoutTLSConfig.Listener.Addr().String()

_, err = newWebsocketRequest(
withServer(proxyAddr),
withPath("/ws"),
withData("ok"),
).send()

c.Assert(err, NotNil)
c.Assert(err, ErrorMatches, "bad status")

transport := &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
}
forwarderWithTLSConfig, err := New(RoundTripper(transport))
c.Assert(err, IsNil)

proxyWithTLSConfig := createProxyWithForwarder(forwarderWithTLSConfig, srv.URL)
defer proxyWithTLSConfig.Close()

proxyAddr = proxyWithTLSConfig.Listener.Addr().String()

proxyAddr := proxy.Listener.Addr().String()
resp, err := newWebsocketRequest(
withServer(proxyAddr),
withPath("/ws"),
withData("echo"),
withData("ok"),
).send()

c.Assert(err, IsNil)
c.Assert(resp, Equals, "ok")

http.DefaultTransport.(*http.Transport).TLSClientConfig = &tls.Config{InsecureSkipVerify: true}

forwarderWithTLSConfigFromDefaultTransport, err := New()
c.Assert(err, IsNil)

proxyWithTLSConfigFromDefaultTransport := createProxyWithForwarder(forwarderWithTLSConfigFromDefaultTransport, srv.URL)
defer proxyWithTLSConfig.Close()

proxyAddr = proxyWithTLSConfigFromDefaultTransport.Listener.Addr().String()

resp, err = newWebsocketRequest(
withServer(proxyAddr),
withPath("/ws"),
withData("ok"),
).send()

c.Assert(err, IsNil)
Expand Down

0 comments on commit 6c94d28

Please sign in to comment.