Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support response body stream #1414

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 20 additions & 2 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -1320,8 +1320,10 @@ func (c *HostClient) doNonNilReqResp(req *Request, resp *Response) (bool, error)

// backing up SkipBody in case it was set explicitly
customSkipBody := resp.SkipBody
customStreamBody := resp.StreamBody
resp.Reset()
resp.SkipBody = customSkipBody
resp.StreamBody = customStreamBody

req.URI().DisablePathNormalizing = c.DisablePathNormalizing

Expand Down Expand Up @@ -1426,12 +1428,28 @@ func (c *HostClient) doNonNilReqResp(req *Request, resp *Response) (bool, error)
return retry, err
}

if resetConnection || req.ConnectionClose() || resp.ConnectionClose() || isConnRST {
closeConn := resetConnection || req.ConnectionClose() || resp.ConnectionClose() || isConnRST
if customStreamBody && resp.bodyStream != nil {
rbs := resp.bodyStream
resp.bodyStream = newCloseReader(rbs, func() error {
if r, ok := rbs.(*requestStream); ok {
releaseRequestStream(r)
}
if closeConn {
c.closeConn(cc)
} else {
c.releaseConn(cc)
}
return nil
})
return false, nil
}

if closeConn {
c.closeConn(cc)
} else {
c.releaseConn(cc)
}

return false, nil
}

Expand Down
59 changes: 55 additions & 4 deletions http.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,10 @@ type Response struct {
// Relevant for bodyStream only.
ImmediateHeaderFlush bool

// StreamBody enables response body streaming.
// Response.BodyStream() get response body.
Anthony-Dong marked this conversation as resolved.
Show resolved Hide resolved
StreamBody bool

bodyStream io.Reader
w responseBodyWriter
body *bytebufferpool.ByteBuffer
Expand Down Expand Up @@ -293,6 +297,40 @@ func (resp *Response) BodyWriter() io.Writer {
return &resp.w
}

// BodyStream returns io.Reader
//
// You must close it after you use it.
func (resp *Response) BodyStream() io.ReadCloser {
if resp.bodyStream == nil {
resp.bodyStream = bytes.NewReader(resp.Body())
}
return newCloseReader(resp.bodyStream, resp.closeBodyStream)
}
erikdubbelboer marked this conversation as resolved.
Show resolved Hide resolved

type closeReader struct {
io.Reader
closeFunc func() error
closeOnce sync.Once
err error
}

func newCloseReader(r io.Reader, closeFunc func() error) io.ReadCloser {
if r == nil {
panic(`BUG: reader is nil`)
}
return &closeReader{Reader: r, closeFunc: closeFunc}
}

func (c *closeReader) Close() error {
if c.closeFunc == nil {
return nil
}
c.closeOnce.Do(func() {
c.err = c.closeFunc()
})
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this sync.Once needed here? If Close can be called multiple times, shouldn't closeFunc then also not be able to be called multiple times?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if SetBodyStream is true, the clientConn does not immediately close. need to close it manually. close clientConn does not need to be executed multiple times.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For example, sync.Pool has the same object and clientConn is closed multiple times

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure I understand. Can you show how Close would be called multiple times?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought it wouldn't be good to return an io.ReadCloser directly, so I added the CloseBodyStream method @erikdubbelboer

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did some investigating into this issue, #1504, and it seems it is cause by Close being called multiple times. So I understand the sync.Once now. But I'm wondering if we should fix the double close instead.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

perIPConn#Close() method is also wrapped in sync.Once ? It also seems reasonable and easier to deal with

Copy link
Contributor Author

@Anthony-Dong Anthony-Dong Mar 15, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Whether the client end is used out needs the same processing, I think it depends on whether there is concurrent execution of closeBodyStream. If there is, the same problem will indeed occur, but this is an unreasonable use behavior, not a framework defect. @erikdubbelboer

return c.err
}

// BodyWriter returns writer for populating request body.
func (req *Request) BodyWriter() io.Writer {
req.w.r = req
Expand Down Expand Up @@ -1067,6 +1105,7 @@ func (resp *Response) Reset() {
resp.raddr = nil
resp.laddr = nil
resp.ImmediateHeaderFlush = false
resp.StreamBody = false
}

func (resp *Response) resetSkipHeader() {
Expand Down Expand Up @@ -1359,7 +1398,7 @@ func (resp *Response) ReadLimitBody(r *bufio.Reader, maxBodySize int) error {
}
}

if resp.Header.ContentLength() == -1 {
if resp.Header.ContentLength() == -1 && !resp.StreamBody {
err = resp.Header.ReadTrailer(r)
if err != nil && err != io.EOF {
if isConnectionReset(err) {
Expand All @@ -1382,14 +1421,23 @@ func (resp *Response) ReadBody(r *bufio.Reader, maxBodySize int) (err error) {
contentLength := resp.Header.ContentLength()
if contentLength >= 0 {
bodyBuf.B, err = readBody(r, contentLength, maxBodySize, bodyBuf.B)

if err == ErrBodyTooLarge && resp.StreamBody {
resp.bodyStream = acquireRequestStream(bodyBuf, r, &resp.Header)
err = nil
}
} else if contentLength == -1 {
bodyBuf.B, err = readBodyChunked(r, maxBodySize, bodyBuf.B)

if resp.StreamBody {
resp.bodyStream = acquireRequestStream(bodyBuf, r, &resp.Header)
} else {
bodyBuf.B, err = readBodyChunked(r, maxBodySize, bodyBuf.B)
}
} else {
bodyBuf.B, err = readBodyIdentity(r, maxBodySize, bodyBuf.B)
resp.Header.SetContentLength(len(bodyBuf.B))
}
if resp.StreamBody && resp.bodyStream == nil {
resp.bodyStream = bytes.NewBuffer(bodyBuf.B)
}
return err
}

Expand Down Expand Up @@ -1951,6 +1999,9 @@ func (resp *Response) closeBodyStream() error {
if bsc, ok := resp.bodyStream.(io.Closer); ok {
err = bsc.Close()
}
if bsr, ok := resp.bodyStream.(*requestStream); ok {
releaseRequestStream(bsr)
}
resp.bodyStream = nil
return err
}
Expand Down
100 changes: 100 additions & 0 deletions http_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2950,6 +2950,106 @@ func TestResponseBodyStreamErrorOnPanicDuringClose(t *testing.T) {
}
}

func TestResponseBodyStream(t *testing.T) {
t.Parallel()
chunkedResp := "HTTP/1.1 200 OK\r\n" + "Transfer-Encoding: chunked\r\n" + "\r\n" + "6\r\n123456\r\n" + "7\r\n1234567\r\n" + "0\r\n\r\n"
simpleResp := "HTTP/1.1 200 OK\r\n" + "Content-Length: 9\r\n" + "\r\n" + "123456789"
t.Run("read chunked response", func(t *testing.T) {
response := AcquireResponse()
response.StreamBody = true
if err := response.Read(bufio.NewReader(bytes.NewBufferString(chunkedResp))); err != nil {
t.Fatalf("parse response find err: %v", err)
}
defer func() {
if err := response.closeBodyStream(); err != nil {
t.Fatalf("close body stream err: %v", err)
}
}()
body, err := io.ReadAll(response.bodyStream)
if err != nil {
t.Fatalf("read body stream err: %v", err)
}
if string(body) != "1234561234567" {
t.Fatalf("unexpected body content, got: %#v, want: %#v", string(body), "1234561234567")
}
})
t.Run("read simple response", func(t *testing.T) {
resp := AcquireResponse()
resp.StreamBody = true
err := resp.ReadLimitBody(bufio.NewReader(bytes.NewBufferString(simpleResp)), 8)
if err != nil {
t.Fatalf("read limit body err: %v", err)
}
body := resp.BodyStream()
defer func() {
if err := body.Close(); err != nil {
t.Fatalf("close body stream err: %v", err)
}
}()
content, err := io.ReadAll(body)
if err != nil {
t.Fatalf("read limit body err: %v", err)
}
if string(content) != "123456789" {
t.Fatalf("unexpected body content, got: %#v, want: %#v", string(content), "123456789")
}
})
t.Run("http client", func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
if request.URL.Query().Get("chunked") == "true" {
for x := 0; x < 10; x++ {
time.Sleep(time.Millisecond)
writer.Write([]byte(strconv.Itoa(x))) //nolint:errcheck
writer.(http.Flusher).Flush()
}
return
}
writer.Write([]byte(`hello world`)) //nolint:errcheck
}))

defer server.Close()
t.Run("simple-max size", func(t *testing.T) {
resp := AcquireResponse()
resp.StreamBody = true
request := AcquireRequest()
request.SetRequestURI(server.URL)
if err := (&Client{MaxResponseBodySize: 5}).Do(request, resp); err != nil {
t.Fatal(err)
}
stream := resp.BodyStream()
defer func() {
if err := stream.Close(); err != nil {
t.Fatalf("close stream err: %v", err)
}
}()
content, _ := io.ReadAll(stream)
if string(content) != "hello world" {
t.Fatalf("unexpected body content, got: %#v, want: %#v", string(content), "hello world")
}
})

t.Run("chunked", func(t *testing.T) {
resp := AcquireResponse()
resp.StreamBody = true
request := AcquireRequest()
request.SetRequestURI(server.URL + "?chunked=true")
if err := Do(request, resp); err != nil {
t.Fatal(err)
}
stream := resp.BodyStream()
defer func() {
if err := stream.Close(); err != nil {
t.Fatalf("close stream err: %v", err)
}
}()
content, _ := io.ReadAll(stream)
if string(content) != "0123456789" {
t.Fatalf("unexpected body content, got: %#v, want: %#v", string(content), "0123456789")
}
})
})
}

func TestRequestMultipartFormPipeEmptyFormField(t *testing.T) {
t.Parallel()

Expand Down
20 changes: 13 additions & 7 deletions streaming.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@ import (
)

type requestStream struct {
header *RequestHeader
header interface {
ContentLength() int
ReadTrailer(r *bufio.Reader) error
}
Anthony-Dong marked this conversation as resolved.
Show resolved Hide resolved
prefetchedBytes *bytes.Reader
reader *bufio.Reader
totalBytesRead int
Expand All @@ -22,7 +25,7 @@ func (rs *requestStream) Read(p []byte) (int, error) {
n int
err error
)
if rs.header.contentLength == -1 {
if rs.header.ContentLength() == -1 {
if rs.chunkLeft == 0 {
chunkSize, err := parseChunkSize(rs.reader)
if err != nil {
Expand Down Expand Up @@ -52,7 +55,7 @@ func (rs *requestStream) Read(p []byte) (int, error) {
}
return n, err
}
if rs.totalBytesRead == rs.header.contentLength {
if rs.totalBytesRead == rs.header.ContentLength() {
return 0, io.EOF
}
prefetchedSize := int(rs.prefetchedBytes.Size())
Expand All @@ -63,12 +66,12 @@ func (rs *requestStream) Read(p []byte) (int, error) {
}
n, err := rs.prefetchedBytes.Read(p)
rs.totalBytesRead += n
if n == rs.header.contentLength {
if n == rs.header.ContentLength() {
return n, io.EOF
}
return n, err
} else {
left := rs.header.contentLength - rs.totalBytesRead
left := rs.header.ContentLength() - rs.totalBytesRead
if len(p) > left {
p = p[:left]
}
Expand All @@ -79,13 +82,16 @@ func (rs *requestStream) Read(p []byte) (int, error) {
}
}

if rs.totalBytesRead == rs.header.contentLength {
if rs.totalBytesRead == rs.header.ContentLength() {
err = io.EOF
}
return n, err
}

func acquireRequestStream(b *bytebufferpool.ByteBuffer, r *bufio.Reader, h *RequestHeader) *requestStream {
func acquireRequestStream(b *bytebufferpool.ByteBuffer, r *bufio.Reader, h interface {
ContentLength() int
ReadTrailer(r *bufio.Reader) error
}) *requestStream {
rs := requestStreamPool.Get().(*requestStream)
rs.prefetchedBytes = bytes.NewReader(b.B)
rs.reader = r
Expand Down