Skip to content

Commit

Permalink
store: properly handle snappy compression continuations (#6602)
Browse files Browse the repository at this point in the history
Snappy works on byte level and it can cut two different chunks in the
middle of a varint. Thus, if there's some error from the Decbuf then
fill up the buffer and try reading a varint again. Added repro test.

Closes #6545.

Signed-off-by: Giedrius Statkevičius <[email protected]>
  • Loading branch information
GiedriusS authored Aug 10, 2023
1 parent 84567ec commit d6a8f0b
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 23 deletions.
7 changes: 7 additions & 0 deletions pkg/store/6545postingsrepro

Large diffs are not rendered by default.

70 changes: 47 additions & 23 deletions pkg/store/postings_codec.go
Original file line number Diff line number Diff line change
Expand Up @@ -227,10 +227,7 @@ func (it *streamedDiffVarintPostings) At() storage.SeriesRef {
return it.curSeries
}

func (it *streamedDiffVarintPostings) readNextChunk() bool {
if len(it.db.B) > 0 {
return true
}
func (it *streamedDiffVarintPostings) readNextChunk(remainder []byte) bool {
// Normal EOF.
if len(it.input) == 0 {
return false
Expand All @@ -255,13 +252,13 @@ func (it *streamedDiffVarintPostings) readNextChunk() bool {
it.err = fmt.Errorf("corrupted identifier")
return false
}
if string(it.input[:6]) != magicBody {
if string(it.input[:len(magicBody)]) != magicBody {
it.err = fmt.Errorf("got bad identifier %s", string(it.input[:6]))
return false
}
it.input = it.input[6:]
it.readSnappyIdentifier = true
return it.readNextChunk()
return it.readNextChunk(nil)
case chunkTypeCompressedData:
if !it.readSnappyIdentifier {
it.err = fmt.Errorf("missing magic snappy marker")
Expand All @@ -276,7 +273,6 @@ func (it *streamedDiffVarintPostings) readNextChunk() bool {
it.err = io.ErrUnexpectedEOF
return false
}
encodedBuf := it.input[:chunkLen]

if it.buf == nil {
if it.disablePooling {
Expand All @@ -291,6 +287,15 @@ func (it *streamedDiffVarintPostings) readNextChunk() bool {
}
}

encodedBuf := it.input[:chunkLen]

// NOTE(GiedriusS): we can probably optimize this better but this should be rare enough
// and not cause any problems.
if len(remainder) > 0 {
remainderCopy := make([]byte, 0, len(remainder))
remainderCopy = append(remainderCopy, remainder...)
remainder = remainderCopy
}
decoded, err := s2.Decode(it.buf, encodedBuf[checksumSize:])
if err != nil {
it.err = err
Expand All @@ -300,7 +305,11 @@ func (it *streamedDiffVarintPostings) readNextChunk() bool {
it.err = fmt.Errorf("mismatched checksum (got %v, expected %v)", crc(decoded), checksum)
return false
}
it.db.B = decoded
if len(remainder) > 0 {
it.db.B = append(remainder, decoded...)
} else {
it.db.B = decoded
}
case chunkTypeUncompressedData:
if !it.readSnappyIdentifier {
it.err = fmt.Errorf("missing magic snappy marker")
Expand All @@ -315,11 +324,25 @@ func (it *streamedDiffVarintPostings) readNextChunk() bool {
it.err = io.ErrUnexpectedEOF
return false
}
it.db.B = it.input[checksumSize:chunkLen]
if crc(it.db.B) != checksum {
it.err = fmt.Errorf("mismatched checksum (got %v, expected %v)", crc(it.db.B), checksum)
uncompressedData := it.input[checksumSize:chunkLen]
if crc(uncompressedData) != checksum {
it.err = fmt.Errorf("mismatched checksum (got %v, expected %v)", crc(uncompressedData), checksum)
return false
}

// NOTE(GiedriusS): we can probably optimize this better but this should be rare enough
// and not cause any problems.
if len(remainder) > 0 {
remainderCopy := make([]byte, 0, len(remainder))
remainderCopy = append(remainderCopy, remainder...)
remainder = remainderCopy
}

if len(remainder) > 0 {
it.db.B = append(remainder, uncompressedData...)
} else {
it.db.B = uncompressedData
}
default:
if chunkType <= 0x7f {
it.err = fmt.Errorf("unsupported chunk type %v", chunkType)
Expand All @@ -336,19 +359,21 @@ func (it *streamedDiffVarintPostings) readNextChunk() bool {
}

func (it *streamedDiffVarintPostings) Next() bool {
if !it.readNextChunk() {
return false
}
val := it.db.Uvarint()
if it.db.Err() != nil {
if it.db.Err() != io.EOF {
it.err = it.db.Err()
// Continue reading next chunks until there is at least binary.MaxVarintLen64.
// If we cannot add any more chunks then return false.
for {
val := it.db.Uvarint64()
if it.db.Err() != nil {
if !it.readNextChunk(it.db.B) {
return false
}
it.db.E = nil
continue
}
return false
}

it.curSeries = it.curSeries + storage.SeriesRef(val)
return true
it.curSeries = it.curSeries + storage.SeriesRef(val)
return true
}
}

func (it *streamedDiffVarintPostings) Err() error {
Expand Down Expand Up @@ -534,7 +559,6 @@ func snappyStreamedEncode(postingsLength int, diffVarintPostings []byte) ([]byte
if err != nil {
return nil, fmt.Errorf("creating snappy compressor: %w", err)
}

_, err = sw.Write(diffVarintPostings)
if err != nil {
return nil, err
Expand Down
46 changes: 46 additions & 0 deletions pkg/store/postings_codec_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@ import (
"bytes"
"context"
crand "crypto/rand"
"io"
"math"
"math/rand"
"os"
"sort"
"strconv"
"testing"
Expand Down Expand Up @@ -338,3 +340,47 @@ func FuzzSnappyStreamEncoding(f *testing.F) {
testutil.Ok(t, err)
})
}

func TestRegressionIssue6545(t *testing.T) {
diffVarintPostings, err := os.ReadFile("6545postingsrepro")
testutil.Ok(t, err)

gotPostings := 0
dvp := newDiffVarintPostings(diffVarintPostings, nil)
decodedPostings := []storage.SeriesRef{}
for dvp.Next() {
decodedPostings = append(decodedPostings, dvp.At())
gotPostings++
}
testutil.Ok(t, dvp.Err())
testutil.Equals(t, 114024, gotPostings)

dataToCache, err := snappyStreamedEncode(114024, diffVarintPostings)
testutil.Ok(t, err)

// Check that the original decompressor works well.
sr := s2.NewReader(bytes.NewBuffer(dataToCache[3:]))
readBytes, err := io.ReadAll(sr)
testutil.Ok(t, err)
testutil.Equals(t, readBytes, diffVarintPostings)

dvp = newDiffVarintPostings(readBytes, nil)
gotPostings = 0
for dvp.Next() {
gotPostings++
}
testutil.Equals(t, 114024, gotPostings)

p, err := decodePostings(dataToCache)
testutil.Ok(t, err)

i := 0
for p.Next() {
post := p.At()
testutil.Equals(t, uint64(decodedPostings[i]), uint64(post))
i++
}

testutil.Ok(t, p.Err())
testutil.Equals(t, 114024, i)
}

0 comments on commit d6a8f0b

Please sign in to comment.