diff --git a/go/mysql/binlog_event_compression.go b/go/mysql/binlog_event_compression.go index 378698bc64b..1cb38d5cb16 100644 --- a/go/mysql/binlog_event_compression.go +++ b/go/mysql/binlog_event_compression.go @@ -25,7 +25,6 @@ import ( "github.com/klauspost/compress/zstd" "vitess.io/vitess/go/stats" - "vitess.io/vitess/go/vt/log" "vitess.io/vitess/go/vt/vterrors" vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" @@ -89,26 +88,9 @@ var ( // allocations and GC overhead so this pool allows us to handle // concurrent cases better while still scaling to 0 when there's no // usage. - statefulDecoderPool sync.Pool + statefulDecoderPool = &decoderPool{} ) -func init() { - var err error - statelessDecoder, err = zstd.NewReader(nil, zstd.WithDecoderConcurrency(0)) - if err != nil { // Should only happen e.g. due to ENOMEM - log.Errorf("Error creating stateless decoder: %v", err) - } - statefulDecoderPool = sync.Pool{ - New: func() any { - d, err := zstd.NewReader(nil, zstd.WithDecoderMaxMemory(zstdInMemoryDecompressorMaxSize)) - if err != nil { // Should only happen e.g. due to ENOMEM - log.Errorf("Error creating stateful decoder: %v", err) - } - return d - }, - } -} - type TransactionPayload struct { size uint64 compressionType uint64 @@ -304,12 +286,9 @@ func (tp *TransactionPayload) decompress() error { // larger payloads. if tp.uncompressedSize > zstdInMemoryDecompressorMaxSize { in := bytes.NewReader(tp.payload) - streamDecoder := statefulDecoderPool.Get().(*zstd.Decoder) - if streamDecoder == nil { - return vterrors.New(vtrpcpb.Code_INTERNAL, "failed to create stateful stream decoder") - } - if err := streamDecoder.Reset(in); err != nil { - return vterrors.Wrap(err, "error resetting stateful stream decoder") + streamDecoder, err := statefulDecoderPool.Get(in) + if err != nil { + return err } compressedTrxPayloadsUsingStream.Add(1) tp.reader = streamDecoder @@ -317,8 +296,12 @@ func (tp *TransactionPayload) decompress() error { } // Process smaller payloads using only in-memory buffers. - if statelessDecoder == nil { // Should never happen - return vterrors.New(vtrpcpb.Code_INTERNAL, "failed to create stateless decoder") + if statelessDecoder == nil { // Should only need to be done once + var err error + statelessDecoder, err = zstd.NewReader(nil, zstd.WithDecoderConcurrency(0)) + if err != nil { // Should only happen e.g. due to ENOMEM + return vterrors.Wrap(err, "failed to create stateless decoder") + } } decompressedBytes := make([]byte, 0, tp.uncompressedSize) // Perform a single pre-allocation decompressedBytes, err := statelessDecoder.DecodeAll(tp.payload, decompressedBytes[:0]) @@ -340,11 +323,8 @@ func (tp *TransactionPayload) decompress() error { func (tp *TransactionPayload) Close() { switch reader := tp.reader.(type) { case *zstd.Decoder: - if err := reader.Reset(nil); err == nil || err == io.EOF { - readersPool.Put(reader) - } + statefulDecoderPool.Put(reader) default: - reader = nil } tp.iterator = nil } @@ -368,3 +348,38 @@ func (tp *TransactionPayload) GetNextEvent() (BinlogEvent, error) { //func (tp *TransactionPayload) Events() iter.Seq[BinlogEvent] { // return tp.iterator //} + +// decoderPool manages a sync.Pool of *zstd.Decoders. +type decoderPool struct { + pool sync.Pool +} + +// Get gets a pooled OR new *zstd.Decoder. +func (dp *decoderPool) Get(reader io.Reader) (*zstd.Decoder, error) { + var ( + decoder *zstd.Decoder + ok bool + ) + if pooled := dp.pool.Get(); pooled != nil { + decoder, ok = pooled.(*zstd.Decoder) + if !ok { + return nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "[BUG] expected *zstd.Decoder but got %T", pooled) + } + } else { + d, err := zstd.NewReader(nil, zstd.WithDecoderMaxMemory(zstdInMemoryDecompressorMaxSize)) + if err != nil { // Should only happen e.g. due to ENOMEM + return nil, vterrors.Wrap(err, "failed to create stateful stream decoder") + } + decoder = d + } + if err := decoder.Reset(reader); err != nil { + return nil, vterrors.Wrap(err, "error resetting stateful stream decoder") + } + return decoder, nil +} + +func (dp *decoderPool) Put(decoder *zstd.Decoder) { + if err := decoder.Reset(nil); err == nil || err == io.EOF { + dp.pool.Put(decoder) + } +} diff --git a/go/mysql/binlog_event_compression_test.go b/go/mysql/binlog_event_compression_test.go new file mode 100644 index 00000000000..4c6418a1aa0 --- /dev/null +++ b/go/mysql/binlog_event_compression_test.go @@ -0,0 +1,75 @@ +/* +Copyright 2024 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package mysql + +import ( + "bytes" + "io" + "testing" + + "github.com/klauspost/compress/zstd" + "github.com/stretchr/testify/require" +) + +func TestDecoderPool(t *testing.T) { + validateDecoder := func(t *testing.T, err error, decoder *zstd.Decoder) { + require.NoError(t, err) + require.NotNil(t, decoder) + require.IsType(t, &zstd.Decoder{}, decoder) + } + tests := []struct { + name string + reader io.Reader + wantErr bool + }{ + { + name: "happy path", + reader: bytes.NewReader([]byte{0x68, 0x61, 0x70, 0x70, 0x79}), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // It's not guaranteed that we get the same decoder back from the pool + // that we just put in, so we use a loop and ensure that it worked at + // least one of the times. Without doing this the test would be flaky. + poolingUsed := false + + for i := 0; i < 20; i++ { + decoder, err := statefulDecoderPool.Get(tt.reader) + validateDecoder(t, err, decoder) + statefulDecoderPool.Put(decoder) + + decoder2, err := statefulDecoderPool.Get(tt.reader) + validateDecoder(t, err, decoder2) + if decoder2 == decoder { + poolingUsed = true + } + statefulDecoderPool.Put(decoder2) + + decoder3, err := statefulDecoderPool.Get(tt.reader) + validateDecoder(t, err, decoder3) + if decoder3 == decoder || decoder3 == decoder2 { + poolingUsed = true + } + statefulDecoderPool.Put(decoder3) + } + + require.True(t, poolingUsed) + }) + } +}