Skip to content

Commit

Permalink
support response body stream (#1414)
Browse files Browse the repository at this point in the history
* feat: support response body stream

* style: add header interface

* Update http.go

Co-authored-by: Erik Dubbelboer <[email protected]>

* feat: support request、response、client stream

* fix: close reader bug

---------

Co-authored-by: fanhaodong.516 <[email protected]>
Co-authored-by: Erik Dubbelboer <[email protected]>
  • Loading branch information
3 people authored Apr 5, 2023
1 parent 239cce4 commit 6b958c2
Show file tree
Hide file tree
Showing 4 changed files with 220 additions and 13 deletions.
29 changes: 27 additions & 2 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,9 @@ type Client struct {
// Connection pool strategy. Can be either LIFO or FIFO (default).
ConnPoolStrategy ConnPoolStrategyType

// StreamResponseBody enables response body streaming
StreamResponseBody bool

// ConfigureClient configures the fasthttp.HostClient.
ConfigureClient func(hc *HostClient) error

Expand Down Expand Up @@ -521,6 +524,7 @@ func (c *Client) Do(req *Request, resp *Response) error {
MaxConnWaitTimeout: c.MaxConnWaitTimeout,
RetryIf: c.RetryIf,
ConnPoolStrategy: c.ConnPoolStrategy,
StreamResponseBody: c.StreamResponseBody,
clientReaderPool: &c.readerPool,
clientWriterPool: &c.writerPool,
}
Expand Down Expand Up @@ -795,6 +799,9 @@ type HostClient struct {
// Connection pool strategy. Can be either LIFO or FIFO (default).
ConnPoolStrategy ConnPoolStrategyType

// StreamResponseBody enables response body streaming
StreamResponseBody bool

lastUseTime uint32

connsLock sync.Mutex
Expand Down Expand Up @@ -1331,8 +1338,10 @@ func (c *HostClient) doNonNilReqResp(req *Request, resp *Response) (bool, error)

// backing up SkipBody in case it was set explicitly
customSkipBody := resp.SkipBody
customStreamBody := resp.StreamBody || c.StreamResponseBody
resp.Reset()
resp.SkipBody = customSkipBody
resp.StreamBody = customStreamBody

req.URI().DisablePathNormalizing = c.DisablePathNormalizing

Expand Down Expand Up @@ -1442,12 +1451,28 @@ func (c *HostClient) doNonNilReqResp(req *Request, resp *Response) (bool, error)
return retry, err
}

if resetConnection || req.ConnectionClose() || resp.ConnectionClose() || isConnRST {
closeConn := resetConnection || req.ConnectionClose() || resp.ConnectionClose() || isConnRST
if customStreamBody && resp.bodyStream != nil {
rbs := resp.bodyStream
resp.bodyStream = newCloseReader(rbs, func() error {
if r, ok := rbs.(*requestStream); ok {
releaseRequestStream(r)
}
if closeConn {
c.closeConn(cc)
} else {
c.releaseConn(cc)
}
return nil
})
return false, nil
}

if closeConn {
c.closeConn(cc)
} else {
c.releaseConn(cc)
}

return false, nil
}

Expand Down
66 changes: 62 additions & 4 deletions http.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,10 @@ type Response struct {
// Relevant for bodyStream only.
ImmediateHeaderFlush bool

// StreamBody enables response body streaming.
// Use SetBodyStream to set the body stream.
StreamBody bool

bodyStream io.Reader
w responseBodyWriter
body *bytebufferpool.ByteBuffer
Expand Down Expand Up @@ -293,6 +297,47 @@ func (resp *Response) BodyWriter() io.Writer {
return &resp.w
}

// BodyStream returns io.Reader
//
// You must CloseBodyStream or ReleaseRequest after you use it.
func (req *Request) BodyStream() io.Reader {
return req.bodyStream
}

func (req *Request) CloseBodyStream() error {
return req.closeBodyStream()
}

// BodyStream returns io.Reader
//
// You must CloseBodyStream or ReleaseResponse after you use it.
func (resp *Response) BodyStream() io.Reader {
return resp.bodyStream
}

func (resp *Response) CloseBodyStream() error {
return resp.closeBodyStream()
}

type closeReader struct {
io.Reader
closeFunc func() error
}

func newCloseReader(r io.Reader, closeFunc func() error) io.ReadCloser {
if r == nil {
panic(`BUG: reader is nil`)
}
return &closeReader{Reader: r, closeFunc: closeFunc}
}

func (c *closeReader) Close() error {
if c.closeFunc == nil {
return nil
}
return c.closeFunc()
}

// BodyWriter returns writer for populating request body.
func (req *Request) BodyWriter() io.Writer {
req.w.r = req
Expand Down Expand Up @@ -1068,6 +1113,7 @@ func (resp *Response) Reset() {
resp.raddr = nil
resp.laddr = nil
resp.ImmediateHeaderFlush = false
resp.StreamBody = false
}

func (resp *Response) resetSkipHeader() {
Expand Down Expand Up @@ -1360,7 +1406,7 @@ func (resp *Response) ReadLimitBody(r *bufio.Reader, maxBodySize int) error {
}
}

if resp.Header.ContentLength() == -1 {
if resp.Header.ContentLength() == -1 && !resp.StreamBody {
err = resp.Header.ReadTrailer(r)
if err != nil && err != io.EOF {
if isConnectionReset(err) {
Expand All @@ -1383,14 +1429,23 @@ func (resp *Response) ReadBody(r *bufio.Reader, maxBodySize int) (err error) {
contentLength := resp.Header.ContentLength()
if contentLength >= 0 {
bodyBuf.B, err = readBody(r, contentLength, maxBodySize, bodyBuf.B)

if err == ErrBodyTooLarge && resp.StreamBody {
resp.bodyStream = acquireRequestStream(bodyBuf, r, &resp.Header)
err = nil
}
} else if contentLength == -1 {
bodyBuf.B, err = readBodyChunked(r, maxBodySize, bodyBuf.B)

if resp.StreamBody {
resp.bodyStream = acquireRequestStream(bodyBuf, r, &resp.Header)
} else {
bodyBuf.B, err = readBodyChunked(r, maxBodySize, bodyBuf.B)
}
} else {
bodyBuf.B, err = readBodyIdentity(r, maxBodySize, bodyBuf.B)
resp.Header.SetContentLength(len(bodyBuf.B))
}
if err == nil && resp.StreamBody && resp.bodyStream == nil {
resp.bodyStream = bytes.NewReader(bodyBuf.B)
}
return err
}

Expand Down Expand Up @@ -1952,6 +2007,9 @@ func (resp *Response) closeBodyStream() error {
if bsc, ok := resp.bodyStream.(io.Closer); ok {
err = bsc.Close()
}
if bsr, ok := resp.bodyStream.(*requestStream); ok {
releaseRequestStream(bsr)
}
resp.bodyStream = nil
return err
}
Expand Down
119 changes: 119 additions & 0 deletions http_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2946,6 +2946,125 @@ func TestResponseBodyStreamErrorOnPanicDuringClose(t *testing.T) {
}
}

func TestResponseBodyStream(t *testing.T) {
t.Parallel()
chunkedResp := "HTTP/1.1 200 OK\r\n" + "Transfer-Encoding: chunked\r\n" + "\r\n" + "6\r\n123456\r\n" + "7\r\n1234567\r\n" + "0\r\n\r\n"
simpleResp := "HTTP/1.1 200 OK\r\n" + "Content-Length: 9\r\n" + "\r\n" + "123456789"
t.Run("read chunked response", func(t *testing.T) {
response := AcquireResponse()
response.StreamBody = true
if err := response.Read(bufio.NewReader(bytes.NewBufferString(chunkedResp))); err != nil {
t.Fatalf("parse response find err: %v", err)
}
defer func() {
if err := response.closeBodyStream(); err != nil {
t.Fatalf("close body stream err: %v", err)
}
}()
body, err := io.ReadAll(response.bodyStream)
if err != nil {
t.Fatalf("read body stream err: %v", err)
}
if string(body) != "1234561234567" {
t.Fatalf("unexpected body content, got: %#v, want: %#v", string(body), "1234561234567")
}
})
t.Run("read simple response", func(t *testing.T) {
resp := AcquireResponse()
resp.StreamBody = true
err := resp.ReadLimitBody(bufio.NewReader(bytes.NewBufferString(simpleResp)), 8)
if err != nil {
t.Fatalf("read limit body err: %v", err)
}
body := resp.BodyStream()
defer func() {
if err := resp.CloseBodyStream(); err != nil {
t.Fatalf("close body stream err: %v", err)
}
}()
content, err := io.ReadAll(body)
if err != nil {
t.Fatalf("read limit body err: %v", err)
}
if string(content) != "123456789" {
t.Fatalf("unexpected body content, got: %#v, want: %#v", string(content), "123456789")
}
})
t.Run("http client", func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
if request.URL.Query().Get("chunked") == "true" {
for x := 0; x < 10; x++ {
time.Sleep(time.Millisecond)
writer.Write([]byte(strconv.Itoa(x))) //nolint:errcheck
writer.(http.Flusher).Flush()
}
return
}
writer.Write([]byte(`hello world`)) //nolint:errcheck
}))

defer server.Close()

t.Run("normal request", func(t *testing.T) {
client := Client{StreamResponseBody: true}
resp := AcquireResponse()
request := AcquireRequest()
request.SetRequestURI(server.URL)
if err := client.Do(request, resp); err != nil {
t.Fatal(err)
}
stream := resp.BodyStream()
defer func() {
ReleaseResponse(resp)
}()
content, _ := io.ReadAll(stream)
if string(content) != "hello world" {
t.Fatalf("unexpected body content, got: %#v, want: %#v", string(content), "hello world")
}
})

t.Run("limit response body size size", func(t *testing.T) {
client := Client{StreamResponseBody: true, MaxResponseBodySize: 20}
resp := AcquireResponse()
request := AcquireRequest()
request.SetRequestURI(server.URL)
if err := client.Do(request, resp); err != nil {
t.Fatal(err)
}
stream := resp.BodyStream()
defer func() {
if err := resp.CloseBodyStream(); err != nil {
t.Fatalf("close stream err: %v", err)
}
}()
content, _ := io.ReadAll(stream)
if string(content) != "hello world" {
t.Fatalf("unexpected body content, got: %#v, want: %#v", string(content), "hello world")
}
})

t.Run("chunked", func(t *testing.T) {
client := Client{StreamResponseBody: true}
resp := AcquireResponse()
request := AcquireRequest()
request.SetRequestURI(server.URL + "?chunked=true")
if err := client.Do(request, resp); err != nil {
t.Fatal(err)
}
stream := resp.BodyStream()
defer func() {
if err := resp.CloseBodyStream(); err != nil {
t.Fatalf("close stream err: %v", err)
}
}()
content, _ := io.ReadAll(stream)
if string(content) != "0123456789" {
t.Fatalf("unexpected body content, got: %#v, want: %#v", string(content), "0123456789")
}
})
})
}

func TestRequestMultipartFormPipeEmptyFormField(t *testing.T) {
t.Parallel()

Expand Down
19 changes: 12 additions & 7 deletions streaming.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,13 @@ import (
"github.com/valyala/bytebufferpool"
)

type headerInterface interface {
ContentLength() int
ReadTrailer(r *bufio.Reader) error
}

type requestStream struct {
header *RequestHeader
header headerInterface
prefetchedBytes *bytes.Reader
reader *bufio.Reader
totalBytesRead int
Expand All @@ -22,7 +27,7 @@ func (rs *requestStream) Read(p []byte) (int, error) {
n int
err error
)
if rs.header.contentLength == -1 {
if rs.header.ContentLength() == -1 {
if rs.chunkLeft == 0 {
chunkSize, err := parseChunkSize(rs.reader)
if err != nil {
Expand Down Expand Up @@ -52,7 +57,7 @@ func (rs *requestStream) Read(p []byte) (int, error) {
}
return n, err
}
if rs.totalBytesRead == rs.header.contentLength {
if rs.totalBytesRead == rs.header.ContentLength() {
return 0, io.EOF
}
prefetchedSize := int(rs.prefetchedBytes.Size())
Expand All @@ -63,12 +68,12 @@ func (rs *requestStream) Read(p []byte) (int, error) {
}
n, err := rs.prefetchedBytes.Read(p)
rs.totalBytesRead += n
if n == rs.header.contentLength {
if n == rs.header.ContentLength() {
return n, io.EOF
}
return n, err
} else {
left := rs.header.contentLength - rs.totalBytesRead
left := rs.header.ContentLength() - rs.totalBytesRead
if len(p) > left {
p = p[:left]
}
Expand All @@ -79,13 +84,13 @@ func (rs *requestStream) Read(p []byte) (int, error) {
}
}

if rs.totalBytesRead == rs.header.contentLength {
if rs.totalBytesRead == rs.header.ContentLength() {
err = io.EOF
}
return n, err
}

func acquireRequestStream(b *bytebufferpool.ByteBuffer, r *bufio.Reader, h *RequestHeader) *requestStream {
func acquireRequestStream(b *bytebufferpool.ByteBuffer, r *bufio.Reader, h headerInterface) *requestStream {
rs := requestStreamPool.Get().(*requestStream)
rs.prefetchedBytes = bytes.NewReader(b.B)
rs.reader = r
Expand Down

0 comments on commit 6b958c2

Please sign in to comment.