diff --git a/middleware/compress.go b/middleware/compress.go index 9e5f61069..cbe29fc32 100644 --- a/middleware/compress.go +++ b/middleware/compress.go @@ -2,6 +2,7 @@ package middleware import ( "bufio" + "bytes" "compress/gzip" "io" "net" @@ -21,12 +22,30 @@ type ( // Gzip compression level. // Optional. Default value -1. Level int `yaml:"level"` + + // Length threshold before gzip compression is applied. + // Optional. Default value 0. + // + // Most of the time you will not need to change the default. Compressing + // a short response might increase the transmitted data because of the + // gzip format overhead. Compressing the response will also consume CPU + // and time on the server and the client (for decompressing). Depending on + // your use case such a threshold might be useful. + // + // See also: + // https://webmasters.stackexchange.com/questions/31750/what-is-recommended-minimum-object-size-for-gzip-performance-benefits + MinLength int } gzipResponseWriter struct { io.Writer http.ResponseWriter - wroteBody bool + wroteHeader bool + wroteBody bool + minLength int + minLengthExceeded bool + buffer *bytes.Buffer + code int } ) @@ -37,8 +56,9 @@ const ( var ( // DefaultGzipConfig is the default Gzip middleware config. DefaultGzipConfig = GzipConfig{ - Skipper: DefaultSkipper, - Level: -1, + Skipper: DefaultSkipper, + Level: -1, + MinLength: 0, } ) @@ -58,8 +78,12 @@ func GzipWithConfig(config GzipConfig) echo.MiddlewareFunc { if config.Level == 0 { config.Level = DefaultGzipConfig.Level } + if config.MinLength < 0 { + config.MinLength = DefaultGzipConfig.MinLength + } pool := gzipCompressPool(config) + bpool := bufferPool() return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { @@ -70,7 +94,6 @@ func GzipWithConfig(config GzipConfig) echo.MiddlewareFunc { res := c.Response() res.Header().Add(echo.HeaderVary, echo.HeaderAcceptEncoding) if strings.Contains(c.Request().Header.Get(echo.HeaderAcceptEncoding), gzipScheme) { - res.Header().Set(echo.HeaderContentEncoding, gzipScheme) // Issue #806 i := pool.Get() w, ok := i.(*gzip.Writer) if !ok { @@ -78,7 +101,11 @@ func GzipWithConfig(config GzipConfig) echo.MiddlewareFunc { } rw := res.Writer w.Reset(rw) - grw := &gzipResponseWriter{Writer: w, ResponseWriter: rw} + + buf := bpool.Get().(*bytes.Buffer) + buf.Reset() + + grw := &gzipResponseWriter{Writer: w, ResponseWriter: rw, minLength: config.MinLength, buffer: buf} defer func() { if !grw.wroteBody { if res.Header().Get(echo.HeaderContentEncoding) == gzipScheme { @@ -89,8 +116,17 @@ func GzipWithConfig(config GzipConfig) echo.MiddlewareFunc { // See issue #424, #407. res.Writer = rw w.Reset(io.Discard) + } else if !grw.minLengthExceeded { + // Write uncompressed response + res.Writer = rw + if grw.wroteHeader { + grw.ResponseWriter.WriteHeader(grw.code) + } + grw.buffer.WriteTo(rw) + w.Reset(io.Discard) } w.Close() + bpool.Put(buf) pool.Put(w) }() res.Writer = grw @@ -102,7 +138,11 @@ func GzipWithConfig(config GzipConfig) echo.MiddlewareFunc { func (w *gzipResponseWriter) WriteHeader(code int) { w.Header().Del(echo.HeaderContentLength) // Issue #444 - w.ResponseWriter.WriteHeader(code) + + w.wroteHeader = true + + // Delay writing of the header until we know if we'll actually compress the response + w.code = code } func (w *gzipResponseWriter) Write(b []byte) (int, error) { @@ -110,10 +150,40 @@ func (w *gzipResponseWriter) Write(b []byte) (int, error) { w.Header().Set(echo.HeaderContentType, http.DetectContentType(b)) } w.wroteBody = true + + if !w.minLengthExceeded { + n, err := w.buffer.Write(b) + + if w.buffer.Len() >= w.minLength { + w.minLengthExceeded = true + + // The minimum length is exceeded, add Content-Encoding header and write the header + w.Header().Set(echo.HeaderContentEncoding, gzipScheme) // Issue #806 + if w.wroteHeader { + w.ResponseWriter.WriteHeader(w.code) + } + + return w.Writer.Write(w.buffer.Bytes()) + } + + return n, err + } + return w.Writer.Write(b) } func (w *gzipResponseWriter) Flush() { + if !w.minLengthExceeded { + // Enforce compression because we will not know how much more data will come + w.minLengthExceeded = true + w.Header().Set(echo.HeaderContentEncoding, gzipScheme) // Issue #806 + if w.wroteHeader { + w.ResponseWriter.WriteHeader(w.code) + } + + w.Writer.Write(w.buffer.Bytes()) + } + w.Writer.(*gzip.Writer).Flush() if flusher, ok := w.ResponseWriter.(http.Flusher); ok { flusher.Flush() @@ -142,3 +212,12 @@ func gzipCompressPool(config GzipConfig) sync.Pool { }, } } + +func bufferPool() sync.Pool { + return sync.Pool{ + New: func() interface{} { + b := &bytes.Buffer{} + return b + }, + } +} diff --git a/middleware/compress_test.go b/middleware/compress_test.go index 714548e8b..e43e2d633 100644 --- a/middleware/compress_test.go +++ b/middleware/compress_test.go @@ -88,6 +88,123 @@ func TestGzip(t *testing.T) { assert.Equal(t, "test", buf.String()) } +func TestGzipWithMinLength(t *testing.T) { + assert := assert.New(t) + + e := echo.New() + // Minimal response length + e.Use(GzipWithConfig(GzipConfig{MinLength: 10})) + e.GET("/", func(c echo.Context) error { + c.Response().Write([]byte("foobarfoobar")) + return nil + }) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme) + rec := httptest.NewRecorder() + e.ServeHTTP(rec, req) + assert.Equal(gzipScheme, rec.Header().Get(echo.HeaderContentEncoding)) + r, err := gzip.NewReader(rec.Body) + if assert.NoError(err) { + buf := new(bytes.Buffer) + defer r.Close() + buf.ReadFrom(r) + assert.Equal("foobarfoobar", buf.String()) + } +} + +func TestGzipWithMinLengthTooShort(t *testing.T) { + assert := assert.New(t) + + e := echo.New() + // Minimal response length + e.Use(GzipWithConfig(GzipConfig{MinLength: 10})) + e.GET("/", func(c echo.Context) error { + c.Response().Write([]byte("test")) + return nil + }) + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme) + rec := httptest.NewRecorder() + e.ServeHTTP(rec, req) + assert.Equal("", rec.Header().Get(echo.HeaderContentEncoding)) + assert.Contains(rec.Body.String(), "test") +} + +func TestGzipWithMinLengthChunked(t *testing.T) { + assert := assert.New(t) + + e := echo.New() + + // Gzip chunked + chunkBuf := make([]byte, 5) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme) + rec := httptest.NewRecorder() + + var r *gzip.Reader = nil + + c := e.NewContext(req, rec) + GzipWithConfig(GzipConfig{MinLength: 10})(func(c echo.Context) error { + c.Response().Header().Set("Content-Type", "text/event-stream") + c.Response().Header().Set("Transfer-Encoding", "chunked") + + // Write and flush the first part of the data + c.Response().Write([]byte("test\n")) + c.Response().Flush() + + // Read the first part of the data + assert.True(rec.Flushed) + assert.Equal(gzipScheme, rec.Header().Get(echo.HeaderContentEncoding)) + + var err error + r, err = gzip.NewReader(rec.Body) + assert.NoError(err) + + _, err = io.ReadFull(r, chunkBuf) + assert.NoError(err) + assert.Equal("test\n", string(chunkBuf)) + + // Write and flush the second part of the data + c.Response().Write([]byte("test\n")) + c.Response().Flush() + + _, err = io.ReadFull(r, chunkBuf) + assert.NoError(err) + assert.Equal("test\n", string(chunkBuf)) + + // Write the final part of the data and return + c.Response().Write([]byte("test")) + return nil + })(c) + + assert.NotNil(r) + + buf := new(bytes.Buffer) + + buf.ReadFrom(r) + assert.Equal("test", buf.String()) + + r.Close() +} + +func TestGzipWithMinLengthNoContent(t *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + h := GzipWithConfig(GzipConfig{MinLength: 10})(func(c echo.Context) error { + return c.NoContent(http.StatusNoContent) + }) + if assert.NoError(t, h(c)) { + assert.Empty(t, rec.Header().Get(echo.HeaderContentEncoding)) + assert.Empty(t, rec.Header().Get(echo.HeaderContentType)) + assert.Equal(t, 0, len(rec.Body.Bytes())) + } +} + func TestGzipNoContent(t *testing.T) { e := echo.New() req := httptest.NewRequest(http.MethodGet, "/", nil)