From 9559b037e79ad673c71f6ef7c732c00949014cd2 Mon Sep 17 00:00:00 2001 From: Klaus Post Date: Wed, 26 Oct 2022 14:55:23 +0200 Subject: [PATCH] gzhttp: Always delete `HeaderNoCompression` (#683) * gzhttp: Always delete `HeaderNoCompression` Also when it cannot be gzipped. * Also remove header when starting to write --- gzhttp/compress.go | 78 +++++++++++++++++++++++++++++++++++++---- gzhttp/compress_test.go | 66 +++++++++++++++++++++++++++++++++- 2 files changed, 137 insertions(+), 7 deletions(-) diff --git a/gzhttp/compress.go b/gzhttp/compress.go index 91503436f5..da742dd8f9 100644 --- a/gzhttp/compress.go +++ b/gzhttp/compress.go @@ -100,12 +100,13 @@ func (w *GzipResponseWriter) Write(b []byte) (int, error) { } w.buf = append(w.buf, b[:toAdd]...) remain := b[toAdd:] + hdr := w.Header() // Only continue if they didn't already choose an encoding or a known unhandled content length or type. - if len(w.Header()[HeaderNoCompression]) == 0 && w.Header().Get(contentEncoding) == "" && w.Header().Get(contentRange) == "" { + if len(hdr[HeaderNoCompression]) == 0 && hdr.Get(contentEncoding) == "" && hdr.Get(contentRange) == "" { // Check more expensive parts now. - cl, _ := atoi(w.Header().Get(contentLength)) - ct := w.Header().Get(contentType) + cl, _ := atoi(hdr.Get(contentLength)) + ct := hdr.Get(contentType) if cl == 0 || cl >= w.minSize && (ct == "" || w.contentTypeFilter(ct)) { // If the current buffer is less than minSize and a Content-Length isn't set, then wait until we have more data. if len(w.buf) < w.minSize && cl == 0 { @@ -121,8 +122,8 @@ func (w *GzipResponseWriter) Write(b []byte) (int, error) { // Handles the intended case of setting a nil Content-Type (as for http/server or http/fs) // Set the header only if the key does not exist - if _, ok := w.Header()[contentType]; w.setContentType && !ok { - w.Header().Set(contentType, ct) + if _, ok := hdr[contentType]; w.setContentType && !ok { + hdr.Set(contentType, ct) } // If the Content-Type is acceptable to GZIP, initialize the GZIP writer. @@ -388,7 +389,8 @@ func NewWrapper(opts ...option) (func(http.Handler) http.HandlerFunc, error) { h.ServeHTTP(gw, r) } } else { - h.ServeHTTP(w, r) + h.ServeHTTP(newNoCompressResponseWriter(w), r) + w.Header().Del(HeaderNoCompression) } } }, nil @@ -743,3 +745,67 @@ func atoi(s string) (int, bool) { i64, err := strconv.ParseInt(s, 10, 0) return int(i64), err == nil } + +// newNoCompressResponseWriter will return a response writer that +// cleans up compression artifacts. +// Depending on whether http.Hijacker is supported the returned will as well. +func newNoCompressResponseWriter(w http.ResponseWriter) http.ResponseWriter { + n := &noCompressResponseWriter{hw: w} + if hj, ok := w.(http.Hijacker); ok { + x := struct { + http.ResponseWriter + http.Hijacker + http.Flusher + }{ + ResponseWriter: n, + Hijacker: hj, + Flusher: n, + } + return x + } + + return n +} + +// noCompressResponseWriter filters out HeaderNoCompression. +type noCompressResponseWriter struct { + hw http.ResponseWriter + hdrCleaned bool +} + +func (n *noCompressResponseWriter) CloseNotify() <-chan bool { + if cn, ok := n.hw.(http.CloseNotifier); ok { + return cn.CloseNotify() + } + return nil +} + +func (n *noCompressResponseWriter) Flush() { + if !n.hdrCleaned { + n.hw.Header().Del(HeaderNoCompression) + n.hdrCleaned = true + } + if f, ok := n.hw.(http.Flusher); ok { + f.Flush() + } +} + +func (n *noCompressResponseWriter) Header() http.Header { + return n.hw.Header() +} + +func (n *noCompressResponseWriter) Write(bytes []byte) (int, error) { + if !n.hdrCleaned { + n.hw.Header().Del(HeaderNoCompression) + n.hdrCleaned = true + } + return n.hw.Write(bytes) +} + +func (n *noCompressResponseWriter) WriteHeader(statusCode int) { + if !n.hdrCleaned { + n.hw.Header().Del(HeaderNoCompression) + n.hdrCleaned = true + } + n.hw.WriteHeader(statusCode) +} diff --git a/gzhttp/compress_test.go b/gzhttp/compress_test.go index cfeb350f29..262ee52495 100644 --- a/gzhttp/compress_test.go +++ b/gzhttp/compress_test.go @@ -748,9 +748,9 @@ func TestContentTypes(t *testing.T) { }) t.Run("disable-"+tt.name, func(t *testing.T) { handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) w.Header().Set("Content-Type", tt.contentType) w.Header().Set(HeaderNoCompression, "plz") + w.WriteHeader(http.StatusOK) w.Write(testBody) }) @@ -765,6 +765,70 @@ func TestContentTypes(t *testing.T) { assertEqual(t, 200, res.StatusCode) assertNotEqual(t, "gzip", res.Header.Get("Content-Encoding")) + _, ok := res.Header[HeaderNoCompression] + assertEqual(t, false, ok) + }) + t.Run("head-req"+tt.name, func(t *testing.T) { + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", tt.contentType) + w.Header().Set(HeaderNoCompression, "plz") + w.WriteHeader(http.StatusOK) + }) + + wrapper, err := NewWrapper(ContentTypes(tt.acceptedContentTypes)) + assertNil(t, err) + + req, _ := http.NewRequest("HEAD", "/whatever", nil) + req.Header.Set("Accept-Encoding", "gzip") + resp := httptest.NewRecorder() + wrapper(handler).ServeHTTP(resp, req) + res := resp.Result() + + assertEqual(t, 200, res.StatusCode) + assertNotEqual(t, "gzip", res.Header.Get("Content-Encoding")) + _, ok := res.Header[HeaderNoCompression] + assertEqual(t, false, ok) + }) + t.Run("head-req-no-ok"+tt.name, func(t *testing.T) { + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", tt.contentType) + w.Header().Set(HeaderNoCompression, "plz") + }) + + wrapper, err := NewWrapper(ContentTypes(tt.acceptedContentTypes)) + assertNil(t, err) + + req, _ := http.NewRequest("HEAD", "/whatever", nil) + req.Header.Set("Accept-Encoding", "gzip") + resp := httptest.NewRecorder() + wrapper(handler).ServeHTTP(resp, req) + res := resp.Result() + + assertEqual(t, 200, res.StatusCode) + assertNotEqual(t, "gzip", res.Header.Get("Content-Encoding")) + _, ok := res.Header[HeaderNoCompression] + assertEqual(t, false, ok) + }) + t.Run("req-no-ok-write"+tt.name, func(t *testing.T) { + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", tt.contentType) + w.Header().Set(HeaderNoCompression, "plz") + w.Write(testBody) + }) + + wrapper, err := NewWrapper(ContentTypes(tt.acceptedContentTypes)) + assertNil(t, err) + + req, _ := http.NewRequest("GET", "/whatever", nil) + req.Header.Set("Accept-Encoding", "") + resp := httptest.NewRecorder() + wrapper(handler).ServeHTTP(resp, req) + res := resp.Result() + + assertEqual(t, 200, res.StatusCode) + assertNotEqual(t, "gzip", res.Header.Get("Content-Encoding")) + _, ok := res.Header[HeaderNoCompression] + assertEqual(t, false, ok) }) } }