diff --git a/forward/fwd.go b/forward/fwd.go index b020dfb5..3f09c1db 100644 --- a/forward/fwd.go +++ b/forward/fwd.go @@ -228,14 +228,35 @@ func (f *Forwarder) ServeHTTP(w http.ResponseWriter, req *http.Request) { } } +func (f *httpForwarder) getUrlFromRequest(req *http.Request) *url.URL { + // If the Request was created by Go via a real HTTP request, RequestURI will + // contain the original query string. If the Request was created in code, RequestURI + // will be empty, and we will use the URL object instead + u := req.URL + if req.RequestURI != "" { + parsedURL, err := url.ParseRequestURI(req.RequestURI) + if err == nil { + u = parsedURL + } else { + f.log.Warnf("vulcand/oxy/forward: error when parsing RequestURI: %s", err) + } + } + return u +} + // Modify the request to handle the target URL func (f *httpForwarder) modifyRequest(outReq *http.Request, target *url.URL) { outReq.URL = utils.CopyURL(outReq.URL) outReq.URL.Scheme = target.Scheme outReq.URL.Host = target.Host - outReq.URL.Opaque = outReq.RequestURI - // raw query is already included in RequestURI, so ignore it to avoid dupes - outReq.URL.RawQuery = "" + + u := f.getUrlFromRequest(outReq) + + outReq.URL.Path = u.Path + outReq.URL.RawPath = u.RawPath + outReq.URL.RawQuery = u.RawQuery + outReq.RequestURI = "" // Outgoing request should not have RequestURI + // Do not pass client Host header unless optsetter PassHostHeader is set. if !f.passHost { outReq.Host = target.Host @@ -352,14 +373,12 @@ func (f *httpForwarder) copyWebSocketRequest(req *http.Request) (outReq *http.Re outReq.URL.Scheme = "ws" } - if requestURI, err := url.ParseRequestURI(outReq.RequestURI); err == nil { - if requestURI.RawPath != "" { - outReq.URL.Path = requestURI.RawPath - } else { - outReq.URL.Path = requestURI.Path - } - outReq.URL.RawQuery = requestURI.RawQuery - } + u := f.getUrlFromRequest(outReq) + + outReq.URL.Path = u.Path + outReq.URL.RawPath = u.RawPath + outReq.URL.RawQuery = u.RawQuery + outReq.RequestURI = "" // Outgoing request should not have RequestURI outReq.URL.Host = req.URL.Host diff --git a/forward/fwd_test.go b/forward/fwd_test.go index 5a02baf8..dc3cbe2a 100644 --- a/forward/fwd_test.go +++ b/forward/fwd_test.go @@ -223,10 +223,10 @@ func (s *FwdSuite) TestCustomLogger(c *C) { c.Assert(re.StatusCode, Equals, http.StatusOK) } -func (s *FwdSuite) TestEscapedURL(c *C) { - var outURL string +func (s *FwdSuite) TestRouteForwarding(c *C) { + var outPath string srv := testutils.NewHandler(func(w http.ResponseWriter, req *http.Request) { - outURL = req.RequestURI + outPath = req.RequestURI w.Write([]byte("hello")) }) defer srv.Close() @@ -240,16 +240,32 @@ func (s *FwdSuite) TestEscapedURL(c *C) { }) defer proxy.Close() - path := "/log/http%3A%2F%2Fwww.site.com%2Fsomething?a=b" + tests := []struct { + Path string + Query string + + ExpectedPath string + }{ + {"/hello", "", "/hello"}, + {"//hello", "", "//hello"}, + {"///hello", "", "///hello"}, + {"/hello", "abc=def&def=123", "/hello?abc=def&def=123"}, + {"/log/http%3A%2F%2Fwww.site.com%2Fsomething?a=b", "", "/log/http%3A%2F%2Fwww.site.com%2Fsomething?a=b"}, + } - request, err := http.NewRequest("GET", proxy.URL, nil) - parsed := testutils.ParseURI(proxy.URL) - parsed.Opaque = path - request.URL = parsed - re, err := http.DefaultClient.Do(request) - c.Assert(err, IsNil) - c.Assert(re.StatusCode, Equals, http.StatusOK) - c.Assert(outURL, Equals, path) + for _, test := range tests { + proxyURL := proxy.URL + test.Path + if test.Query != "" { + proxyURL = proxyURL + "?" + test.Query + } + request, err := http.NewRequest("GET", proxyURL, nil) + c.Assert(err, IsNil) + + re, err := http.DefaultClient.Do(request) + c.Assert(err, IsNil) + c.Assert(re.StatusCode, Equals, http.StatusOK) + c.Assert(outPath, Equals, test.ExpectedPath) + } } func (s *FwdSuite) TestForwardedProto(c *C) { diff --git a/forward/fwd_websocket_test.go b/forward/fwd_websocket_test.go index 53ca8a96..4cc854f5 100644 --- a/forward/fwd_websocket_test.go +++ b/forward/fwd_websocket_test.go @@ -190,7 +190,7 @@ func (s *FwdSuite) TestWebSocketRequestWithEncodedChar(c *C) { return } defer conn.Close() - c.Assert(r.URL.Path, Equals, "/%3A%2F%2F") + c.Assert(r.URL.EscapedPath(), Equals, "/%3A%2F%2F") for { mt, message, err := conn.ReadMessage() if err != nil {