Skip to content

Commit

Permalink
zstd: Fix extra CRC written with multiple Close calls
Browse files Browse the repository at this point in the history
Fixes #1016
  • Loading branch information
klauspost committed Oct 8, 2024
1 parent dbd6c38 commit a452cbb
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 2 deletions.
8 changes: 7 additions & 1 deletion zstd/encoder.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package zstd

import (
"crypto/rand"
"errors"
"fmt"
"io"
"math"
Expand Down Expand Up @@ -417,7 +418,7 @@ func (e *Encoder) Flush() error {
// The Encoder can still be re-used after calling this.
func (e *Encoder) Close() error {
s := &e.state
if s.encoder == nil {
if s.encoder == nil || errors.Is(s.err, ErrDecoderClosed) {
return nil
}
err := e.nextBlock(true)
Expand Down Expand Up @@ -459,6 +460,11 @@ func (e *Encoder) Close() error {
}
_, s.err = s.w.Write(frame)
}
if s.err == nil {
s.err = ErrDecoderClosed
return nil
}

return s.err
}

Expand Down
6 changes: 5 additions & 1 deletion zstd/encoder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -278,13 +278,17 @@ func TestEncoderRegression(t *testing.T) {
if err != nil {
t.Error(err)
}
err = enc.Close()
if err != nil {
t.Error(err)
}
encoded = dst.Bytes()
if len(encoded) > enc.MaxEncodedSize(len(in)) {
t.Errorf("max encoded size for %v: got: %d, want max: %d", len(in), len(encoded), enc.MaxEncodedSize(len(in)))
}
got, err = dec.DecodeAll(encoded, make([]byte, 0, len(in)/2))
if err != nil {
t.Logf("error: %v\nwant: %v\ngot: %v", err, in, got)
t.Logf("error: %v\nwant: %v\ngot: %v", err, len(in), len(got))
t.Error(err)
}
})
Expand Down
4 changes: 4 additions & 0 deletions zstd/zstd.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,10 @@ var (
// Close has been called.
ErrDecoderClosed = errors.New("decoder used after Close")

// ErrEncoderClosed will be returned if the Encoder was used after
// Close has been called.
ErrEncoderClosed = errors.New("encoder used after Close")

// ErrDecoderNilInput is returned when a nil Reader was provided
// and an operation other than Reset/DecodeAll/Close was attempted.
ErrDecoderNilInput = errors.New("nil input provided as reader")
Expand Down

0 comments on commit a452cbb

Please sign in to comment.