Skip to content
This repository has been archived by the owner on Feb 15, 2019. It is now read-only.

WebSocket use gorilla/websocket #40

Merged
merged 3 commits into from
Nov 16, 2017
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions .travis.yml
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
language: go

go:
- 1.6.x
- 1.7.x
- 1.8.x
- 1.9.x
- master
Expand Down
130 changes: 72 additions & 58 deletions forward/fwd.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@ import (
"time"

"crypto/tls"
"net"
"net/http/httputil"
"reflect"

"github.com/gorilla/websocket"
log "github.com/sirupsen/logrus"
"github.com/vulcand/oxy/utils"
)
Expand Down Expand Up @@ -165,6 +165,11 @@ func New(setters ...optSetter) (*Forwarder, error) {
if f.errHandler == nil {
f.errHandler = utils.DefaultHandler
}

if f.tlsClientConfig == nil {
f.tlsClientConfig = f.httpForwarder.roundTripper.(*http.Transport).TLSClientConfig
}

return f, nil
}

Expand Down Expand Up @@ -286,57 +291,58 @@ func (f *httpForwarder) serveWebSocket(w http.ResponseWriter, req *http.Request,
}

outReq := f.copyWebSocketRequest(req)
host := outReq.URL.Host

// if host does not specify a port, use the default http port
if !strings.Contains(host, ":") {
if outReq.URL.Scheme == "wss" {
host = host + ":443"
} else {
host = host + ":80"
}
}

var targetConn net.Conn
var err error

dialer := websocket.DefaultDialer
if outReq.URL.Scheme == "wss" && f.tlsClientConfig != nil {
f.log.Debugf("vulcand/oxy/forward/websocket: Dialing secure (tls) tcp connection to host %s with TLS Client Config %v", host, f.tlsClientConfig)
targetConn, err = tls.Dial("tcp", host, f.tlsClientConfig)
} else {
f.log.Debugf("vulcand/oxy/forward/websocket: Dialing insecure (non-tls) tcp connection to host %s", host)
targetConn, err = net.Dial("tcp", host)
dialer.TLSClientConfig = f.tlsClientConfig.Clone()
//WebSocket is only in http/1.1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please add a space after //

dialer.TLSClientConfig.NextProtos = []string{"http/1.1"}
}

targetConn, resp, err := dialer.Dial(outReq.URL.String(), outReq.Header)
if err != nil {
f.log.Errorf("vulcand/oxy/forward/websocket: Error dialing `%v`: %v", host, err)
ctx.errHandler.ServeHTTP(w, req, err)
return
}
hijacker, ok := w.(http.Hijacker)
if !ok {
f.log.Errorf("vulcand/oxy/forward/websocket: Unable to hijack the connection: does not implement http.Hijacker. ResponseWriter implementation type: %v", reflect.TypeOf(w))
ctx.errHandler.ServeHTTP(w, req, err)
if resp == nil {
ctx.errHandler.ServeHTTP(w, req, err)
} else {
log.Errorf("vulcand/oxy/forward/websocket: Error dialing %q: %v with resp: %d %s", outReq.Host, err, resp.StatusCode, resp.Status)
hijacker, ok := w.(http.Hijacker)
if !ok {
log.Errorf("vulcand/oxy/forward/websocket: %s can not be hijack", reflect.TypeOf(w))
ctx.errHandler.ServeHTTP(w, req, err)
return
}

conn, _, err := hijacker.Hijack()
if err != nil {
log.Errorf("vulcand/oxy/forward/websocket: Failed to hijack responseWriter")
ctx.errHandler.ServeHTTP(w, req, err)
return
}
defer conn.Close()

err = resp.Write(conn)
if err != nil {
log.Errorf("vulcand/oxy/forward/websocket: Failed to forward response")
ctx.errHandler.ServeHTTP(w, req, err)
return
}
}
return
}
underlyingConn, _, err := hijacker.Hijack()

//Only the targetConn choose to CheckOrigin or not
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please add a space after //

upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool {
return true
}}

utils.RemoveHeaders(resp.Header, WebsocketUpgradeHeaders...)
underlyingConn, err := upgrader.Upgrade(w, req, resp.Header)
if err != nil {
f.log.Errorf("vulcand/oxy/forward/websocket: Unable to hijack the connection: %v", err)
ctx.errHandler.ServeHTTP(w, req, err)
log.Errorf("vulcand/oxy/forward/websocket: Error while upgrading connection : %v", err)
return
}
// it is now caller's responsibility to Close the underlying connection
defer underlyingConn.Close()
defer targetConn.Close()

f.log.Infof("vulcand/oxy/forward/websocket: Writing outgoing Websocket request to target connection: %+v", outReq)

// write the modified incoming request to the dialed connection
if err = outReq.Write(targetConn); err != nil {
f.log.Errorf("vulcand/oxy/forward/websocket: Unable to copy request to target: %v", err)
ctx.errHandler.ServeHTTP(w, req, err)
return
}
errc := make(chan error, 2)
replicate := func(dst io.Writer, src io.Reader, dstName string, srcName string) {
_, err := io.Copy(dst, src)
Expand All @@ -347,21 +353,27 @@ func (f *httpForwarder) serveWebSocket(w http.ResponseWriter, req *http.Request,
}
errc <- err
}
go replicate(targetConn, underlyingConn, "backend", "client")
go replicate(underlyingConn, targetConn, "client", "backend")
err = <-errc // One goroutine complete
f.log.Infof("vulcand/oxy/forward/websocket: first proxying connection closed: %v", err)
err = <-errc // Both goroutines complete
f.log.Infof("vulcand/oxy/forward/websocket: second proxying connection closed: %v", err)

go replicate(targetConn.UnderlyingConn(), underlyingConn.UnderlyingConn(), "backend", "client")

// Try to read the first message
t, msg, err := targetConn.ReadMessage()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please rename t to msgType

if err != nil {
log.Errorf("vulcand/oxy/forward/websocket: Couldn't read first message : %v", err)
} else {
underlyingConn.WriteMessage(t, msg)
}

go replicate(underlyingConn.UnderlyingConn(), targetConn.UnderlyingConn(), "client", "backend")
<-errc
}

// copyRequest makes a copy of the specified request.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please fix your comment // copyWebSocketRequest makes a copy of the specified web socket request

func (f *httpForwarder) copyWebSocketRequest(req *http.Request) (outReq *http.Request) {
outReq = new(http.Request)
*outReq = *req
outReq.URL = utils.CopyURL(req.URL)
*outReq = *req // includes shallow copies of maps, but we handle this below

//a good working default
outReq.URL = utils.CopyURL(req.URL)
outReq.URL.Scheme = req.URL.Scheme

//sometimes backends might be registered as HTTP/HTTPS servers so translate URLs to websocket URLs.
Expand All @@ -372,24 +384,26 @@ func (f *httpForwarder) copyWebSocketRequest(req *http.Request) (outReq *http.Re
outReq.URL.Scheme = "ws"
}

outReq.URL.Host = req.URL.Host
outReq.URL.Opaque = req.RequestURI

// Do not pass client Host header unless optsetter PassHostHeader is set.
if !f.passHost {
outReq.Host = req.Host
if requestURI, err := url.ParseRequestURI(outReq.RequestURI); err == nil {
if requestURI.RawPath != "" {
outReq.URL.Path = requestURI.RawPath
} else {
outReq.URL.Path = requestURI.Path
}
outReq.URL.RawQuery = requestURI.RawQuery
}

// Overwrite close flag so we can keep persistent connection for the backend servers
outReq.Close = false
outReq.URL.Host = req.URL.Host

outReq.Header = make(http.Header)
//gorilla websocket use this header to set the request.Host tested in checkSameOrigin
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please add space after //

outReq.Header.Set("Host", outReq.Host)
utils.CopyHeaders(outReq.Header, req.Header)
utils.RemoveHeaders(outReq.Header, WebsocketDialHeaders...)

if f.rewriter != nil {
f.rewriter.Rewrite(outReq)
}

return outReq
}

Expand Down
82 changes: 0 additions & 82 deletions forward/fwd_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package forward

import (
"fmt"
"net"
"net/http"
"net/http/httptest"
"strings"
Expand All @@ -12,7 +11,6 @@ import (
"github.com/vulcand/oxy/testutils"
"github.com/vulcand/oxy/utils"

"golang.org/x/net/websocket"
. "gopkg.in/check.v1"
)

Expand Down Expand Up @@ -279,83 +277,3 @@ func (s *FwdSuite) TestChunkedResponseConversion(c *C) {
c.Assert(re.StatusCode, Equals, http.StatusOK)
c.Assert(re.Header.Get("Content-Length"), Equals, fmt.Sprintf("%d", len("testtest1test2")))
}

func (s *FwdSuite) TestDetectsWebsocketRequest(c *C) {
mux := http.NewServeMux()
mux.Handle("/ws", websocket.Handler(func(conn *websocket.Conn) {
conn.Write([]byte("ok"))
conn.Close()
}))
srv := testutils.NewHandler(func(w http.ResponseWriter, req *http.Request) {
websocketRequest := IsWebsocketRequest(req)
c.Assert(websocketRequest, Equals, true)
mux.ServeHTTP(w, req)
})
defer srv.Close()

serverAddr := srv.Listener.Addr().String()
resp, err := sendWebsocketRequest(serverAddr, "/ws", "echo", c)
c.Assert(err, IsNil)
c.Assert(resp, Equals, "ok")
}

func (s *FwdSuite) TestForwardsWebsocketTraffic(c *C) {
f, err := New()
c.Assert(err, IsNil)

mux := http.NewServeMux()
mux.Handle("/ws", websocket.Handler(func(conn *websocket.Conn) {
conn.Write([]byte("ok"))
conn.Close()
}))
srv := testutils.NewHandler(func(w http.ResponseWriter, req *http.Request) {
mux.ServeHTTP(w, req)
})
defer srv.Close()

proxy := testutils.NewHandler(func(w http.ResponseWriter, req *http.Request) {
path := req.URL.Path // keep the original path
// Set new backend URL
req.URL = testutils.ParseURI(srv.URL)
req.URL.Path = path
f.ServeHTTP(w, req)
})
defer proxy.Close()

proxyAddr := proxy.Listener.Addr().String()
resp, err := sendWebsocketRequest(proxyAddr, "/ws", "echo", c)
c.Assert(err, IsNil)
c.Assert(resp, Equals, "ok")
}

const dialTimeout = time.Second

func sendWebsocketRequest(serverAddr, path, data string, c *C) (received string, err error) {
client, err := net.DialTimeout("tcp", serverAddr, dialTimeout)
if err != nil {
return "", err
}
config := newWebsocketConfig(serverAddr, path)
conn, err := websocket.NewClient(config, client)
if err != nil {
return "", err
}
defer conn.Close()
if _, err := conn.Write([]byte(data)); err != nil {
return "", err
}
var msg = make([]byte, 512)
var n int
n, err = conn.Read(msg)
if err != nil {
return "", err
}

received = string(msg[:n])
return received, nil
}

func newWebsocketConfig(serverAddr, path string) *websocket.Config {
config, _ := websocket.NewConfig(fmt.Sprintf("ws://%s%s", serverAddr, path), "http://localhost")
return config
}
Loading