Skip to content

Commit

Permalink
zstd: Fix extra CRC written with multiple Close calls (#1017)
Browse files Browse the repository at this point in the history
* zstd: Fix extra CRC written with multiple Close calls
* Also check write/flush after close.

Fixes #1016
  • Loading branch information
klauspost authored Oct 8, 2024
1 parent dbd6c38 commit 72cd4a9
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 4 deletions.
26 changes: 23 additions & 3 deletions zstd/encoder.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package zstd

import (
"crypto/rand"
"errors"
"fmt"
"io"
"math"
Expand Down Expand Up @@ -149,6 +150,9 @@ func (e *Encoder) ResetContentSize(w io.Writer, size int64) {
// and write CRC if requested.
func (e *Encoder) Write(p []byte) (n int, err error) {
s := &e.state
if s.eofWritten {
return 0, ErrEncoderClosed
}
for len(p) > 0 {
if len(p)+len(s.filling) < e.o.blockSize {
if e.o.crc {
Expand Down Expand Up @@ -288,6 +292,9 @@ func (e *Encoder) nextBlock(final bool) error {
s.filling, s.current, s.previous = s.previous[:0], s.filling, s.current
s.nInput += int64(len(s.current))
s.wg.Add(1)
if final {
s.eofWritten = true
}
go func(src []byte) {
if debugEncoder {
println("Adding block,", len(src), "bytes, final:", final)
Expand All @@ -303,9 +310,6 @@ func (e *Encoder) nextBlock(final bool) error {
blk := enc.Block()
enc.Encode(blk, src)
blk.last = final
if final {
s.eofWritten = true
}
// Wait for pending writes.
s.wWg.Wait()
if s.writeErr != nil {
Expand Down Expand Up @@ -401,12 +405,20 @@ func (e *Encoder) Flush() error {
if len(s.filling) > 0 {
err := e.nextBlock(false)
if err != nil {
// Ignore Flush after Close.
if errors.Is(s.err, ErrEncoderClosed) {
return nil
}
return err
}
}
s.wg.Wait()
s.wWg.Wait()
if s.err != nil {
// Ignore Flush after Close.
if errors.Is(s.err, ErrEncoderClosed) {
return nil
}
return s.err
}
return s.writeErr
Expand All @@ -422,6 +434,9 @@ func (e *Encoder) Close() error {
}
err := e.nextBlock(true)
if err != nil {
if errors.Is(s.err, ErrEncoderClosed) {
return nil
}
return err
}
if s.frameContentSize > 0 {
Expand Down Expand Up @@ -459,6 +474,11 @@ func (e *Encoder) Close() error {
}
_, s.err = s.w.Write(frame)
}
if s.err == nil {
s.err = ErrEncoderClosed
return nil
}

return s.err
}

Expand Down
11 changes: 10 additions & 1 deletion zstd/encoder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package zstd

import (
"bytes"
"errors"
"fmt"
"io"
"math/rand"
Expand Down Expand Up @@ -278,13 +279,21 @@ func TestEncoderRegression(t *testing.T) {
if err != nil {
t.Error(err)
}
err = enc.Close()
if err != nil {
t.Error(err)
}
_, err = enc.Write([]byte{1, 2, 3, 4})
if !errors.Is(err, ErrEncoderClosed) {
t.Errorf("unexpected error: %v", err)
}
encoded = dst.Bytes()
if len(encoded) > enc.MaxEncodedSize(len(in)) {
t.Errorf("max encoded size for %v: got: %d, want max: %d", len(in), len(encoded), enc.MaxEncodedSize(len(in)))
}
got, err = dec.DecodeAll(encoded, make([]byte, 0, len(in)/2))
if err != nil {
t.Logf("error: %v\nwant: %v\ngot: %v", err, in, got)
t.Logf("error: %v\nwant: %v\ngot: %v", err, len(in), len(got))
t.Error(err)
}
})
Expand Down
4 changes: 4 additions & 0 deletions zstd/zstd.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,10 @@ var (
// Close has been called.
ErrDecoderClosed = errors.New("decoder used after Close")

// ErrEncoderClosed will be returned if the Encoder was used after
// Close has been called.
ErrEncoderClosed = errors.New("encoder used after Close")

// ErrDecoderNilInput is returned when a nil Reader was provided
// and an operation other than Reset/DecodeAll/Close was attempted.
ErrDecoderNilInput = errors.New("nil input provided as reader")
Expand Down

0 comments on commit 72cd4a9

Please sign in to comment.