diff --git a/.travis.yml b/.travis.yml index 903daec5..f6bf8479 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,8 +1,6 @@ language: go go: - - 1.6.x - - 1.7.x - 1.8.x - 1.9.x - master diff --git a/forward/fwd.go b/forward/fwd.go index adb7c8ae..2e99b8a6 100644 --- a/forward/fwd.go +++ b/forward/fwd.go @@ -13,10 +13,10 @@ import ( "time" "crypto/tls" - "net" "net/http/httputil" "reflect" + "github.com/gorilla/websocket" log "github.com/sirupsen/logrus" "github.com/vulcand/oxy/utils" ) @@ -165,6 +165,11 @@ func New(setters ...optSetter) (*Forwarder, error) { if f.errHandler == nil { f.errHandler = utils.DefaultHandler } + + if f.tlsClientConfig == nil { + f.tlsClientConfig = f.httpForwarder.roundTripper.(*http.Transport).TLSClientConfig + } + return f, nil } @@ -286,57 +291,58 @@ func (f *httpForwarder) serveWebSocket(w http.ResponseWriter, req *http.Request, } outReq := f.copyWebSocketRequest(req) - host := outReq.URL.Host - - // if host does not specify a port, use the default http port - if !strings.Contains(host, ":") { - if outReq.URL.Scheme == "wss" { - host = host + ":443" - } else { - host = host + ":80" - } - } - - var targetConn net.Conn - var err error + dialer := websocket.DefaultDialer if outReq.URL.Scheme == "wss" && f.tlsClientConfig != nil { - f.log.Debugf("vulcand/oxy/forward/websocket: Dialing secure (tls) tcp connection to host %s with TLS Client Config %v", host, f.tlsClientConfig) - targetConn, err = tls.Dial("tcp", host, f.tlsClientConfig) - } else { - f.log.Debugf("vulcand/oxy/forward/websocket: Dialing insecure (non-tls) tcp connection to host %s", host) - targetConn, err = net.Dial("tcp", host) + dialer.TLSClientConfig = f.tlsClientConfig.Clone() + // WebSocket is only in http/1.1 + dialer.TLSClientConfig.NextProtos = []string{"http/1.1"} } - + targetConn, resp, err := dialer.Dial(outReq.URL.String(), outReq.Header) if err != nil { - f.log.Errorf("vulcand/oxy/forward/websocket: Error dialing `%v`: %v", host, err) - ctx.errHandler.ServeHTTP(w, req, err) - return - } - hijacker, ok := w.(http.Hijacker) - if !ok { - f.log.Errorf("vulcand/oxy/forward/websocket: Unable to hijack the connection: does not implement http.Hijacker. ResponseWriter implementation type: %v", reflect.TypeOf(w)) - ctx.errHandler.ServeHTTP(w, req, err) + if resp == nil { + ctx.errHandler.ServeHTTP(w, req, err) + } else { + log.Errorf("vulcand/oxy/forward/websocket: Error dialing %q: %v with resp: %d %s", outReq.Host, err, resp.StatusCode, resp.Status) + hijacker, ok := w.(http.Hijacker) + if !ok { + log.Errorf("vulcand/oxy/forward/websocket: %s can not be hijack", reflect.TypeOf(w)) + ctx.errHandler.ServeHTTP(w, req, err) + return + } + + conn, _, err := hijacker.Hijack() + if err != nil { + log.Errorf("vulcand/oxy/forward/websocket: Failed to hijack responseWriter") + ctx.errHandler.ServeHTTP(w, req, err) + return + } + defer conn.Close() + + err = resp.Write(conn) + if err != nil { + log.Errorf("vulcand/oxy/forward/websocket: Failed to forward response") + ctx.errHandler.ServeHTTP(w, req, err) + return + } + } return } - underlyingConn, _, err := hijacker.Hijack() + + // Only the targetConn choose to CheckOrigin or not + upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { + return true + }} + + utils.RemoveHeaders(resp.Header, WebsocketUpgradeHeaders...) + underlyingConn, err := upgrader.Upgrade(w, req, resp.Header) if err != nil { - f.log.Errorf("vulcand/oxy/forward/websocket: Unable to hijack the connection: %v", err) - ctx.errHandler.ServeHTTP(w, req, err) + log.Errorf("vulcand/oxy/forward/websocket: Error while upgrading connection : %v", err) return } - // it is now caller's responsibility to Close the underlying connection defer underlyingConn.Close() defer targetConn.Close() - f.log.Infof("vulcand/oxy/forward/websocket: Writing outgoing Websocket request to target connection: %+v", outReq) - - // write the modified incoming request to the dialed connection - if err = outReq.Write(targetConn); err != nil { - f.log.Errorf("vulcand/oxy/forward/websocket: Unable to copy request to target: %v", err) - ctx.errHandler.ServeHTTP(w, req, err) - return - } errc := make(chan error, 2) replicate := func(dst io.Writer, src io.Reader, dstName string, srcName string) { _, err := io.Copy(dst, src) @@ -347,24 +353,30 @@ func (f *httpForwarder) serveWebSocket(w http.ResponseWriter, req *http.Request, } errc <- err } - go replicate(targetConn, underlyingConn, "backend", "client") - go replicate(underlyingConn, targetConn, "client", "backend") - err = <-errc // One goroutine complete - f.log.Infof("vulcand/oxy/forward/websocket: first proxying connection closed: %v", err) - err = <-errc // Both goroutines complete - f.log.Infof("vulcand/oxy/forward/websocket: second proxying connection closed: %v", err) + + go replicate(targetConn.UnderlyingConn(), underlyingConn.UnderlyingConn(), "backend", "client") + + // Try to read the first message + msgType, msg, err := targetConn.ReadMessage() + if err != nil { + log.Errorf("vulcand/oxy/forward/websocket: Couldn't read first message : %v", err) + } else { + underlyingConn.WriteMessage(msgType, msg) + } + + go replicate(underlyingConn.UnderlyingConn(), targetConn.UnderlyingConn(), "client", "backend") + <-errc } -// copyRequest makes a copy of the specified request. +// copyWebsocketRequest makes a copy of the specified request. func (f *httpForwarder) copyWebSocketRequest(req *http.Request) (outReq *http.Request) { outReq = new(http.Request) - *outReq = *req - outReq.URL = utils.CopyURL(req.URL) + *outReq = *req // includes shallow copies of maps, but we handle this below - //a good working default + outReq.URL = utils.CopyURL(req.URL) outReq.URL.Scheme = req.URL.Scheme - //sometimes backends might be registered as HTTP/HTTPS servers so translate URLs to websocket URLs. + // sometimes backends might be registered as HTTP/HTTPS servers so translate URLs to websocket URLs. switch req.URL.Scheme { case "https": outReq.URL.Scheme = "wss" @@ -372,24 +384,26 @@ func (f *httpForwarder) copyWebSocketRequest(req *http.Request) (outReq *http.Re outReq.URL.Scheme = "ws" } - outReq.URL.Host = req.URL.Host - outReq.URL.Opaque = req.RequestURI - - // Do not pass client Host header unless optsetter PassHostHeader is set. - if !f.passHost { - outReq.Host = req.Host + if requestURI, err := url.ParseRequestURI(outReq.RequestURI); err == nil { + if requestURI.RawPath != "" { + outReq.URL.Path = requestURI.RawPath + } else { + outReq.URL.Path = requestURI.Path + } + outReq.URL.RawQuery = requestURI.RawQuery } - // Overwrite close flag so we can keep persistent connection for the backend servers - outReq.Close = false + outReq.URL.Host = req.URL.Host outReq.Header = make(http.Header) + // gorilla websocket use this header to set the request.Host tested in checkSameOrigin + outReq.Header.Set("Host", outReq.Host) utils.CopyHeaders(outReq.Header, req.Header) + utils.RemoveHeaders(outReq.Header, WebsocketDialHeaders...) if f.rewriter != nil { f.rewriter.Rewrite(outReq) } - return outReq } diff --git a/forward/fwd_test.go b/forward/fwd_test.go index f8a9d672..62420074 100644 --- a/forward/fwd_test.go +++ b/forward/fwd_test.go @@ -2,7 +2,6 @@ package forward import ( "fmt" - "net" "net/http" "net/http/httptest" "strings" @@ -12,7 +11,6 @@ import ( "github.com/vulcand/oxy/testutils" "github.com/vulcand/oxy/utils" - "golang.org/x/net/websocket" . "gopkg.in/check.v1" ) @@ -279,83 +277,3 @@ func (s *FwdSuite) TestChunkedResponseConversion(c *C) { c.Assert(re.StatusCode, Equals, http.StatusOK) c.Assert(re.Header.Get("Content-Length"), Equals, fmt.Sprintf("%d", len("testtest1test2"))) } - -func (s *FwdSuite) TestDetectsWebsocketRequest(c *C) { - mux := http.NewServeMux() - mux.Handle("/ws", websocket.Handler(func(conn *websocket.Conn) { - conn.Write([]byte("ok")) - conn.Close() - })) - srv := testutils.NewHandler(func(w http.ResponseWriter, req *http.Request) { - websocketRequest := IsWebsocketRequest(req) - c.Assert(websocketRequest, Equals, true) - mux.ServeHTTP(w, req) - }) - defer srv.Close() - - serverAddr := srv.Listener.Addr().String() - resp, err := sendWebsocketRequest(serverAddr, "/ws", "echo", c) - c.Assert(err, IsNil) - c.Assert(resp, Equals, "ok") -} - -func (s *FwdSuite) TestForwardsWebsocketTraffic(c *C) { - f, err := New() - c.Assert(err, IsNil) - - mux := http.NewServeMux() - mux.Handle("/ws", websocket.Handler(func(conn *websocket.Conn) { - conn.Write([]byte("ok")) - conn.Close() - })) - srv := testutils.NewHandler(func(w http.ResponseWriter, req *http.Request) { - mux.ServeHTTP(w, req) - }) - defer srv.Close() - - proxy := testutils.NewHandler(func(w http.ResponseWriter, req *http.Request) { - path := req.URL.Path // keep the original path - // Set new backend URL - req.URL = testutils.ParseURI(srv.URL) - req.URL.Path = path - f.ServeHTTP(w, req) - }) - defer proxy.Close() - - proxyAddr := proxy.Listener.Addr().String() - resp, err := sendWebsocketRequest(proxyAddr, "/ws", "echo", c) - c.Assert(err, IsNil) - c.Assert(resp, Equals, "ok") -} - -const dialTimeout = time.Second - -func sendWebsocketRequest(serverAddr, path, data string, c *C) (received string, err error) { - client, err := net.DialTimeout("tcp", serverAddr, dialTimeout) - if err != nil { - return "", err - } - config := newWebsocketConfig(serverAddr, path) - conn, err := websocket.NewClient(config, client) - if err != nil { - return "", err - } - defer conn.Close() - if _, err := conn.Write([]byte(data)); err != nil { - return "", err - } - var msg = make([]byte, 512) - var n int - n, err = conn.Read(msg) - if err != nil { - return "", err - } - - received = string(msg[:n]) - return received, nil -} - -func newWebsocketConfig(serverAddr, path string) *websocket.Config { - config, _ := websocket.NewConfig(fmt.Sprintf("ws://%s%s", serverAddr, path), "http://localhost") - return config -} diff --git a/forward/fwd_websocket_test.go b/forward/fwd_websocket_test.go new file mode 100644 index 00000000..b17cbee7 --- /dev/null +++ b/forward/fwd_websocket_test.go @@ -0,0 +1,500 @@ +package forward + +import ( + "bufio" + "crypto/tls" + "fmt" + "io" + "net" + "net/http" + "net/http/httptest" + "time" + + gorillawebsocket "github.com/gorilla/websocket" + "github.com/vulcand/oxy/testutils" + "golang.org/x/net/websocket" + . "gopkg.in/check.v1" +) + +func (s *FwdSuite) TestWebSocketEcho(c *C) { + f, err := New() + c.Assert(err, IsNil) + + mux := http.NewServeMux() + mux.Handle("/ws", websocket.Handler(func(conn *websocket.Conn) { + msg := make([]byte, 4) + conn.Read(msg) + c.Log(msg) + conn.Write(msg) + conn.Close() + })) + srv := testutils.NewHandler(func(w http.ResponseWriter, req *http.Request) { + mux.ServeHTTP(w, req) + }) + defer srv.Close() + proxy := testutils.NewHandler(func(w http.ResponseWriter, req *http.Request) { + req.URL = testutils.ParseURI(srv.URL) + f.ServeHTTP(w, req) + }) + serverAddr := proxy.Listener.Addr().String() + c.Log(serverAddr) + headers := http.Header{} + webSocketURL := "ws://" + serverAddr + "/ws" + headers.Add("Origin", webSocketURL) + conn, resp, err := gorillawebsocket.DefaultDialer.Dial(webSocketURL, headers) + if err != nil { + c.Errorf("Error [%s] during Dial with response: %+v", err, resp) + return + } + conn.WriteMessage(gorillawebsocket.TextMessage, []byte("OK")) + c.Log(conn.ReadMessage()) + +} + +func (s *FwdSuite) TestWebSocketServerWithoutCheckOrigin(c *C) { + f, err := New() + c.Assert(err, IsNil) + + upgrader := gorillawebsocket.Upgrader{CheckOrigin: func(r *http.Request) bool { + return true + }} + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + c, err := upgrader.Upgrade(w, r, nil) + if err != nil { + return + } + defer c.Close() + for { + mt, message, err := c.ReadMessage() + if err != nil { + break + } + err = c.WriteMessage(mt, message) + if err != nil { + break + } + } + })) + defer srv.Close() + + proxy := createProxyWithForwarder(f, srv.URL) + defer proxy.Close() + + proxyAddr := proxy.Listener.Addr().String() + resp, err := newWebsocketRequest( + withServer(proxyAddr), + withPath("/ws"), + withData("ok"), + withOrigin("http://127.0.0.2"), + ).send() + + c.Assert(err, IsNil) + c.Assert(resp, Equals, "ok") +} +func (s *FwdSuite) TestWebSocketRequestWithOrigin(c *C) { + f, err := New() + c.Assert(err, IsNil) + + upgrader := gorillawebsocket.Upgrader{} + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + c, err := upgrader.Upgrade(w, r, nil) + if err != nil { + return + } + defer c.Close() + for { + mt, message, err := c.ReadMessage() + if err != nil { + break + } + err = c.WriteMessage(mt, message) + if err != nil { + break + } + } + })) + defer srv.Close() + + proxy := createProxyWithForwarder(f, srv.URL) + defer proxy.Close() + + proxyAddr := proxy.Listener.Addr().String() + _, err = newWebsocketRequest( + withServer(proxyAddr), + withPath("/ws"), + withData("echo"), + withOrigin("http://127.0.0.2"), + ).send() + + c.Assert(err, NotNil) + c.Assert(err, ErrorMatches, "bad status") + + resp, err := newWebsocketRequest( + withServer(proxyAddr), + withPath("/ws"), + withData("ok"), + ).send() + + c.Assert(err, IsNil) + c.Assert(resp, Equals, "ok") +} + +func (s *FwdSuite) TestWebSocketRequestWithQueryParams(c *C) { + f, err := New() + c.Assert(err, IsNil) + + upgrader := gorillawebsocket.Upgrader{} + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + return + } + defer conn.Close() + c.Assert(r.URL.Query().Get("query"), Equals, "test") + for { + mt, message, err := conn.ReadMessage() + if err != nil { + break + } + err = conn.WriteMessage(mt, message) + if err != nil { + break + } + } + })) + defer srv.Close() + + proxy := createProxyWithForwarder(f, srv.URL) + defer proxy.Close() + + proxyAddr := proxy.Listener.Addr().String() + + resp, err := newWebsocketRequest( + withServer(proxyAddr), + withPath("/ws?query=test"), + withData("ok"), + ).send() + + c.Assert(err, IsNil) + c.Assert(resp, Equals, "ok") +} + +func (s *FwdSuite) TestWebSocketRequestWithEncodedChar(c *C) { + f, err := New() + c.Assert(err, IsNil) + + upgrader := gorillawebsocket.Upgrader{} + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + return + } + defer conn.Close() + c.Assert(r.URL.Path, Equals, "/%3A%2F%2F") + for { + mt, message, err := conn.ReadMessage() + if err != nil { + break + } + err = conn.WriteMessage(mt, message) + if err != nil { + break + } + } + })) + defer srv.Close() + + proxy := createProxyWithForwarder(f, srv.URL) + defer proxy.Close() + + proxyAddr := proxy.Listener.Addr().String() + + resp, err := newWebsocketRequest( + withServer(proxyAddr), + withPath("/%3A%2F%2F"), + withData("ok"), + ).send() + + c.Assert(err, IsNil) + c.Assert(resp, Equals, "ok") +} + +func (s *FwdSuite) TestDetectsWebSocketRequest(c *C) { + mux := http.NewServeMux() + mux.Handle("/ws", websocket.Handler(func(conn *websocket.Conn) { + conn.Write([]byte("ok")) + conn.Close() + })) + srv := testutils.NewHandler(func(w http.ResponseWriter, req *http.Request) { + websocketRequest := IsWebsocketRequest(req) + c.Assert(websocketRequest, Equals, true) + mux.ServeHTTP(w, req) + }) + defer srv.Close() + + serverAddr := srv.Listener.Addr().String() + + resp, err := newWebsocketRequest( + withServer(serverAddr), + withPath("/ws"), + withData("echo"), + ).send() + + c.Assert(err, IsNil) + c.Assert(resp, Equals, "ok") +} + +func (s *FwdSuite) TestWebSocketUpgradeFailed(c *C) { + f, err := New() + c.Assert(err, IsNil) + + mux := http.NewServeMux() + mux.HandleFunc("/ws", func(w http.ResponseWriter, req *http.Request) { + w.WriteHeader(400) + }) + srv := testutils.NewHandler(func(w http.ResponseWriter, req *http.Request) { + mux.ServeHTTP(w, req) + }) + defer srv.Close() + + proxy := testutils.NewHandler(func(w http.ResponseWriter, req *http.Request) { + path := req.URL.Path // keep the original path + + if path == "/ws" { + // Set new backend URL + req.URL = testutils.ParseURI(srv.URL) + req.URL.Path = path + websocketRequest := IsWebsocketRequest(req) + c.Assert(websocketRequest, Equals, true) + f.ServeHTTP(w, req) + } else { + w.WriteHeader(200) + } + }) + defer proxy.Close() + + proxyAddr := proxy.Listener.Addr().String() + conn, err := net.DialTimeout("tcp", proxyAddr, dialTimeout) + + c.Assert(err, IsNil) + defer conn.Close() + + req, err := http.NewRequest(http.MethodGet, "ws://127.0.0.1/ws", nil) + c.Assert(err, IsNil) + + req.Header.Add("upgrade", "websocket") + req.Header.Add("Connection", "upgrade") + + req.Write(conn) + + // First request works with 400 + br := bufio.NewReader(conn) + resp, err := http.ReadResponse(br, req) + + c.Assert(resp.StatusCode, Equals, 400) + + req, err = http.NewRequest(http.MethodGet, "ws://127.0.0.1/ws2", nil) + req.Header.Add("upgrade", "websocket") + req.Header.Add("Connection", "upgrade") + req.Write(conn) + + br = bufio.NewReader(conn) + resp, err = http.ReadResponse(br, req) + c.Assert(err, Equals, io.ErrUnexpectedEOF) +} + +func (s *FwdSuite) TestForwardsWebsocketTraffic(c *C) { + f, err := New() + c.Assert(err, IsNil) + + mux := http.NewServeMux() + mux.Handle("/ws", websocket.Handler(func(conn *websocket.Conn) { + conn.Write([]byte("ok")) + conn.Close() + })) + srv := testutils.NewHandler(func(w http.ResponseWriter, req *http.Request) { + mux.ServeHTTP(w, req) + }) + defer srv.Close() + + proxy := createProxyWithForwarder(f, srv.URL) + defer proxy.Close() + + proxyAddr := proxy.Listener.Addr().String() + resp, err := newWebsocketRequest( + withServer(proxyAddr), + withPath("/ws"), + withData("echo"), + ).send() + + c.Assert(err, IsNil) + c.Assert(resp, Equals, "ok") +} + +func createTLSWebsocketServer() *httptest.Server { + upgrader := gorillawebsocket.Upgrader{} + srv := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + return + } + defer conn.Close() + for { + mt, message, err := conn.ReadMessage() + if err != nil { + break + } + err = conn.WriteMessage(mt, message) + if err != nil { + break + } + } + })) + return srv +} + +func createProxyWithForwarder(forwarder *Forwarder, URL string) *httptest.Server { + return testutils.NewHandler(func(w http.ResponseWriter, req *http.Request) { + path := req.URL.Path // keep the original path + // Set new backend URL + req.URL = testutils.ParseURI(URL) + req.URL.Path = path + + forwarder.ServeHTTP(w, req) + }) +} + +func (s *FwdSuite) TestWebSocketTransferTLSConfig(c *C) { + srv := createTLSWebsocketServer() + defer srv.Close() + + forwarderWithoutTLSConfig, err := New() + c.Assert(err, IsNil) + + proxyWithoutTLSConfig := createProxyWithForwarder(forwarderWithoutTLSConfig, srv.URL) + defer proxyWithoutTLSConfig.Close() + + proxyAddr := proxyWithoutTLSConfig.Listener.Addr().String() + + _, err = newWebsocketRequest( + withServer(proxyAddr), + withPath("/ws"), + withData("ok"), + ).send() + + c.Assert(err, NotNil) + c.Assert(err, ErrorMatches, "bad status") + + transport := &http.Transport{ + TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, + } + forwarderWithTLSConfig, err := New(RoundTripper(transport)) + c.Assert(err, IsNil) + + proxyWithTLSConfig := createProxyWithForwarder(forwarderWithTLSConfig, srv.URL) + defer proxyWithTLSConfig.Close() + + proxyAddr = proxyWithTLSConfig.Listener.Addr().String() + + resp, err := newWebsocketRequest( + withServer(proxyAddr), + withPath("/ws"), + withData("ok"), + ).send() + + c.Assert(err, IsNil) + c.Assert(resp, Equals, "ok") + + http.DefaultTransport.(*http.Transport).TLSClientConfig = &tls.Config{InsecureSkipVerify: true} + + forwarderWithTLSConfigFromDefaultTransport, err := New() + c.Assert(err, IsNil) + + proxyWithTLSConfigFromDefaultTransport := createProxyWithForwarder(forwarderWithTLSConfigFromDefaultTransport, srv.URL) + defer proxyWithTLSConfig.Close() + + proxyAddr = proxyWithTLSConfigFromDefaultTransport.Listener.Addr().String() + + resp, err = newWebsocketRequest( + withServer(proxyAddr), + withPath("/ws"), + withData("ok"), + ).send() + + c.Assert(err, IsNil) + c.Assert(resp, Equals, "ok") +} + +const dialTimeout = time.Second + +type websocketRequestOpt func(w *websocketRequest) + +func withServer(server string) websocketRequestOpt { + return func(w *websocketRequest) { + w.ServerAddr = server + } +} + +func withPath(path string) websocketRequestOpt { + return func(w *websocketRequest) { + w.Path = path + } +} + +func withData(data string) websocketRequestOpt { + return func(w *websocketRequest) { + w.Data = data + } +} + +func withOrigin(origin string) websocketRequestOpt { + return func(w *websocketRequest) { + w.Origin = origin + } +} + +func newWebsocketRequest(opts ...websocketRequestOpt) *websocketRequest { + wsrequest := &websocketRequest{} + for _, opt := range opts { + opt(wsrequest) + } + if wsrequest.Origin == "" { + wsrequest.Origin = "http://" + wsrequest.ServerAddr + } + if wsrequest.Config == nil { + wsrequest.Config, _ = websocket.NewConfig(fmt.Sprintf("ws://%s%s", wsrequest.ServerAddr, wsrequest.Path), wsrequest.Origin) + } + return wsrequest +} + +type websocketRequest struct { + ServerAddr string + Path string + Data string + Origin string + Config *websocket.Config +} + +func (w *websocketRequest) send() (string, error) { + client, err := net.DialTimeout("tcp", w.ServerAddr, dialTimeout) + if err != nil { + return "", err + } + conn, err := websocket.NewClient(w.Config, client) + if err != nil { + return "", err + } + defer conn.Close() + if _, err := conn.Write([]byte(w.Data)); err != nil { + return "", err + } + var msg = make([]byte, 512) + var n int + n, err = conn.Read(msg) + if err != nil { + return "", err + } + + received := string(msg[:n]) + return received, nil +} diff --git a/forward/headers.go b/forward/headers.go index 0702ed8f..99a64678 100644 --- a/forward/headers.go +++ b/forward/headers.go @@ -1,19 +1,23 @@ package forward const ( - XForwardedProto = "X-Forwarded-Proto" - XForwardedFor = "X-Forwarded-For" - XForwardedHost = "X-Forwarded-Host" - XForwardedServer = "X-Forwarded-Server" - Connection = "Connection" - KeepAlive = "Keep-Alive" - ProxyAuthenticate = "Proxy-Authenticate" - ProxyAuthorization = "Proxy-Authorization" - Te = "Te" // canonicalized version of "TE" - Trailers = "Trailers" - TransferEncoding = "Transfer-Encoding" - Upgrade = "Upgrade" - ContentLength = "Content-Length" + XForwardedProto = "X-Forwarded-Proto" + XForwardedFor = "X-Forwarded-For" + XForwardedHost = "X-Forwarded-Host" + XForwardedServer = "X-Forwarded-Server" + Connection = "Connection" + KeepAlive = "Keep-Alive" + ProxyAuthenticate = "Proxy-Authenticate" + ProxyAuthorization = "Proxy-Authorization" + Te = "Te" // canonicalized version of "TE" + Trailers = "Trailers" + TransferEncoding = "Transfer-Encoding" + Upgrade = "Upgrade" + ContentLength = "Content-Length" + SecWebsocketKey = "Sec-Websocket-Key" + SecWebsocketVersion = "Sec-Websocket-Version" + SecWebsocketExtensions = "Sec-Websocket-Extensions" + SecWebsocketAccept = "Sec-Websocket-Accept" ) // Hop-by-hop headers. These are removed when sent to the backend. @@ -29,3 +33,18 @@ var HopHeaders = []string{ TransferEncoding, Upgrade, } + +var WebsocketDialHeaders = []string{ + Upgrade, + Connection, + SecWebsocketKey, + SecWebsocketVersion, + SecWebsocketExtensions, + SecWebsocketAccept, +} + +var WebsocketUpgradeHeaders = []string{ + Upgrade, + Connection, + SecWebsocketAccept, +}