From a68e0b3214a363283a691a088253d6e56307e1ac Mon Sep 17 00:00:00 2001 From: efron Date: Tue, 7 May 2024 17:07:32 -0700 Subject: [PATCH] feat: add GinGzipOrBrotliBodies --- compressmw/compressmw_test.go | 58 ++++++++++++++++++++++++++++++++- compressmw/gincompat.go | 60 +++++++++++++++++++++++++++++++++++ go.mod | 1 + go.sum | 2 ++ 4 files changed, 120 insertions(+), 1 deletion(-) diff --git a/compressmw/compressmw_test.go b/compressmw/compressmw_test.go index f3ab7b4..c0a4728 100644 --- a/compressmw/compressmw_test.go +++ b/compressmw/compressmw_test.go @@ -10,6 +10,7 @@ import ( "strings" "testing" + "github.com/andybalholm/brotli" "github.com/gin-gonic/gin" "github.com/runpod/rpcompress/compressmw" ) @@ -125,6 +126,61 @@ func TestServerAccept(t *testing.T) { } } +func TestGinGzipOrBrotliBodies(t *testing.T) { + for _, tt := range []struct { + encoding string + read func(io.Reader) (string, error) + }{ + { + encoding: "gzip", + read: func(r io.Reader) (string, error) { + gzipR, err := gzip.NewReader(r) + if err != nil { + return "", err + } + b, err := io.ReadAll(gzipR) + return string(b), err + }, + }, + { + encoding: "br", + read: func(r io.Reader) (string, error) { + b, err := io.ReadAll(brotli.NewReader(r)) + return string(b), err + }, + }, + } { + t.Run(tt.encoding, func(t *testing.T) { + const wantBody = "" + req, err := http.NewRequest("POST", "/foo", strings.NewReader(wantBody)) + if err != nil { + t.Fatal(err) + } + // say we can accept brotli + req.Header.Set("Accept-Encoding", tt.encoding) + + router := gin.New() + router.Use(compressmw.GinGzipOrBrotliBodies) // set up the router to use the middleware when it sees "br" in the Accept-Encoding header + router.POST("/foo", func(c *gin.Context) { + io.Copy(c.Writer, c.Request.Body) + }) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + if rec.Code != http.StatusOK { + t.Errorf("got status %d, want %d", rec.Code, http.StatusOK) + } + if rec.Header().Get("Content-Encoding") != tt.encoding { + t.Errorf("got Content-Encoding %q, want %q", rec.Header().Get("Content-Encoding"), "br") + } + if got, err := tt.read(rec.Body); err != nil { + t.Errorf("error reading response body: %v", err) + } else if got != wantBody { + t.Errorf("got %q, want %q", got, wantBody) + } + }) + } +} + // implementation of TestGinGzipBodies per-level func testGinGzipBodies(t *testing.T, lvl int) { router := gin.New() @@ -212,7 +268,7 @@ func testGzipRoundTrip(t *testing.T, lvl int) { const want = "" s := httptest.NewServer(handler) t.Cleanup(s.Close) - client := &http.Client{Transport: compressmw.ClientGzipBody(http.DefaultTransport, gzip.DefaultCompression)} + client := &http.Client{Transport: compressmw.ClientGzipBody(http.DefaultTransport, lvl)} req, err := http.NewRequest("POST", s.URL+"/foo", strings.NewReader(want)) if err != nil { t.Fatal(err) diff --git a/compressmw/gincompat.go b/compressmw/gincompat.go index 88cb229..2b77cb1 100644 --- a/compressmw/gincompat.go +++ b/compressmw/gincompat.go @@ -4,12 +4,18 @@ package compressmw import ( "bufio" + "io" "net" "net/http" + "github.com/andybalholm/brotli" "github.com/gin-gonic/gin" ) +func GinBrotliOrGzip(c *gin.Context) { + brotli.HTTPCompressor(c.Writer, c.Request) +} + func GinAcceptGzip(c *gin.Context) { i := hasGzipAt(c.Request.Header.Values("Content-Encoding")) if i == -1 { @@ -27,6 +33,13 @@ func GinAcceptGzip(c *gin.Context) { c.Next() } +func GinGzipOrBrotliBodies(c *gin.Context) { + wc := brotli.HTTPCompressor(c.Writer, c.Request) + defer wc.Close() + c.Writer = &ginCompatGzipOrBrotliWriter{ginResponseWriter: c.Writer, compressWriter: wc} + c.Next() +} + // GinGzipBodies is a gin.HandlerFunc that compresses the response body with gzip if the client accepts it. Level is in the range 1(gzip.BestSpeed) to 9(gzip.BestCompression). 0 or -1 default to 6. func GinGzipBodies(lvl int) gin.HandlerFunc { lvl = checkgziplevel(lvl) @@ -50,6 +63,53 @@ func GinGzipBodies(lvl int) gin.HandlerFunc { } } +type ginCompatGzipOrBrotliWriter struct { + ginResponseWriter gin.ResponseWriter + compressWriter io.WriteCloser + status int +} + +var _ gin.ResponseWriter = (*ginCompatGzipOrBrotliWriter)(nil) + +func (g *ginCompatGzipOrBrotliWriter) Flush() { g.ginResponseWriter.Flush() } +func (g *ginCompatGzipOrBrotliWriter) Pusher() http.Pusher { return g.ginResponseWriter.Pusher() } +func (g *ginCompatGzipOrBrotliWriter) Header() http.Header { return g.ginResponseWriter.Header() } +func (g *ginCompatGzipOrBrotliWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + return g.ginResponseWriter.Hijack() +} + +func (g *ginCompatGzipOrBrotliWriter) Status() int { + if g.status != 0 { + return g.status + } + return g.ginResponseWriter.Status() +} +func (g *ginCompatGzipOrBrotliWriter) WriteString(s string) (int, error) { return g.Write([]byte(s)) } +func (g *ginCompatGzipOrBrotliWriter) Written() bool { return g.ginResponseWriter.Written() } +func (g *ginCompatGzipOrBrotliWriter) Size() int { return g.ginResponseWriter.Size() } +func (g *ginCompatGzipOrBrotliWriter) CloseNotify() <-chan bool { + return g.ginResponseWriter.CloseNotify() +} + +func (g *ginCompatGzipOrBrotliWriter) WriteHeader(code int) { + g.status = code + g.ginResponseWriter.WriteHeader(code) +} + +func (g *ginCompatGzipOrBrotliWriter) WriteHeaderNow() { + if g.status == 0 { + g.status = http.StatusOK + } + g.ginResponseWriter.WriteHeader(g.status) +} + +func (g *ginCompatGzipOrBrotliWriter) Write(data []byte) (int, error) { + if g.status == 0 { + g.status = http.StatusOK + } + return g.compressWriter.Write(data) +} + // ginCompatGzipWriter implements all 10 billion methods of gin.ResponseWriter // in order to write a simple middleware. // I _strongly_ dislike gin, but it's what we already use... diff --git a/go.mod b/go.mod index dcfcbd4..db8e64d 100644 --- a/go.mod +++ b/go.mod @@ -5,6 +5,7 @@ go 1.21.6 require github.com/gin-gonic/gin v1.9.1 require ( + github.com/andybalholm/brotli v1.1.0 // indirect github.com/bytedance/sonic v1.9.1 // indirect github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect github.com/gabriel-vasile/mimetype v1.4.2 // indirect diff --git a/go.sum b/go.sum index 1a77fa1..6997855 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,5 @@ +github.com/andybalholm/brotli v1.1.0 h1:eLKJA0d02Lf0mVpIDgYnqXcUn0GqVmEFny3VuID1U3M= +github.com/andybalholm/brotli v1.1.0/go.mod h1:sms7XGricyQI9K10gOSf56VKKWS4oLer58Q+mhRPtnY= github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM= github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s= github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U=