From 97e30b594d9cd4b46237519c035fca2ed750a171 Mon Sep 17 00:00:00 2001 From: Ted Xu Date: Fri, 22 Mar 2024 18:14:09 +0800 Subject: [PATCH 1/3] feat: adding binlog streaming writer Signed-off-by: Ted Xu --- internal/storage/binlog_iterator_test.go | 38 +- internal/storage/serde.go | 850 ++++++++++++++++++++--- internal/storage/serde_test.go | 241 ++++--- 3 files changed, 921 insertions(+), 208 deletions(-) diff --git a/internal/storage/binlog_iterator_test.go b/internal/storage/binlog_iterator_test.go index 98a213d1b62d3..ad65134b88091 100644 --- a/internal/storage/binlog_iterator_test.go +++ b/internal/storage/binlog_iterator_test.go @@ -22,13 +22,14 @@ import ( "github.com/stretchr/testify/assert" + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/proto/etcdpb" "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/util/testutils" ) -func generateTestData(num int) ([]*Blob, error) { +func generateTestSchema() *schemapb.CollectionSchema { schema := &schemapb.CollectionSchema{Fields: []*schemapb.FieldSchema{ {FieldID: common.TimeStampField, Name: "ts", DataType: schemapb.DataType_Int64}, {FieldID: common.RowIDField, Name: "rowid", DataType: schemapb.DataType_Int64}, @@ -43,13 +44,28 @@ func generateTestData(num int) ([]*Blob, error) { {FieldID: 18, Name: "array", DataType: schemapb.DataType_Array}, {FieldID: 19, Name: "string", DataType: schemapb.DataType_JSON}, {FieldID: 101, Name: "int32", DataType: schemapb.DataType_Int32}, - {FieldID: 102, Name: "floatVector", DataType: schemapb.DataType_FloatVector}, - {FieldID: 103, Name: "binaryVector", DataType: schemapb.DataType_BinaryVector}, - {FieldID: 104, Name: "float16Vector", DataType: schemapb.DataType_Float16Vector}, - {FieldID: 105, Name: "bf16Vector", DataType: schemapb.DataType_BFloat16Vector}, - {FieldID: 106, Name: "sparseFloatVector", DataType: schemapb.DataType_SparseFloatVector}, + {FieldID: 102, Name: "floatVector", DataType: schemapb.DataType_FloatVector, TypeParams: []*commonpb.KeyValuePair{ + {Key: common.DimKey, Value: "8"}, + }}, + {FieldID: 103, Name: "binaryVector", DataType: schemapb.DataType_BinaryVector, TypeParams: []*commonpb.KeyValuePair{ + {Key: common.DimKey, Value: "8"}, + }}, + {FieldID: 104, Name: "float16Vector", DataType: schemapb.DataType_Float16Vector, TypeParams: []*commonpb.KeyValuePair{ + {Key: common.DimKey, Value: "8"}, + }}, + {FieldID: 105, Name: "bf16Vector", DataType: schemapb.DataType_BFloat16Vector, TypeParams: []*commonpb.KeyValuePair{ + {Key: common.DimKey, Value: "8"}, + }}, + {FieldID: 106, Name: "sparseFloatVector", DataType: schemapb.DataType_SparseFloatVector, TypeParams: []*commonpb.KeyValuePair{ + {Key: common.DimKey, Value: "28433"}, + }}, }} - insertCodec := NewInsertCodecWithSchema(&etcdpb.CollectionMeta{ID: 1, Schema: schema}) + + return schema +} + +func generateTestData(num int) ([]*Blob, error) { + insertCodec := NewInsertCodecWithSchema(&etcdpb.CollectionMeta{ID: 1, Schema: generateTestSchema()}) var ( field0 []int64 @@ -105,7 +121,7 @@ func generateTestData(num int) ([]*Blob, error) { field102 = append(field102, f102...) field103 = append(field103, 0xff) - f104 := make([]byte, 8) + f104 := make([]byte, 16) for j := range f104 { f104[j] = byte(i) } @@ -140,11 +156,11 @@ func generateTestData(num int) ([]*Blob, error) { }, 104: &Float16VectorFieldData{ Data: field104, - Dim: 4, + Dim: 8, }, 105: &BFloat16VectorFieldData{ Data: field105, - Dim: 4, + Dim: 8, }, 106: &SparseFloatVectorFieldData{ SparseFloatArray: schemapb.SparseFloatArray{ @@ -165,7 +181,7 @@ func assertTestData(t *testing.T, i int, value *Value) { f102[j] = float32(i) } - f104 := make([]byte, 8) + f104 := make([]byte, 16) for j := range f104 { f104[j] = byte(i) } diff --git a/internal/storage/serde.go b/internal/storage/serde.go index b00b0dedcbdbd..4e0657b6c9cde 100644 --- a/internal/storage/serde.go +++ b/internal/storage/serde.go @@ -17,19 +17,25 @@ package storage import ( + "bytes" + "encoding/binary" "fmt" "io" + "math" "sort" "strconv" "github.com/apache/arrow/go/v12/arrow" "github.com/apache/arrow/go/v12/arrow/array" + "github.com/apache/arrow/go/v12/arrow/memory" + "github.com/apache/arrow/go/v12/parquet/pqarrow" "github.com/cockroachdb/errors" "github.com/golang/protobuf/proto" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/util/metautil" + "github.com/milvus-io/milvus/pkg/util/typeutil" ) type Record interface { @@ -45,6 +51,16 @@ type RecordReader interface { Close() } +type RecordWriter interface { + Write(r Record) error + Close() +} + +type ( + Serializer[T any] func([]T) (Record, uint64, error) + Deserializer[T any] func(Record, []T) error +) + // compositeRecord is a record being composed of multiple records, in which each only have 1 column type compositeRecord struct { recs map[FieldID]arrow.Record @@ -137,7 +153,9 @@ func (crr *compositeRecordReader) Next() error { recs: make(map[FieldID]arrow.Record, len(crr.rrs)), schema: make(map[FieldID]schemapb.DataType, len(crr.rrs)), } - crr.iterateNextBatch() + if err := crr.iterateNextBatch(); err != nil { + return err + } } composeRecord := func() bool { @@ -172,10 +190,373 @@ func (crr *compositeRecordReader) Record() Record { func (crr *compositeRecordReader) Close() { for _, close := range crr.closers { - close() + if close != nil { + close() + } } } +type serdeEntry struct { + arrowType func(int) arrow.DataType + deserialize func(arrow.Array, int) (any, bool) + serialize func(array.Builder, any) bool + sizeof func(any) uint64 +} + +var serdeMap = func() map[schemapb.DataType]serdeEntry { + m := make(map[schemapb.DataType]serdeEntry) + m[schemapb.DataType_Bool] = serdeEntry{ + func(i int) arrow.DataType { + return arrow.FixedWidthTypes.Boolean + }, + func(a arrow.Array, i int) (any, bool) { + if arr, ok := a.(*array.Boolean); ok && i < arr.Len() { + return arr.Value(i), true + } + return nil, false + }, + func(b array.Builder, v any) bool { + if builder, ok := b.(*array.BooleanBuilder); ok { + if v, ok := v.(bool); ok { + builder.Append(v) + return true + } + } + return false + }, + func(any) uint64 { + return 1 + }, + } + m[schemapb.DataType_Int8] = serdeEntry{ + func(i int) arrow.DataType { + return arrow.PrimitiveTypes.Int8 + }, + func(a arrow.Array, i int) (any, bool) { + if arr, ok := a.(*array.Int8); ok && i < arr.Len() { + return arr.Value(i), true + } + return nil, false + }, + func(b array.Builder, v any) bool { + if builder, ok := b.(*array.Int8Builder); ok { + if v, ok := v.(int8); ok { + builder.Append(v) + return true + } + } + return false + }, + func(any) uint64 { + return 1 + }, + } + m[schemapb.DataType_Int16] = serdeEntry{ + func(i int) arrow.DataType { + return arrow.PrimitiveTypes.Int16 + }, + func(a arrow.Array, i int) (any, bool) { + if arr, ok := a.(*array.Int16); ok && i < arr.Len() { + return arr.Value(i), true + } + return nil, false + }, + func(b array.Builder, v any) bool { + if builder, ok := b.(*array.Int16Builder); ok { + if v, ok := v.(int16); ok { + builder.Append(v) + return true + } + } + return false + }, + func(any) uint64 { + return 2 + }, + } + m[schemapb.DataType_Int32] = serdeEntry{ + func(i int) arrow.DataType { + return arrow.PrimitiveTypes.Int32 + }, + func(a arrow.Array, i int) (any, bool) { + if arr, ok := a.(*array.Int32); ok && i < arr.Len() { + return arr.Value(i), true + } + return nil, false + }, + func(b array.Builder, v any) bool { + if builder, ok := b.(*array.Int32Builder); ok { + if v, ok := v.(int32); ok { + builder.Append(v) + return true + } + } + return false + }, + func(any) uint64 { + return 4 + }, + } + m[schemapb.DataType_Int64] = serdeEntry{ + func(i int) arrow.DataType { + return arrow.PrimitiveTypes.Int64 + }, + func(a arrow.Array, i int) (any, bool) { + if arr, ok := a.(*array.Int64); ok && i < arr.Len() { + return arr.Value(i), true + } + return nil, false + }, + func(b array.Builder, v any) bool { + if builder, ok := b.(*array.Int64Builder); ok { + if v, ok := v.(int64); ok { + builder.Append(v) + return true + } + } + return false + }, + func(any) uint64 { + return 8 + }, + } + m[schemapb.DataType_Float] = serdeEntry{ + func(i int) arrow.DataType { + return arrow.PrimitiveTypes.Float32 + }, + func(a arrow.Array, i int) (any, bool) { + if arr, ok := a.(*array.Float32); ok && i < arr.Len() { + return arr.Value(i), true + } + return nil, false + }, + func(b array.Builder, v any) bool { + if builder, ok := b.(*array.Float32Builder); ok { + if v, ok := v.(float32); ok { + builder.Append(v) + return true + } + } + return false + }, + func(any) uint64 { + return 4 + }, + } + m[schemapb.DataType_Double] = serdeEntry{ + func(i int) arrow.DataType { + return arrow.PrimitiveTypes.Float64 + }, + func(a arrow.Array, i int) (any, bool) { + if arr, ok := a.(*array.Float64); ok && i < arr.Len() { + return arr.Value(i), true + } + return nil, false + }, + func(b array.Builder, v any) bool { + if builder, ok := b.(*array.Float64Builder); ok { + if v, ok := v.(float64); ok { + builder.Append(v) + return true + } + } + return false + }, + func(any) uint64 { + return 8 + }, + } + stringEntry := serdeEntry{ + func(i int) arrow.DataType { + return arrow.BinaryTypes.String + }, + func(a arrow.Array, i int) (any, bool) { + if arr, ok := a.(*array.String); ok && i < arr.Len() { + return arr.Value(i), true + } + return nil, false + }, + func(b array.Builder, v any) bool { + if builder, ok := b.(*array.StringBuilder); ok { + if v, ok := v.(string); ok { + builder.Append(v) + return true + } + } + return false + }, + func(v any) uint64 { + return uint64(len(v.(string))) + }, + } + + m[schemapb.DataType_VarChar] = stringEntry + m[schemapb.DataType_String] = stringEntry + m[schemapb.DataType_Array] = serdeEntry{ + func(i int) arrow.DataType { + return arrow.BinaryTypes.Binary + }, + func(a arrow.Array, i int) (any, bool) { + if arr, ok := a.(*array.Binary); ok && i < arr.Len() { + v := &schemapb.ScalarField{} + if err := proto.Unmarshal(arr.Value(i), v); err == nil { + return v, true + } + } + return nil, false + }, + func(b array.Builder, v any) bool { + if builder, ok := b.(*array.BinaryBuilder); ok { + if vv, ok := v.(*schemapb.ScalarField); ok { + if bytes, err := proto.Marshal(vv); err == nil { + builder.Append(bytes) + return true + } + } + } + return false + }, + func(v any) uint64 { + return uint64(v.(*schemapb.ScalarField).XXX_Size()) + }, + } + + sizeOfBytes := func(v any) uint64 { + return uint64(len(v.([]byte))) + } + m[schemapb.DataType_JSON] = serdeEntry{ + func(i int) arrow.DataType { + return arrow.BinaryTypes.Binary + }, + func(a arrow.Array, i int) (any, bool) { + if arr, ok := a.(*array.Binary); ok && i < arr.Len() { + return arr.Value(i), true + } + return nil, false + }, + func(b array.Builder, v any) bool { + if builder, ok := b.(*array.BinaryBuilder); ok { + if v, ok := v.([]byte); ok { + builder.Append(v) + return true + } + } + return false + }, + sizeOfBytes, + } + + fixedSizeDeserializer := func(a arrow.Array, i int) (any, bool) { + if arr, ok := a.(*array.FixedSizeBinary); ok && i < arr.Len() { + return arr.Value(i), true + } + return nil, false + } + fixedSizeSerializer := func(b array.Builder, v any) bool { + if builder, ok := b.(*array.FixedSizeBinaryBuilder); ok { + if v, ok := v.([]byte); ok { + builder.Append(v) + return true + } + } + return false + } + + m[schemapb.DataType_BinaryVector] = serdeEntry{ + func(i int) arrow.DataType { + return &arrow.FixedSizeBinaryType{ByteWidth: (i + 7) / 8} + }, + fixedSizeDeserializer, + fixedSizeSerializer, + sizeOfBytes, + } + m[schemapb.DataType_Float16Vector] = serdeEntry{ + func(i int) arrow.DataType { + return &arrow.FixedSizeBinaryType{ByteWidth: i * 2} + }, + fixedSizeDeserializer, + fixedSizeSerializer, + sizeOfBytes, + } + m[schemapb.DataType_BFloat16Vector] = serdeEntry{ + func(i int) arrow.DataType { + return &arrow.FixedSizeBinaryType{ByteWidth: i * 2} + }, + fixedSizeDeserializer, + fixedSizeSerializer, + sizeOfBytes, + } + m[schemapb.DataType_FloatVector] = serdeEntry{ + func(i int) arrow.DataType { + return &arrow.FixedSizeBinaryType{ByteWidth: i * 4} + }, + func(a arrow.Array, i int) (any, bool) { + if arr, ok := a.(*array.FixedSizeBinary); ok && i < arr.Len() { + return arrow.Float32Traits.CastFromBytes(arr.Value(i)), true + } + return nil, false + }, + func(b array.Builder, v any) bool { + if builder, ok := b.(*array.FixedSizeBinaryBuilder); ok { + if vv, ok := v.([]float32); ok { + dim := len(vv) + byteLength := dim * 4 + bytesData := make([]byte, byteLength) + for i, vec := range vv { + bytes := math.Float32bits(vec) + common.Endian.PutUint32(bytesData[i*4:], bytes) + } + builder.Append(bytesData) + return true + } + } + return false + }, + func(v any) uint64 { + return uint64(len(v.([]float32)) * 4) + }, + } + m[schemapb.DataType_SparseFloatVector] = serdeEntry{ + func(i int) arrow.DataType { + return arrow.BinaryTypes.Binary + }, + func(a arrow.Array, i int) (any, bool) { + if arr, ok := a.(*array.Binary); ok && i < arr.Len() { + + value := arr.Value(i) + return value, true + // fieldData := &SparseFloatVectorFieldData{} + // if len(value)%8 != 0 { + // return nil, false + // } + + // fieldData.Contents = append(fieldData.Contents, value) + // rowDim := typeutil.SparseFloatRowDim(value) + // if rowDim > fieldData.Dim { + // fieldData.Dim = rowDim + // } + + // return fieldData, true + } + return nil, false + }, + func(b array.Builder, v any) bool { + if builder, ok := b.(*array.BinaryBuilder); ok { + if vv, ok := v.(*SparseFloatVectorFieldData); ok { + length := len(vv.SparseFloatArray.Contents) + for i := 0; i < length; i++ { + builder.Append(vv.SparseFloatArray.Contents[i]) + } + return true + } + } + return false + }, + sizeOfBytes, + } + return m +}() + func parseBlobKey(bolbKey string) (colId FieldID, logId UniqueID) { if _, _, _, colId, logId, ok := metautil.ParseInsertLogPath(bolbKey); ok { return colId, logId @@ -224,7 +605,7 @@ func newCompositeRecordReader(blobs []*Blob) (*compositeRecordReader, error) { type DeserializeReader[T any] struct { rr RecordReader - deserializer func(Record, []T) error + deserializer Deserializer[T] rec Record values []T pos int @@ -265,113 +646,13 @@ func (deser *DeserializeReader[T]) Close() { } } -func NewDeserializeReader[T any](rr RecordReader, deserializer func(Record, []T) error) *DeserializeReader[T] { +func NewDeserializeReader[T any](rr RecordReader, deserializer Deserializer[T]) *DeserializeReader[T] { return &DeserializeReader[T]{ rr: rr, deserializer: deserializer, } } -func deserializeCell(col arrow.Array, dataType schemapb.DataType, i int) (interface{}, bool) { - switch dataType { - case schemapb.DataType_Bool: - arr, ok := col.(*array.Boolean) - if !ok { - return nil, false - } - return arr.Value(i), true - - case schemapb.DataType_Int8: - arr, ok := col.(*array.Int8) - if !ok { - return nil, false - } - return arr.Value(i), true - - case schemapb.DataType_Int16: - arr, ok := col.(*array.Int16) - if !ok { - return nil, false - } - return arr.Value(i), true - - case schemapb.DataType_Int32: - arr, ok := col.(*array.Int32) - if !ok { - return nil, false - } - return arr.Value(i), true - - case schemapb.DataType_Int64: - arr, ok := col.(*array.Int64) - if !ok { - return nil, false - } - return arr.Value(i), true - - case schemapb.DataType_Float: - arr, ok := col.(*array.Float32) - if !ok { - return nil, false - } - return arr.Value(i), true - - case schemapb.DataType_Double: - arr, ok := col.(*array.Float64) - if !ok { - return nil, false - } - return arr.Value(i), true - - case schemapb.DataType_String, schemapb.DataType_VarChar: - arr, ok := col.(*array.String) - if !ok { - return nil, false - } - return arr.Value(i), true - - case schemapb.DataType_Array: - arr, ok := col.(*array.Binary) - if !ok { - return nil, false - } - v := &schemapb.ScalarField{} - if err := proto.Unmarshal(arr.Value(i), v); err != nil { - return nil, false - } - return v, true - - case schemapb.DataType_JSON: - arr, ok := col.(*array.Binary) - if !ok { - return nil, false - } - return arr.Value(i), true - - case schemapb.DataType_BinaryVector, schemapb.DataType_Float16Vector, schemapb.DataType_BFloat16Vector: - arr, ok := col.(*array.FixedSizeBinary) - if !ok { - return nil, false - } - return arr.Value(i), true - - case schemapb.DataType_FloatVector: - arr, ok := col.(*array.FixedSizeBinary) - if !ok { - return nil, false - } - return arrow.Float32Traits.CastFromBytes(arr.Value(i)), true - case schemapb.DataType_SparseFloatVector: - arr, ok := col.(*array.Binary) - if !ok { - return nil, false - } - return arr.Value(i), true - default: - panic(fmt.Sprintf("unsupported type %s", dataType)) - } -} - func NewBinlogDeserializeReader(blobs []*Blob, PKfieldID UniqueID) (*DeserializeReader[*Value], error) { reader, err := newCompositeRecordReader(blobs) if err != nil { @@ -391,7 +672,7 @@ func NewBinlogDeserializeReader(blobs []*Blob, PKfieldID UniqueID) (*Deserialize m := value.Value.(map[FieldID]interface{}) for j, dt := range r.Schema() { - d, ok := deserializeCell(r.Column(j), dt, i) + d, ok := serdeMap[dt].deserialize(r.Column(j), i) if ok { m[j] = d // TODO: avoid memory copy here. } else { @@ -417,3 +698,362 @@ func NewBinlogDeserializeReader(blobs []*Blob, PKfieldID UniqueID) (*Deserialize return nil }), nil } + +// selectiveRecord is a Record that only contains a single field, reusing existing Record. +type selectiveRecord struct { + Record + + r Record + selectedFieldId FieldID + + schema map[FieldID]schemapb.DataType +} + +func (r *selectiveRecord) Schema() map[FieldID]schemapb.DataType { + return r.schema +} + +func (r *selectiveRecord) Column(i FieldID) arrow.Array { + if i == r.selectedFieldId { + return r.r.Column(i) + } + return nil +} + +func (r *selectiveRecord) Len() int { + return r.r.Len() +} + +func (r *selectiveRecord) Release() { + // do nothing. +} + +func newSelectiveRecord(r Record, selectedFieldId FieldID) *selectiveRecord { + dt, ok := r.Schema()[selectedFieldId] + if !ok { + return nil + } + schema := make(map[FieldID]schemapb.DataType, 1) + schema[selectedFieldId] = dt + return &selectiveRecord{ + r: r, + selectedFieldId: selectedFieldId, + schema: schema, + } +} + +type compositeRecordWriter struct { + RecordWriter + writers map[FieldID]RecordWriter +} + +func (crw *compositeRecordWriter) Write(r Record) error { + if len(r.Schema()) != len(crw.writers) { + return fmt.Errorf("schema length mismatch %d, expected %d", len(r.Schema()), len(crw.writers)) + } + for fieldId, w := range crw.writers { + sr := newSelectiveRecord(r, fieldId) + if err := w.Write(sr); err != nil { + return err + } + } + return nil +} + +func (crw *compositeRecordWriter) Close() { + if crw != nil { + for _, w := range crw.writers { + if w != nil { + w.Close() + } + } + } +} + +func newCompositeRecordWriter(writers map[FieldID]RecordWriter) *compositeRecordWriter { + return &compositeRecordWriter{ + writers: writers, + } +} + +type singleFieldRecordWriter struct { + RecordWriter + fw *pqarrow.FileWriter + fieldId FieldID + + grouped bool +} + +func (sfw *singleFieldRecordWriter) Write(r Record) error { + if !sfw.grouped { + sfw.grouped = true + sfw.fw.NewRowGroup() + } + // TODO: adding row group support by calling fw.NewRowGroup() + a := r.Column(sfw.fieldId) + return sfw.fw.WriteColumnData(a) +} + +func (sfw *singleFieldRecordWriter) Close() { + sfw.fw.Close() +} + +func newSingleFieldRecordWriter(fieldId FieldID, field arrow.Field, writer io.Writer) (*singleFieldRecordWriter, error) { + schema := arrow.NewSchema([]arrow.Field{field}, nil) + fw, err := pqarrow.NewFileWriter(schema, writer, nil, pqarrow.DefaultWriterProps()) + if err != nil { + return nil, err + } + return &singleFieldRecordWriter{ + fw: fw, + fieldId: fieldId, + }, nil +} + +type SerializeWriter[T any] struct { + rw RecordWriter + serializer Serializer[T] + batchSize int + + buffer []T + pos int + writtenMemorySize uint64 +} + +func (sw *SerializeWriter[T]) Flush() error { + buf := sw.buffer[:sw.pos] + r, size, err := sw.serializer(buf) + if err != nil { + return err + } + if err := sw.rw.Write(r); err != nil { + return err + } + r.Release() + sw.pos = 0 + sw.writtenMemorySize += size + return nil +} + +func (sw *SerializeWriter[T]) Write(value T) error { + if sw.buffer == nil { + sw.buffer = make([]T, sw.batchSize) + } + sw.buffer[sw.pos] = value + sw.pos++ + if sw.pos == sw.batchSize { + if err := sw.Flush(); err != nil { + return err + } + } + return nil +} + +func (sw *SerializeWriter[T]) WrittenMemorySize() uint64 { + return sw.writtenMemorySize +} + +func (sw *SerializeWriter[T]) Close() { + sw.Flush() + sw.rw.Close() +} + +func NewSerializeRecordWriter[T any](rw RecordWriter, serializer Serializer[T], batchSize int) *SerializeWriter[T] { + return &SerializeWriter[T]{ + rw: rw, + serializer: serializer, + batchSize: batchSize, + } +} + +type simpleArrowRecord struct { + Record + + r arrow.Record + schema map[FieldID]schemapb.DataType + + field2Col map[FieldID]int +} + +func (sr *simpleArrowRecord) Schema() map[FieldID]schemapb.DataType { + return sr.schema +} + +func (sr *simpleArrowRecord) Column(i FieldID) arrow.Array { + colIdx, ok := sr.field2Col[i] + if !ok { + panic("no such field") + } + return sr.r.Column(colIdx) +} + +func (sr *simpleArrowRecord) Len() int { + return int(sr.r.NumRows()) +} + +func (sr *simpleArrowRecord) Release() { + sr.r.Release() +} + +func newSimpleArrowRecord(r arrow.Record, schema map[FieldID]schemapb.DataType, field2Col map[FieldID]int) *simpleArrowRecord { + return &simpleArrowRecord{ + r: r, + schema: schema, + field2Col: field2Col, + } +} + +type BinlogStreamWriter struct { + collectionID UniqueID + partitionID UniqueID + segmentID UniqueID + fieldSchema *schemapb.FieldSchema + + memorySize int // To be updated on the fly + + buf bytes.Buffer + rw RecordWriter +} + +func (bsw *BinlogStreamWriter) GetRecordWriter() (RecordWriter, error) { + if bsw.rw != nil { + return bsw.rw, nil + } + + fid := bsw.fieldSchema.FieldID + dim, _ := typeutil.GetDim(bsw.fieldSchema) + rw, err := newSingleFieldRecordWriter(fid, arrow.Field{ + Name: strconv.Itoa(int(fid)), + Type: serdeMap[bsw.fieldSchema.DataType].arrowType(int(dim)), + }, &bsw.buf) + if err != nil { + return nil, err + } + bsw.rw = rw + return rw, nil +} + +func (bsw *BinlogStreamWriter) Finalize() (*Blob, error) { + if bsw.rw == nil { + return nil, io.ErrUnexpectedEOF + } + bsw.rw.Close() + + var b bytes.Buffer + if err := bsw.writeBinlogHeaders(&b); err != nil { + return nil, err + } + if _, err := b.Write(bsw.buf.Bytes()); err != nil { + return nil, err + } + return &Blob{ + Key: strconv.Itoa(int(bsw.fieldSchema.FieldID)), + Value: b.Bytes(), + }, nil +} + +func (bsw *BinlogStreamWriter) writeBinlogHeaders(w io.Writer) error { + // Write magic number + if err := binary.Write(w, common.Endian, MagicNumber); err != nil { + return err + } + // Write descriptor + de := newDescriptorEvent() + de.PayloadDataType = bsw.fieldSchema.DataType + de.CollectionID = bsw.collectionID + de.PartitionID = bsw.partitionID + de.SegmentID = bsw.segmentID + de.FieldID = bsw.fieldSchema.FieldID + de.StartTimestamp = 0 + de.EndTimestamp = 0 + de.descriptorEventData.AddExtra(originalSizeKey, strconv.Itoa(bsw.memorySize)) // FIXME: enable original size + if err := de.Write(w); err != nil { + return err + } + // Write event header + eh := newEventHeader(InsertEventType) + // Write event data + ev := newInsertEventData() + ev.StartTimestamp = 1 // Fixme: enable start/end timestamp + ev.EndTimestamp = 1 + eh.EventLength = int32(bsw.buf.Len()) + eh.GetMemoryUsageInBytes() + int32(binary.Size(ev)) + // eh.NextPosition = eh.EventLength + w.Offset() + if err := eh.Write(w); err != nil { + return err + } + if err := ev.WriteEventData(w); err != nil { + return err + } + return nil +} + +func NewBinlogStreamWriters(collectionID, partitionID, segmentID UniqueID, + schema []*schemapb.FieldSchema, +) map[FieldID]*BinlogStreamWriter { + bws := make(map[FieldID]*BinlogStreamWriter, len(schema)) + for _, f := range schema { + bws[f.FieldID] = &BinlogStreamWriter{ + collectionID: collectionID, + partitionID: partitionID, + segmentID: segmentID, + fieldSchema: f, + } + } + return bws +} + +func NewBinlogSerializeWriter(schema *schemapb.CollectionSchema, partitionID, segmentID UniqueID, + writers map[FieldID]*BinlogStreamWriter, batchSize int, +) (*SerializeWriter[*Value], error) { + rws := make(map[FieldID]RecordWriter, len(writers)) + for fid := range writers { + w := writers[fid] + rw, err := w.GetRecordWriter() + if err != nil { + return nil, err + } + rws[fid] = rw + } + compositeRecordWriter := newCompositeRecordWriter(rws) + return NewSerializeRecordWriter[*Value](compositeRecordWriter, func(v []*Value) (Record, uint64, error) { + builders := make(map[FieldID]array.Builder, len(schema.Fields)) + types := make(map[FieldID]schemapb.DataType, len(schema.Fields)) + for _, f := range schema.Fields { + dim, _ := typeutil.GetDim(f) + builders[f.FieldID] = array.NewBuilder(memory.DefaultAllocator, serdeMap[f.DataType].arrowType(int(dim))) + types[f.FieldID] = f.DataType + } + + var memorySize uint64 + for _, vv := range v { + m := vv.Value.(map[FieldID]any) + + for fid, e := range m { + typeEntry, ok := serdeMap[types[fid]] + if !ok { + panic("unknown type") + } + ok = typeEntry.serialize(builders[fid], e) + if !ok { + return nil, 0, errors.New(fmt.Sprintf("unexpected type %s", types[fid])) + } + memorySize += typeEntry.sizeof(e) + } + } + arrays := make([]arrow.Array, len(types)) + fields := make([]arrow.Field, len(types)) + field2Col := make(map[FieldID]int, len(types)) + i := 0 + for fid, builder := range builders { + arrays[i] = builder.NewArray() + builder.Release() + fields[i] = arrow.Field{ + Name: strconv.Itoa(int(fid)), + Type: arrays[i].DataType(), + } + field2Col[fid] = i + i++ + } + return newSimpleArrowRecord(array.NewRecord(arrow.NewSchema(fields, nil), arrays, -1), types, field2Col), memorySize, nil + }, batchSize), nil +} diff --git a/internal/storage/serde_test.go b/internal/storage/serde_test.go index 0d0bfb9ee6d37..7d99d853fbe7f 100644 --- a/internal/storage/serde_test.go +++ b/internal/storage/serde_test.go @@ -17,6 +17,8 @@ package storage import ( + "bytes" + "context" "io" "reflect" "testing" @@ -24,6 +26,8 @@ import ( "github.com/apache/arrow/go/v12/arrow" "github.com/apache/arrow/go/v12/arrow/array" "github.com/apache/arrow/go/v12/arrow/memory" + "github.com/apache/arrow/go/v12/parquet/file" + "github.com/apache/arrow/go/v12/parquet/pqarrow" "github.com/stretchr/testify/assert" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" @@ -46,23 +50,18 @@ func TestBinlogDeserializeReader(t *testing.T) { }) t.Run("test deserialize", func(t *testing.T) { - len := 3 - blobs, err := generateTestData(len) + size := 3 + blobs, err := generateTestData(size) assert.NoError(t, err) reader, err := NewBinlogDeserializeReader(blobs, common.RowIDField) assert.NoError(t, err) defer reader.Close() - for i := 1; i <= len; i++ { + for i := 1; i <= size; i++ { err = reader.Next() assert.NoError(t, err) value := reader.Value() - - f102 := make([]float32, 8) - for j := range f102 { - f102[j] = float32(i) - } assertTestData(t, i, value) } @@ -71,61 +70,115 @@ func TestBinlogDeserializeReader(t *testing.T) { }) } -func Test_deserializeCell(t *testing.T) { - onelinerArray := func(dtype arrow.DataType, payload interface{}) arrow.Array { - mem := memory.DefaultAllocator - - switch dtype.ID() { - case arrow.BOOL: - builder := array.NewBooleanBuilder(mem) - builder.Append(payload.(bool)) - return builder.NewBooleanArray() - case arrow.INT8: - builder := array.NewInt8Builder(mem) - builder.Append(payload.(int8)) - return builder.NewInt8Array() - case arrow.INT16: - builder := array.NewInt16Builder(mem) - builder.Append(payload.(int16)) - return builder.NewInt16Array() - case arrow.INT32: - builder := array.NewInt32Builder(mem) - builder.Append(payload.(int32)) - return builder.NewInt32Array() - case arrow.INT64: - builder := array.NewInt64Builder(mem) - builder.Append(payload.(int64)) - return builder.NewInt64Array() - case arrow.FLOAT32: - builder := array.NewFloat32Builder(mem) - builder.Append(payload.(float32)) - return builder.NewFloat32Array() - case arrow.FLOAT64: - builder := array.NewFloat64Builder(mem) - builder.Append(payload.(float64)) - return builder.NewFloat64Array() - case arrow.STRING: - builder := array.NewStringBuilder(mem) - builder.Append(payload.(string)) - return builder.NewStringArray() - case arrow.BINARY: - builder := array.NewBinaryBuilder(mem, arrow.BinaryTypes.Binary) - builder.Append(payload.([]byte)) - return builder.NewBinaryArray() - case arrow.FIXED_SIZE_BINARY: - typ := dtype.(*arrow.FixedSizeBinaryType) - builder := array.NewFixedSizeBinaryBuilder(mem, typ) - builder.Append(payload.([]byte)) - return builder.NewFixedSizeBinaryArray() +func TestBinlogStreamWriter(t *testing.T) { + t.Run("test write", func(t *testing.T) { + size := 3 + + field := arrow.Field{Name: "bool", Type: arrow.FixedWidthTypes.Boolean} + var w bytes.Buffer + rw, err := newSingleFieldRecordWriter(1, field, &w) + assert.NoError(t, err) + + builder := array.NewBooleanBuilder(memory.DefaultAllocator) + builder.AppendValues([]bool{true, false, true}, nil) + arr := builder.NewArray() + defer arr.Release() + ar := array.NewRecord( + arrow.NewSchema( + []arrow.Field{field}, + nil, + ), + []arrow.Array{arr}, + int64(size), + ) + r := newSimpleArrowRecord(ar, map[FieldID]schemapb.DataType{1: schemapb.DataType_Bool}, map[FieldID]int{1: 0}) + defer r.Release() + err = rw.Write(r) + assert.NoError(t, err) + rw.Close() + + reader, err := file.NewParquetReader(bytes.NewReader(w.Bytes())) + assert.NoError(t, err) + arrowReader, err := pqarrow.NewFileReader(reader, pqarrow.ArrowReadProperties{BatchSize: 1024}, memory.DefaultAllocator) + assert.NoError(t, err) + rr, err := arrowReader.GetRecordReader(context.Background(), nil, nil) + assert.NoError(t, err) + defer rr.Release() + ok := rr.Next() + assert.True(t, ok) + rec := rr.Record() + defer rec.Release() + assert.Equal(t, int64(size), rec.NumRows()) + ok = rr.Next() + assert.False(t, ok) + }) +} + +func TestBinlogSerializeWriter(t *testing.T) { + t.Run("test empty data", func(t *testing.T) { + reader, err := NewBinlogDeserializeReader(nil, common.RowIDField) + assert.NoError(t, err) + defer reader.Close() + err = reader.Next() + assert.Equal(t, io.EOF, err) + }) + + t.Run("test serialize", func(t *testing.T) { + size := 3 + blobs, err := generateTestData(size) + assert.NoError(t, err) + reader, err := NewBinlogDeserializeReader(blobs, common.RowIDField) + assert.NoError(t, err) + defer reader.Close() + + schema := generateTestSchema() + // Copy write the generated data + writers := NewBinlogStreamWriters(0, 0, 0, schema.Fields) + writer, err := NewBinlogSerializeWriter(schema, 0, 0, writers, 1024) + assert.NoError(t, err) + + for i := 1; i <= size; i++ { + err = reader.Next() + assert.NoError(t, err) + + value := reader.Value() + assertTestData(t, i, value) + writer.Write(value) } - return nil - } + err = reader.Next() + assert.Equal(t, io.EOF, err) + writer.Close() + + // Read from the written data + newblobs := make([]*Blob, len(writers)) + i := 0 + for _, w := range writers { + blob, err := w.Finalize() + assert.NoError(t, err) + assert.NotNil(t, blob) + newblobs[i] = blob + i++ + } + // assert.Equal(t, blobs[0].Value, newblobs[0].Value) + reader, err = NewBinlogDeserializeReader(blobs, common.RowIDField) + assert.NoError(t, err) + defer reader.Close() + for i := 1; i <= size; i++ { + err = reader.Next() + assert.NoError(t, err, i) + + value := reader.Value() + assertTestData(t, i, value) + writer.Write(value) + } + }) +} +func TestSerDe(t *testing.T) { type args struct { - col arrow.Array - dataType schemapb.DataType - i int + dt schemapb.DataType + v any } tests := []struct { name string @@ -133,45 +186,49 @@ func Test_deserializeCell(t *testing.T) { want interface{} want1 bool }{ - {"test bool", args{col: onelinerArray(arrow.FixedWidthTypes.Boolean, true), dataType: schemapb.DataType_Bool, i: 0}, true, true}, - {"test bool negative", args{col: onelinerArray(arrow.Null, nil), dataType: schemapb.DataType_Bool, i: 0}, nil, false}, - {"test int8", args{col: onelinerArray(arrow.PrimitiveTypes.Int8, int8(1)), dataType: schemapb.DataType_Int8, i: 0}, int8(1), true}, - {"test int8 negative", args{col: onelinerArray(arrow.Null, nil), dataType: schemapb.DataType_Int8, i: 0}, nil, false}, - {"test int16", args{col: onelinerArray(arrow.PrimitiveTypes.Int16, int16(1)), dataType: schemapb.DataType_Int16, i: 0}, int16(1), true}, - {"test int16 negative", args{col: onelinerArray(arrow.Null, nil), dataType: schemapb.DataType_Int16, i: 0}, nil, false}, - {"test int32", args{col: onelinerArray(arrow.PrimitiveTypes.Int32, int32(1)), dataType: schemapb.DataType_Int32, i: 0}, int32(1), true}, - {"test int32 negative", args{col: onelinerArray(arrow.Null, nil), dataType: schemapb.DataType_Int32, i: 0}, nil, false}, - {"test int64", args{col: onelinerArray(arrow.PrimitiveTypes.Int64, int64(1)), dataType: schemapb.DataType_Int64, i: 0}, int64(1), true}, - {"test int64 negative", args{col: onelinerArray(arrow.Null, nil), dataType: schemapb.DataType_Int64, i: 0}, nil, false}, - {"test float32", args{col: onelinerArray(arrow.PrimitiveTypes.Float32, float32(1)), dataType: schemapb.DataType_Float, i: 0}, float32(1), true}, - {"test float32 negative", args{col: onelinerArray(arrow.Null, nil), dataType: schemapb.DataType_Float, i: 0}, nil, false}, - {"test float64", args{col: onelinerArray(arrow.PrimitiveTypes.Float64, float64(1)), dataType: schemapb.DataType_Double, i: 0}, float64(1), true}, - {"test float64 negative", args{col: onelinerArray(arrow.Null, nil), dataType: schemapb.DataType_Double, i: 0}, nil, false}, - {"test string", args{col: onelinerArray(arrow.BinaryTypes.String, "test"), dataType: schemapb.DataType_String, i: 0}, "test", true}, - {"test string negative", args{col: onelinerArray(arrow.Null, nil), dataType: schemapb.DataType_String, i: 0}, nil, false}, - {"test varchar", args{col: onelinerArray(arrow.BinaryTypes.String, "test"), dataType: schemapb.DataType_VarChar, i: 0}, "test", true}, - {"test varchar negative", args{col: onelinerArray(arrow.Null, nil), dataType: schemapb.DataType_VarChar, i: 0}, nil, false}, - {"test array negative", args{col: onelinerArray(arrow.BinaryTypes.Binary, []byte("{}")), dataType: schemapb.DataType_Array, i: 0}, nil, false}, - {"test array negative null", args{col: onelinerArray(arrow.Null, nil), dataType: schemapb.DataType_Array, i: 0}, nil, false}, - {"test json", args{col: onelinerArray(arrow.BinaryTypes.Binary, []byte("{}")), dataType: schemapb.DataType_JSON, i: 0}, []byte("{}"), true}, - {"test json negative", args{col: onelinerArray(arrow.Null, nil), dataType: schemapb.DataType_JSON, i: 0}, nil, false}, - {"test float vector", args{col: onelinerArray(&arrow.FixedSizeBinaryType{ByteWidth: 4}, []byte{0, 0, 0, 0}), dataType: schemapb.DataType_FloatVector, i: 0}, []float32{0.0}, true}, - {"test float vector negative", args{col: onelinerArray(arrow.Null, nil), dataType: schemapb.DataType_FloatVector, i: 0}, nil, false}, - {"test bool vector", args{col: onelinerArray(&arrow.FixedSizeBinaryType{ByteWidth: 4}, []byte("test")), dataType: schemapb.DataType_BinaryVector, i: 0}, []byte("test"), true}, - {"test float16 vector", args{col: onelinerArray(&arrow.FixedSizeBinaryType{ByteWidth: 4}, []byte("test")), dataType: schemapb.DataType_Float16Vector, i: 0}, []byte("test"), true}, - {"test bfloat16 vector", args{col: onelinerArray(&arrow.FixedSizeBinaryType{ByteWidth: 4}, []byte("test")), dataType: schemapb.DataType_BFloat16Vector, i: 0}, []byte("test"), true}, - {"test bfloat16 vector negative", args{col: onelinerArray(arrow.Null, nil), dataType: schemapb.DataType_BFloat16Vector, i: 0}, nil, false}, - {"test sparse float vector", args{col: onelinerArray(arrow.BinaryTypes.Binary, []byte("1234test")), dataType: schemapb.DataType_SparseFloatVector, i: 0}, []byte("1234test"), true}, - {"test sparse float vector negative", args{col: onelinerArray(arrow.Null, nil), dataType: schemapb.DataType_SparseFloatVector, i: 0}, nil, false}, + {"test bool", args{dt: schemapb.DataType_Bool, v: true}, true, true}, + {"test bool negative", args{dt: schemapb.DataType_Bool, v: nil}, nil, false}, + {"test int8", args{dt: schemapb.DataType_Int8, v: int8(1)}, int8(1), true}, + {"test int8 negative", args{dt: schemapb.DataType_Int8, v: nil}, nil, false}, + {"test int16", args{dt: schemapb.DataType_Int16, v: int16(1)}, int16(1), true}, + {"test int16 negative", args{dt: schemapb.DataType_Int16, v: nil}, nil, false}, + {"test int32", args{dt: schemapb.DataType_Int32, v: int32(1)}, int32(1), true}, + {"test int32 negative", args{dt: schemapb.DataType_Int32, v: nil}, nil, false}, + {"test int64", args{dt: schemapb.DataType_Int64, v: int64(1)}, int64(1), true}, + {"test int64 negative", args{dt: schemapb.DataType_Int64, v: nil}, nil, false}, + {"test float32", args{dt: schemapb.DataType_Float, v: float32(1)}, float32(1), true}, + {"test float32 negative", args{dt: schemapb.DataType_Float, v: nil}, nil, false}, + {"test float64", args{dt: schemapb.DataType_Double, v: float64(1)}, float64(1), true}, + {"test float64 negative", args{dt: schemapb.DataType_Double, v: nil}, nil, false}, + {"test string", args{dt: schemapb.DataType_String, v: "test"}, "test", true}, + {"test string negative", args{dt: schemapb.DataType_String, v: nil}, nil, false}, + {"test varchar", args{dt: schemapb.DataType_VarChar, v: "test"}, "test", true}, + {"test varchar negative", args{dt: schemapb.DataType_VarChar, v: nil}, nil, false}, + {"test array negative", args{dt: schemapb.DataType_Array, v: "{}"}, nil, false}, + {"test array negative null", args{dt: schemapb.DataType_Array, v: nil}, nil, false}, + {"test json", args{dt: schemapb.DataType_JSON, v: []byte("{}")}, []byte("{}"), true}, + {"test json negative", args{dt: schemapb.DataType_JSON, v: nil}, nil, false}, + {"test float vector", args{dt: schemapb.DataType_FloatVector, v: []float32{1.0}}, []float32{1.0}, true}, + {"test float vector negative", args{dt: schemapb.DataType_FloatVector, v: nil}, nil, false}, + {"test bool vector", args{dt: schemapb.DataType_BinaryVector, v: []byte{0xff}}, []byte{0xff}, true}, + {"test float16 vector", args{dt: schemapb.DataType_Float16Vector, v: []byte{0xff, 0xff}}, []byte{0xff, 0xff}, true}, + {"test bfloat16 vector", args{dt: schemapb.DataType_BFloat16Vector, v: []byte{0xff, 0xff}}, []byte{0xff, 0xff}, true}, + {"test bfloat16 vector negative", args{dt: schemapb.DataType_BFloat16Vector, v: nil}, nil, false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, got1 := deserializeCell(tt.args.col, tt.args.dataType, tt.args.i) + dt := tt.args.dt + v := tt.args.v + builder := array.NewBuilder(memory.DefaultAllocator, serdeMap[dt].arrowType(1)) + serdeMap[dt].serialize(builder, v) + // assert.True(t, ok) + a := builder.NewArray() + got, got1 := serdeMap[dt].deserialize(a, 0) if !reflect.DeepEqual(got, tt.want) { - t.Errorf("deserializeCell() got = %v, want %v", got, tt.want) + t.Errorf("deserialize() got = %v, want %v", got, tt.want) } if got1 != tt.want1 { - t.Errorf("deserializeCell() got1 = %v, want %v", got1, tt.want1) + t.Errorf("deserialize() got1 = %v, want %v", got1, tt.want1) } }) } From dea149cf3fc30de55280687db93ef4eb758f85ba Mon Sep 17 00:00:00 2001 From: Ted Xu Date: Tue, 26 Mar 2024 10:35:29 +0800 Subject: [PATCH 2/3] Improve test coverage Signed-off-by: Ted Xu --- internal/storage/serde.go | 52 ++++++---------------------------- internal/storage/serde_test.go | 5 ++-- 2 files changed, 12 insertions(+), 45 deletions(-) diff --git a/internal/storage/serde.go b/internal/storage/serde.go index 4e0657b6c9cde..dcbe72da252d2 100644 --- a/internal/storage/serde.go +++ b/internal/storage/serde.go @@ -424,7 +424,8 @@ var serdeMap = func() map[schemapb.DataType]serdeEntry { sizeOfBytes := func(v any) uint64 { return uint64(len(v.([]byte))) } - m[schemapb.DataType_JSON] = serdeEntry{ + + byteEntry := serdeEntry{ func(i int) arrow.DataType { return arrow.BinaryTypes.Binary }, @@ -446,6 +447,8 @@ var serdeMap = func() map[schemapb.DataType]serdeEntry { sizeOfBytes, } + m[schemapb.DataType_JSON] = byteEntry + fixedSizeDeserializer := func(a arrow.Array, i int) (any, bool) { if arr, ok := a.(*array.FixedSizeBinary); ok && i < arr.Len() { return arr.Value(i), true @@ -516,44 +519,7 @@ var serdeMap = func() map[schemapb.DataType]serdeEntry { return uint64(len(v.([]float32)) * 4) }, } - m[schemapb.DataType_SparseFloatVector] = serdeEntry{ - func(i int) arrow.DataType { - return arrow.BinaryTypes.Binary - }, - func(a arrow.Array, i int) (any, bool) { - if arr, ok := a.(*array.Binary); ok && i < arr.Len() { - - value := arr.Value(i) - return value, true - // fieldData := &SparseFloatVectorFieldData{} - // if len(value)%8 != 0 { - // return nil, false - // } - - // fieldData.Contents = append(fieldData.Contents, value) - // rowDim := typeutil.SparseFloatRowDim(value) - // if rowDim > fieldData.Dim { - // fieldData.Dim = rowDim - // } - - // return fieldData, true - } - return nil, false - }, - func(b array.Builder, v any) bool { - if builder, ok := b.(*array.BinaryBuilder); ok { - if vv, ok := v.(*SparseFloatVectorFieldData); ok { - length := len(vv.SparseFloatArray.Contents) - for i := 0; i < length; i++ { - builder.Append(vv.SparseFloatArray.Contents[i]) - } - return true - } - } - return false - }, - sizeOfBytes, - } + m[schemapb.DataType_SparseFloatVector] = byteEntry return m }() @@ -853,9 +819,9 @@ func (sw *SerializeWriter[T]) WrittenMemorySize() uint64 { return sw.writtenMemorySize } -func (sw *SerializeWriter[T]) Close() { - sw.Flush() +func (sw *SerializeWriter[T]) Close() error { sw.rw.Close() + return sw.Flush() } func NewSerializeRecordWriter[T any](rw RecordWriter, serializer Serializer[T], batchSize int) *SerializeWriter[T] { @@ -1035,7 +1001,7 @@ func NewBinlogSerializeWriter(schema *schemapb.CollectionSchema, partitionID, se } ok = typeEntry.serialize(builders[fid], e) if !ok { - return nil, 0, errors.New(fmt.Sprintf("unexpected type %s", types[fid])) + return nil, 0, errors.New(fmt.Sprintf("serialize error on type %s", types[fid])) } memorySize += typeEntry.sizeof(e) } @@ -1054,6 +1020,6 @@ func NewBinlogSerializeWriter(schema *schemapb.CollectionSchema, partitionID, se field2Col[fid] = i i++ } - return newSimpleArrowRecord(array.NewRecord(arrow.NewSchema(fields, nil), arrays, -1), types, field2Col), memorySize, nil + return newSimpleArrowRecord(array.NewRecord(arrow.NewSchema(fields, nil), arrays, int64(len(v))), types, field2Col), memorySize, nil }, batchSize), nil } diff --git a/internal/storage/serde_test.go b/internal/storage/serde_test.go index 7d99d853fbe7f..17a10e3a2104e 100644 --- a/internal/storage/serde_test.go +++ b/internal/storage/serde_test.go @@ -148,7 +148,9 @@ func TestBinlogSerializeWriter(t *testing.T) { err = reader.Next() assert.Equal(t, io.EOF, err) - writer.Close() + err = writer.Close() + assert.NoError(t, err) + assert.True(t, writer.WrittenMemorySize() >= 429) // Read from the written data newblobs := make([]*Blob, len(writers)) @@ -170,7 +172,6 @@ func TestBinlogSerializeWriter(t *testing.T) { value := reader.Value() assertTestData(t, i, value) - writer.Write(value) } }) } From 5c0e0c87505214e1e8db9befb225c63490304587 Mon Sep 17 00:00:00 2001 From: Ted Xu Date: Tue, 9 Apr 2024 18:43:43 +0800 Subject: [PATCH 3/3] Remove unnecessary interface declarations Signed-off-by: Ted Xu --- internal/storage/serde.go | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/internal/storage/serde.go b/internal/storage/serde.go index dcbe72da252d2..7a64edf79b653 100644 --- a/internal/storage/serde.go +++ b/internal/storage/serde.go @@ -88,8 +88,9 @@ func (r *compositeRecord) Schema() map[FieldID]schemapb.DataType { return r.schema } +var _ RecordReader = (*compositeRecordReader)(nil) + type compositeRecordReader struct { - RecordReader blobs [][]*Blob blobPos int @@ -665,10 +666,10 @@ func NewBinlogDeserializeReader(blobs []*Blob, PKfieldID UniqueID) (*Deserialize }), nil } +var _ Record = (*selectiveRecord)(nil) + // selectiveRecord is a Record that only contains a single field, reusing existing Record. type selectiveRecord struct { - Record - r Record selectedFieldId FieldID @@ -708,8 +709,9 @@ func newSelectiveRecord(r Record, selectedFieldId FieldID) *selectiveRecord { } } +var _ RecordWriter = (*compositeRecordWriter)(nil) + type compositeRecordWriter struct { - RecordWriter writers map[FieldID]RecordWriter } @@ -742,8 +744,9 @@ func newCompositeRecordWriter(writers map[FieldID]RecordWriter) *compositeRecord } } +var _ RecordWriter = (*singleFieldRecordWriter)(nil) + type singleFieldRecordWriter struct { - RecordWriter fw *pqarrow.FileWriter fieldId FieldID