diff --git a/modules/caddyhttp/reverseproxy/reverseproxy.go b/modules/caddyhttp/reverseproxy/reverseproxy.go index 44cc2f9d9a6..5424da4d0d0 100644 --- a/modules/caddyhttp/reverseproxy/reverseproxy.go +++ b/modules/caddyhttp/reverseproxy/reverseproxy.go @@ -17,6 +17,8 @@ package reverseproxy import ( "bytes" "context" + "crypto/rand" + "encoding/base64" "encoding/json" "errors" "fmt" @@ -398,6 +400,23 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddyht return caddyhttp.Error(http.StatusInternalServerError, fmt.Errorf("preparing request for upstream round-trip: %v", err)) } + // websocket over http2, assuming backend doesn't support this, the request will be modifided to http1.1 upgrade + // TODO: once we can reliably detect backend support this, it can be removed for those backends + if r.ProtoMajor == 2 && r.Method == http.MethodConnect && r.Header.Get(":protocol") != "" { + clonedReq.Header.Del(":protocol") + // keep the body for later use. http1.1 upgrade uses http.NoBody + caddyhttp.SetVar(clonedReq.Context(), "h2_websocket_body", clonedReq.Body) + clonedReq.Body = http.NoBody + clonedReq.Method = http.MethodGet + clonedReq.Header.Set("Upgrade", r.Header.Get(":protocol")) + clonedReq.Header.Set("Connection", "Upgrade") + key := make([]byte, 16) + _, randErr := rand.Read(key) + if randErr != nil { + return randErr + } + clonedReq.Header.Set("Sec-Websocket-Key", base64.StdEncoding.EncodeToString(key)) + } // we will need the original headers and Host value if // header operations are configured; this is so that each diff --git a/modules/caddyhttp/reverseproxy/streaming.go b/modules/caddyhttp/reverseproxy/streaming.go index a696ac7fbed..6f2c3390074 100644 --- a/modules/caddyhttp/reverseproxy/streaming.go +++ b/modules/caddyhttp/reverseproxy/streaming.go @@ -19,9 +19,11 @@ package reverseproxy import ( + "bufio" "context" "errors" "fmt" + "github.com/caddyserver/caddy/v2/modules/caddyhttp" "io" weakrand "math/rand" "mime" @@ -34,6 +36,24 @@ import ( "golang.org/x/net/http/httpguts" ) +type h2ReadWriteCloser struct { + io.ReadCloser + http.ResponseWriter +} + +func (rwc h2ReadWriteCloser) Write(p []byte) (n int, err error) { + n, err = rwc.ResponseWriter.Write(p) + if err != nil { + return 0, err + } + + err = http.NewResponseController(rwc.ResponseWriter).Flush() + if err != nil { + return 0, err + } + return n, nil +} + func (h *Handler) handleUpgradeResponse(logger *zap.Logger, wg *sync.WaitGroup, rw http.ResponseWriter, req *http.Request, res *http.Response) { reqUpType := upgradeType(req.Header) resUpType := upgradeType(res.Header) @@ -61,20 +81,45 @@ func (h *Handler) handleUpgradeResponse(logger *zap.Logger, wg *sync.WaitGroup, // write header first, response headers should not be counted in size // like the rest of handler chain. copyHeader(rw.Header(), res.Header) - rw.WriteHeader(res.StatusCode) - logger.Debug("upgrading connection") + var ( + conn io.ReadWriteCloser + brw *bufio.ReadWriter + ) + // websocket over http2, assuming backend doesn't support this, the request will be modifided to http1.1 upgrade + // TODO: once we can reliably detect backend support this, it can be removed for those backends + if body, ok := caddyhttp.GetVar(req.Context(), "h2_websocket_body").(io.ReadCloser); ok { + req.Body = body + rw.Header().Del("Upgrade") + rw.Header().Del("Connection") + rw.Header().Del("Sec-Websocket-Accept") + rw.WriteHeader(http.StatusOK) + + flushErr := http.NewResponseController(rw).Flush() + if flushErr != nil { + h.logger.Error("failed to flush http2 websocket response", zap.Error(flushErr)) + return + } + conn = h2ReadWriteCloser{req.Body, rw} + // bufio is not needed, use minimal buffer + brw = bufio.NewReadWriter(bufio.NewReaderSize(conn, 1), bufio.NewWriterSize(conn, 1)) + } else { + rw.WriteHeader(res.StatusCode) + + logger.Debug("upgrading connection") - //nolint:bodyclose - conn, brw, hijackErr := http.NewResponseController(rw).Hijack() - if errors.Is(hijackErr, http.ErrNotSupported) { - h.logger.Error("can't switch protocols using non-Hijacker ResponseWriter", zap.String("type", fmt.Sprintf("%T", rw))) - return - } + var hijackErr error + //nolint:bodyclose + conn, brw, hijackErr = http.NewResponseController(rw).Hijack() + if errors.Is(hijackErr, http.ErrNotSupported) { + h.logger.Error("can't switch protocols using non-Hijacker ResponseWriter", zap.String("type", fmt.Sprintf("%T", rw))) + return + } - if hijackErr != nil { - h.logger.Error("hijack failed on protocol switch", zap.Error(hijackErr)) - return + if hijackErr != nil { + h.logger.Error("hijack failed on protocol switch", zap.Error(hijackErr)) + return + } } // adopted from https://github.com/golang/go/commit/8bcf2834afdf6a1f7937390903a41518715ef6f5