From 0d0bbfee5a8dd12a82e442d3cbb11e56726dd06e Mon Sep 17 00:00:00 2001 From: AutumnSun Date: Sun, 2 Jul 2023 18:40:26 +0800 Subject: [PATCH] Auto add 'Vary' header after compression (#1585) * Auto add 'Vary' header after compression Add config `SetAddVaryHeaderForCompression` to enable 'Vary: Accept-Encoding' header when compression is used. * feat: always set the Vary header * create and use `ResponseHeader.AddVaryBytes` * not export 'AddVaryBytes' --- header.go | 12 +++++ header_test.go | 62 ++++++++++++++++++++++ http.go | 3 ++ server_test.go | 141 +++++++++++++++++++++++++++++++++++++++++++++++++ strings.go | 1 + 5 files changed, 219 insertions(+) diff --git a/header.go b/header.go index 5665e79e59..ca9062f6cd 100644 --- a/header.go +++ b/header.go @@ -344,6 +344,18 @@ func (h *ResponseHeader) SetContentEncodingBytes(contentEncoding []byte) { h.contentEncoding = append(h.contentEncoding[:0], contentEncoding...) } +// addVaryBytes add value to the 'Vary' header if it's not included +func (h *ResponseHeader) addVaryBytes(value []byte) { + v := h.peek(strVary) + if len(v) == 0 { + // 'Vary' is not set + h.SetBytesV(HeaderVary, value) + } else if !bytes.Contains(v, value) { + // 'Vary' is set and not contains target value + h.SetBytesV(HeaderVary, append(append(v, ','), value...)) + } // else: 'Vary' is set and contains target value +} + // Server returns Server header value. func (h *ResponseHeader) Server() []byte { return h.server diff --git a/header_test.go b/header_test.go index 9f9fd351bd..c2b8b18e74 100644 --- a/header_test.go +++ b/header_test.go @@ -3007,3 +3007,65 @@ func TestResponseHeader_Keys(t *testing.T) { t.Fatalf("Unexpected value %q. Expected %q", actualTrailerKeys, expectedTrailerKeys) } } + +func TestAddVaryHeader(t *testing.T) { + t.Parallel() + + var h ResponseHeader + + h.addVaryBytes([]byte("Accept-Encoding")) + got := string(h.Peek("Vary")) + expected := "Accept-Encoding" + if got != expected { + t.Errorf("expected %q got %q", expected, got) + } + + var buf bytes.Buffer + h.WriteTo(&buf) //nolint:errcheck + + if n := strings.Count(buf.String(), "Vary: "); n != 1 { + t.Errorf("Vary occurred %d times", n) + } +} + +func TestAddVaryHeaderExisting(t *testing.T) { + t.Parallel() + + var h ResponseHeader + + h.Set("Vary", "Accept") + h.addVaryBytes([]byte("Accept-Encoding")) + got := string(h.Peek("Vary")) + expected := "Accept,Accept-Encoding" + if got != expected { + t.Errorf("expected %q got %q", expected, got) + } + + var buf bytes.Buffer + h.WriteTo(&buf) //nolint:errcheck + + if n := strings.Count(buf.String(), "Vary: "); n != 1 { + t.Errorf("Vary occurred %d times", n) + } +} + +func TestAddVaryHeaderExistingAcceptEncoding(t *testing.T) { + t.Parallel() + + var h ResponseHeader + + h.Set("Vary", "Accept-Encoding") + h.addVaryBytes([]byte("Accept-Encoding")) + got := string(h.Peek("Vary")) + expected := "Accept-Encoding" + if got != expected { + t.Errorf("expected %q got %q", expected, got) + } + + var buf bytes.Buffer + h.WriteTo(&buf) //nolint:errcheck + + if n := strings.Count(buf.String(), "Vary: "); n != 1 { + t.Errorf("Vary occurred %d times", n) + } +} diff --git a/http.go b/http.go index ffb02c8b7e..5d8dc93477 100644 --- a/http.go +++ b/http.go @@ -1723,6 +1723,7 @@ func (resp *Response) brotliBody(level int) error { resp.bodyRaw = nil } resp.Header.SetContentEncodingBytes(strBr) + resp.Header.addVaryBytes(strAcceptEncoding) return nil } @@ -1778,6 +1779,7 @@ func (resp *Response) gzipBody(level int) error { resp.bodyRaw = nil } resp.Header.SetContentEncodingBytes(strGzip) + resp.Header.addVaryBytes(strAcceptEncoding) return nil } @@ -1833,6 +1835,7 @@ func (resp *Response) deflateBody(level int) error { resp.bodyRaw = nil } resp.Header.SetContentEncodingBytes(strDeflate) + resp.Header.addVaryBytes(strAcceptEncoding) return nil } diff --git a/server_test.go b/server_test.go index e22da4edfc..67ca8a6d8a 100644 --- a/server_test.go +++ b/server_test.go @@ -2035,6 +2035,147 @@ func TestCompressHandler(t *testing.T) { } } +func TestCompressHandlerVary(t *testing.T) { + t.Parallel() + + expectedBody := string(createFixedBody(2e4)) + + h := CompressHandlerBrotliLevel(func(ctx *RequestCtx) { + ctx.WriteString(expectedBody) //nolint:errcheck + }, CompressBrotliBestSpeed, CompressBestSpeed) + + var ctx RequestCtx + var resp Response + + // verify uncompressed response + h(&ctx) + s := ctx.Response.String() + br := bufio.NewReader(bytes.NewBufferString(s)) + if err := resp.Read(br); err != nil { + t.Fatalf("unexpected error: %v", err) + } + ce := resp.Header.ContentEncoding() + if string(ce) != "" { + t.Fatalf("unexpected Content-Encoding: %q. Expecting %q", ce, "") + } + vary := resp.Header.Peek("Vary") + if string(vary) != "" { + t.Fatalf("unexpected Vary: %q. Expecting %q", vary, "") + } + body := resp.Body() + if string(body) != expectedBody { + t.Fatalf("unexpected body %q. Expecting %q", body, expectedBody) + } + + // verify gzip-compressed response + ctx.Request.Reset() + ctx.Response.Reset() + ctx.Request.Header.Set("Accept-Encoding", "gzip, deflate, sdhc") + + h(&ctx) + s = ctx.Response.String() + br = bufio.NewReader(bytes.NewBufferString(s)) + if err := resp.Read(br); err != nil { + t.Fatalf("unexpected error: %v", err) + } + ce = resp.Header.ContentEncoding() + if string(ce) != "gzip" { + t.Fatalf("unexpected Content-Encoding: %q. Expecting %q", ce, "gzip") + } + vary = resp.Header.Peek("Vary") + if string(vary) != "Accept-Encoding" { + t.Fatalf("unexpected Vary: %q. Expecting %q", vary, "Accept-Encoding") + } + body, err := resp.BodyGunzip() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if string(body) != expectedBody { + t.Fatalf("unexpected body %q. Expecting %q", body, expectedBody) + } + + // an attempt to compress already compressed response + ctx.Request.Reset() + ctx.Response.Reset() + ctx.Request.Header.Set("Accept-Encoding", "gzip, deflate, sdhc") + hh := CompressHandler(h) + hh(&ctx) + s = ctx.Response.String() + br = bufio.NewReader(bytes.NewBufferString(s)) + if err := resp.Read(br); err != nil { + t.Fatalf("unexpected error: %v", err) + } + ce = resp.Header.ContentEncoding() + if string(ce) != "gzip" { + t.Fatalf("unexpected Content-Encoding: %q. Expecting %q", ce, "gzip") + } + vary = resp.Header.Peek("Vary") + if string(vary) != "Accept-Encoding" { + t.Fatalf("unexpected Vary: %q. Expecting %q", vary, "Accept-Encoding") + } + body, err = resp.BodyGunzip() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if string(body) != expectedBody { + t.Fatalf("unexpected body %q. Expecting %q", body, expectedBody) + } + + // verify deflate-compressed response + ctx.Request.Reset() + ctx.Response.Reset() + ctx.Request.Header.Set(HeaderAcceptEncoding, "foobar, deflate, sdhc") + + h(&ctx) + s = ctx.Response.String() + br = bufio.NewReader(bytes.NewBufferString(s)) + if err := resp.Read(br); err != nil { + t.Fatalf("unexpected error: %v", err) + } + ce = resp.Header.ContentEncoding() + if string(ce) != "deflate" { + t.Fatalf("unexpected Content-Encoding: %q. Expecting %q", ce, "deflate") + } + vary = resp.Header.Peek("Vary") + if string(vary) != "Accept-Encoding" { + t.Fatalf("unexpected Vary: %q. Expecting %q", vary, "Accept-Encoding") + } + body, err = resp.BodyInflate() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if string(body) != expectedBody { + t.Fatalf("unexpected body %q. Expecting %q", body, expectedBody) + } + + // verify br-compressed response + ctx.Request.Reset() + ctx.Response.Reset() + ctx.Request.Header.Set(HeaderAcceptEncoding, "gzip, deflate, br") + + h(&ctx) + s = ctx.Response.String() + br = bufio.NewReader(bytes.NewBufferString(s)) + if err := resp.Read(br); err != nil { + t.Fatalf("unexpected error: %v", err) + } + ce = resp.Header.ContentEncoding() + if string(ce) != "br" { + t.Fatalf("unexpected Content-Encoding: %q. Expecting %q", ce, "br") + } + vary = resp.Header.Peek("Vary") + if string(vary) != "Accept-Encoding" { + t.Fatalf("unexpected Vary: %q. Expecting %q", vary, "Accept-Encoding") + } + body, err = resp.BodyUnbrotli() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if string(body) != expectedBody { + t.Fatalf("unexpected body %q. Expecting %q", body, expectedBody) + } +} + func TestRequestCtxWriteString(t *testing.T) { t.Parallel() diff --git a/strings.go b/strings.go index 0e201a161d..3cec8ed0e1 100644 --- a/strings.go +++ b/strings.go @@ -57,6 +57,7 @@ var ( strProxyAuthenticate = []byte(HeaderProxyAuthenticate) strProxyAuthorization = []byte(HeaderProxyAuthorization) strWWWAuthenticate = []byte(HeaderWWWAuthenticate) + strVary = []byte(HeaderVary) strCookieExpires = []byte("expires") strCookieDomain = []byte("domain")