From 6b958c2c222bcf715691e3789d7dc09474241121 Mon Sep 17 00:00:00 2001 From: Anthony-Dong Date: Thu, 6 Apr 2023 00:56:31 +0800 Subject: [PATCH] support response body stream (#1414) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat: support response body stream * style: add header interface * Update http.go Co-authored-by: Erik Dubbelboer * feat: support request、response、client stream * fix: close reader bug --------- Co-authored-by: fanhaodong.516 Co-authored-by: Erik Dubbelboer --- client.go | 29 ++++++++++++- http.go | 66 ++++++++++++++++++++++++++-- http_test.go | 119 +++++++++++++++++++++++++++++++++++++++++++++++++++ streaming.go | 19 +++++--- 4 files changed, 220 insertions(+), 13 deletions(-) diff --git a/client.go b/client.go index f0ce67f3c0..b3e26ccd06 100644 --- a/client.go +++ b/client.go @@ -297,6 +297,9 @@ type Client struct { // Connection pool strategy. Can be either LIFO or FIFO (default). ConnPoolStrategy ConnPoolStrategyType + // StreamResponseBody enables response body streaming + StreamResponseBody bool + // ConfigureClient configures the fasthttp.HostClient. ConfigureClient func(hc *HostClient) error @@ -521,6 +524,7 @@ func (c *Client) Do(req *Request, resp *Response) error { MaxConnWaitTimeout: c.MaxConnWaitTimeout, RetryIf: c.RetryIf, ConnPoolStrategy: c.ConnPoolStrategy, + StreamResponseBody: c.StreamResponseBody, clientReaderPool: &c.readerPool, clientWriterPool: &c.writerPool, } @@ -795,6 +799,9 @@ type HostClient struct { // Connection pool strategy. Can be either LIFO or FIFO (default). ConnPoolStrategy ConnPoolStrategyType + // StreamResponseBody enables response body streaming + StreamResponseBody bool + lastUseTime uint32 connsLock sync.Mutex @@ -1331,8 +1338,10 @@ func (c *HostClient) doNonNilReqResp(req *Request, resp *Response) (bool, error) // backing up SkipBody in case it was set explicitly customSkipBody := resp.SkipBody + customStreamBody := resp.StreamBody || c.StreamResponseBody resp.Reset() resp.SkipBody = customSkipBody + resp.StreamBody = customStreamBody req.URI().DisablePathNormalizing = c.DisablePathNormalizing @@ -1442,12 +1451,28 @@ func (c *HostClient) doNonNilReqResp(req *Request, resp *Response) (bool, error) return retry, err } - if resetConnection || req.ConnectionClose() || resp.ConnectionClose() || isConnRST { + closeConn := resetConnection || req.ConnectionClose() || resp.ConnectionClose() || isConnRST + if customStreamBody && resp.bodyStream != nil { + rbs := resp.bodyStream + resp.bodyStream = newCloseReader(rbs, func() error { + if r, ok := rbs.(*requestStream); ok { + releaseRequestStream(r) + } + if closeConn { + c.closeConn(cc) + } else { + c.releaseConn(cc) + } + return nil + }) + return false, nil + } + + if closeConn { c.closeConn(cc) } else { c.releaseConn(cc) } - return false, nil } diff --git a/http.go b/http.go index b20d6658f7..0508857d9b 100644 --- a/http.go +++ b/http.go @@ -92,6 +92,10 @@ type Response struct { // Relevant for bodyStream only. ImmediateHeaderFlush bool + // StreamBody enables response body streaming. + // Use SetBodyStream to set the body stream. + StreamBody bool + bodyStream io.Reader w responseBodyWriter body *bytebufferpool.ByteBuffer @@ -293,6 +297,47 @@ func (resp *Response) BodyWriter() io.Writer { return &resp.w } +// BodyStream returns io.Reader +// +// You must CloseBodyStream or ReleaseRequest after you use it. +func (req *Request) BodyStream() io.Reader { + return req.bodyStream +} + +func (req *Request) CloseBodyStream() error { + return req.closeBodyStream() +} + +// BodyStream returns io.Reader +// +// You must CloseBodyStream or ReleaseResponse after you use it. +func (resp *Response) BodyStream() io.Reader { + return resp.bodyStream +} + +func (resp *Response) CloseBodyStream() error { + return resp.closeBodyStream() +} + +type closeReader struct { + io.Reader + closeFunc func() error +} + +func newCloseReader(r io.Reader, closeFunc func() error) io.ReadCloser { + if r == nil { + panic(`BUG: reader is nil`) + } + return &closeReader{Reader: r, closeFunc: closeFunc} +} + +func (c *closeReader) Close() error { + if c.closeFunc == nil { + return nil + } + return c.closeFunc() +} + // BodyWriter returns writer for populating request body. func (req *Request) BodyWriter() io.Writer { req.w.r = req @@ -1068,6 +1113,7 @@ func (resp *Response) Reset() { resp.raddr = nil resp.laddr = nil resp.ImmediateHeaderFlush = false + resp.StreamBody = false } func (resp *Response) resetSkipHeader() { @@ -1360,7 +1406,7 @@ func (resp *Response) ReadLimitBody(r *bufio.Reader, maxBodySize int) error { } } - if resp.Header.ContentLength() == -1 { + if resp.Header.ContentLength() == -1 && !resp.StreamBody { err = resp.Header.ReadTrailer(r) if err != nil && err != io.EOF { if isConnectionReset(err) { @@ -1383,14 +1429,23 @@ func (resp *Response) ReadBody(r *bufio.Reader, maxBodySize int) (err error) { contentLength := resp.Header.ContentLength() if contentLength >= 0 { bodyBuf.B, err = readBody(r, contentLength, maxBodySize, bodyBuf.B) - + if err == ErrBodyTooLarge && resp.StreamBody { + resp.bodyStream = acquireRequestStream(bodyBuf, r, &resp.Header) + err = nil + } } else if contentLength == -1 { - bodyBuf.B, err = readBodyChunked(r, maxBodySize, bodyBuf.B) - + if resp.StreamBody { + resp.bodyStream = acquireRequestStream(bodyBuf, r, &resp.Header) + } else { + bodyBuf.B, err = readBodyChunked(r, maxBodySize, bodyBuf.B) + } } else { bodyBuf.B, err = readBodyIdentity(r, maxBodySize, bodyBuf.B) resp.Header.SetContentLength(len(bodyBuf.B)) } + if err == nil && resp.StreamBody && resp.bodyStream == nil { + resp.bodyStream = bytes.NewReader(bodyBuf.B) + } return err } @@ -1952,6 +2007,9 @@ func (resp *Response) closeBodyStream() error { if bsc, ok := resp.bodyStream.(io.Closer); ok { err = bsc.Close() } + if bsr, ok := resp.bodyStream.(*requestStream); ok { + releaseRequestStream(bsr) + } resp.bodyStream = nil return err } diff --git a/http_test.go b/http_test.go index ebbfa82986..32194a84ca 100644 --- a/http_test.go +++ b/http_test.go @@ -2946,6 +2946,125 @@ func TestResponseBodyStreamErrorOnPanicDuringClose(t *testing.T) { } } +func TestResponseBodyStream(t *testing.T) { + t.Parallel() + chunkedResp := "HTTP/1.1 200 OK\r\n" + "Transfer-Encoding: chunked\r\n" + "\r\n" + "6\r\n123456\r\n" + "7\r\n1234567\r\n" + "0\r\n\r\n" + simpleResp := "HTTP/1.1 200 OK\r\n" + "Content-Length: 9\r\n" + "\r\n" + "123456789" + t.Run("read chunked response", func(t *testing.T) { + response := AcquireResponse() + response.StreamBody = true + if err := response.Read(bufio.NewReader(bytes.NewBufferString(chunkedResp))); err != nil { + t.Fatalf("parse response find err: %v", err) + } + defer func() { + if err := response.closeBodyStream(); err != nil { + t.Fatalf("close body stream err: %v", err) + } + }() + body, err := io.ReadAll(response.bodyStream) + if err != nil { + t.Fatalf("read body stream err: %v", err) + } + if string(body) != "1234561234567" { + t.Fatalf("unexpected body content, got: %#v, want: %#v", string(body), "1234561234567") + } + }) + t.Run("read simple response", func(t *testing.T) { + resp := AcquireResponse() + resp.StreamBody = true + err := resp.ReadLimitBody(bufio.NewReader(bytes.NewBufferString(simpleResp)), 8) + if err != nil { + t.Fatalf("read limit body err: %v", err) + } + body := resp.BodyStream() + defer func() { + if err := resp.CloseBodyStream(); err != nil { + t.Fatalf("close body stream err: %v", err) + } + }() + content, err := io.ReadAll(body) + if err != nil { + t.Fatalf("read limit body err: %v", err) + } + if string(content) != "123456789" { + t.Fatalf("unexpected body content, got: %#v, want: %#v", string(content), "123456789") + } + }) + t.Run("http client", func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { + if request.URL.Query().Get("chunked") == "true" { + for x := 0; x < 10; x++ { + time.Sleep(time.Millisecond) + writer.Write([]byte(strconv.Itoa(x))) //nolint:errcheck + writer.(http.Flusher).Flush() + } + return + } + writer.Write([]byte(`hello world`)) //nolint:errcheck + })) + + defer server.Close() + + t.Run("normal request", func(t *testing.T) { + client := Client{StreamResponseBody: true} + resp := AcquireResponse() + request := AcquireRequest() + request.SetRequestURI(server.URL) + if err := client.Do(request, resp); err != nil { + t.Fatal(err) + } + stream := resp.BodyStream() + defer func() { + ReleaseResponse(resp) + }() + content, _ := io.ReadAll(stream) + if string(content) != "hello world" { + t.Fatalf("unexpected body content, got: %#v, want: %#v", string(content), "hello world") + } + }) + + t.Run("limit response body size size", func(t *testing.T) { + client := Client{StreamResponseBody: true, MaxResponseBodySize: 20} + resp := AcquireResponse() + request := AcquireRequest() + request.SetRequestURI(server.URL) + if err := client.Do(request, resp); err != nil { + t.Fatal(err) + } + stream := resp.BodyStream() + defer func() { + if err := resp.CloseBodyStream(); err != nil { + t.Fatalf("close stream err: %v", err) + } + }() + content, _ := io.ReadAll(stream) + if string(content) != "hello world" { + t.Fatalf("unexpected body content, got: %#v, want: %#v", string(content), "hello world") + } + }) + + t.Run("chunked", func(t *testing.T) { + client := Client{StreamResponseBody: true} + resp := AcquireResponse() + request := AcquireRequest() + request.SetRequestURI(server.URL + "?chunked=true") + if err := client.Do(request, resp); err != nil { + t.Fatal(err) + } + stream := resp.BodyStream() + defer func() { + if err := resp.CloseBodyStream(); err != nil { + t.Fatalf("close stream err: %v", err) + } + }() + content, _ := io.ReadAll(stream) + if string(content) != "0123456789" { + t.Fatalf("unexpected body content, got: %#v, want: %#v", string(content), "0123456789") + } + }) + }) +} + func TestRequestMultipartFormPipeEmptyFormField(t *testing.T) { t.Parallel() diff --git a/streaming.go b/streaming.go index fc04916d5e..a0a374bec0 100644 --- a/streaming.go +++ b/streaming.go @@ -9,8 +9,13 @@ import ( "github.com/valyala/bytebufferpool" ) +type headerInterface interface { + ContentLength() int + ReadTrailer(r *bufio.Reader) error +} + type requestStream struct { - header *RequestHeader + header headerInterface prefetchedBytes *bytes.Reader reader *bufio.Reader totalBytesRead int @@ -22,7 +27,7 @@ func (rs *requestStream) Read(p []byte) (int, error) { n int err error ) - if rs.header.contentLength == -1 { + if rs.header.ContentLength() == -1 { if rs.chunkLeft == 0 { chunkSize, err := parseChunkSize(rs.reader) if err != nil { @@ -52,7 +57,7 @@ func (rs *requestStream) Read(p []byte) (int, error) { } return n, err } - if rs.totalBytesRead == rs.header.contentLength { + if rs.totalBytesRead == rs.header.ContentLength() { return 0, io.EOF } prefetchedSize := int(rs.prefetchedBytes.Size()) @@ -63,12 +68,12 @@ func (rs *requestStream) Read(p []byte) (int, error) { } n, err := rs.prefetchedBytes.Read(p) rs.totalBytesRead += n - if n == rs.header.contentLength { + if n == rs.header.ContentLength() { return n, io.EOF } return n, err } else { - left := rs.header.contentLength - rs.totalBytesRead + left := rs.header.ContentLength() - rs.totalBytesRead if len(p) > left { p = p[:left] } @@ -79,13 +84,13 @@ func (rs *requestStream) Read(p []byte) (int, error) { } } - if rs.totalBytesRead == rs.header.contentLength { + if rs.totalBytesRead == rs.header.ContentLength() { err = io.EOF } return n, err } -func acquireRequestStream(b *bytebufferpool.ByteBuffer, r *bufio.Reader, h *RequestHeader) *requestStream { +func acquireRequestStream(b *bytebufferpool.ByteBuffer, r *bufio.Reader, h headerInterface) *requestStream { rs := requestStreamPool.Get().(*requestStream) rs.prefetchedBytes = bytes.NewReader(b.B) rs.reader = r