From f3845e5b31ab9cb6a5ed927577d7f9207175fc07 Mon Sep 17 00:00:00 2001 From: Brandon Bennett Date: Thu, 25 Jan 2024 10:01:35 -0700 Subject: [PATCH] transport: fix infinate loop on chunkReader.Close() --- transport/frame.go | 14 +++++++++++--- transport/frame_test.go | 20 ++++++++++++++++---- 2 files changed, 27 insertions(+), 7 deletions(-) diff --git a/transport/frame.go b/transport/frame.go index b5da749..979d090 100644 --- a/transport/frame.go +++ b/transport/frame.go @@ -180,6 +180,9 @@ func (r *chunkReader) readHeader() error { if _, err := r.r.Discard(2); err != nil { return err } + // not stricly needed but it is the responsibility of this function to + // update chunkLeft. + r.chunkLeft = 0 return io.EOF } @@ -213,7 +216,7 @@ func (r *chunkReader) Read(p []byte) (int, error) { return 0, ErrInvalidIO } - // still reading existing chunk + // done with existing chunck so grab the next one if r.chunkLeft <= 0 { if err := r.readHeader(); err != nil { return 0, err @@ -234,7 +237,7 @@ func (r *chunkReader) ReadByte() (byte, error) { return 0, ErrInvalidIO } - // still reading existing chunk + // done with existing chunck so grab the next one if r.chunkLeft <= 0 { if err := r.readHeader(); err != nil { return 0, err @@ -255,8 +258,11 @@ func (r *chunkReader) Close() error { // poison the reader so that it can no longer be used defer func() { r.r = nil }() + // read all remaining chunks until we get to the end of the frame. for { if r.chunkLeft <= 0 { + // readHeader return io.EOF when it encounter the end-of-frame + // marker ("\n##\n") err := r.readHeader() switch err { case nil: @@ -268,9 +274,11 @@ func (r *chunkReader) Close() error { } } - if _, err := r.r.Discard(r.chunkLeft); err != nil { + n, err := r.r.Discard(r.chunkLeft) + if err != nil { return err } + r.chunkLeft -= n } } diff --git a/transport/frame_test.go b/transport/frame_test.go index 7dd2e0c..3cac2a2 100644 --- a/transport/frame_test.go +++ b/transport/frame_test.go @@ -76,8 +76,9 @@ var chunkedTests = []struct { func TestChunkReaderReadByte(t *testing.T) { for _, tc := range chunkedTests { t.Run(tc.name, func(t *testing.T) { - r := bufio.NewReader(bytes.NewReader(tc.input)) - cr := &chunkReader{r: r} + r := &chunkReader{ + r: bufio.NewReader(bytes.NewReader(tc.input)), + } buf := make([]byte, 8192) @@ -87,7 +88,7 @@ func TestChunkReaderReadByte(t *testing.T) { err error ) for { - b, err = cr.ReadByte() + b, err = r.ReadByte() if err != nil { break } @@ -97,9 +98,12 @@ func TestChunkReaderReadByte(t *testing.T) { buf = buf[:n] if err != io.EOF { - assert.Equal(t, err, tc.err) + assert.Equal(t, tc.err, err) } assert.Equal(t, tc.want, buf) + + // TODO: validate the return error + r.Close() }) } } @@ -114,6 +118,9 @@ func TestChunkReaderRead(t *testing.T) { got, err := io.ReadAll(r) assert.Equal(t, tc.err, err) assert.Equal(t, tc.want, got) + + // TODO: validate the return error + r.Close() }) } } @@ -264,6 +271,9 @@ func TestEOMReadByte(t *testing.T) { } assert.Equal(t, tc.want, buf) + + // TODO: validate the return error + r.Close() }) } } @@ -277,6 +287,8 @@ func TestEOMRead(t *testing.T) { got, err := io.ReadAll(r) assert.Equal(t, err, tc.err) assert.Equal(t, tc.want, got) + // TODO: validate the return error + r.Close() }) } }