Skip to content

Commit

Permalink
*: Add Vector data type (#54635)
Browse files Browse the repository at this point in the history
ref #54245
  • Loading branch information
EricZequan authored Aug 2, 2024
1 parent 1acb8f7 commit 5389de9
Show file tree
Hide file tree
Showing 69 changed files with 2,237 additions and 117 deletions.
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ require (
github.com/jfcg/sixb v1.3.8 // indirect
github.com/jmespath/go-jmespath v0.4.0 // indirect
github.com/jonboulle/clockwork v0.4.0 // indirect
github.com/json-iterator/go v1.1.12 // indirect
github.com/json-iterator/go v1.1.12
github.com/klauspost/cpuid v1.3.1 // indirect
github.com/kr/pretty v0.3.1 // indirect
github.com/kr/text v0.2.0 // indirect
Expand Down
5 changes: 3 additions & 2 deletions pkg/ddl/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -1636,7 +1636,7 @@ func getDefaultValue(ctx exprctx.BuildContext, col *table.Column, option *ast.Co
}

if v.Kind() == types.KindBinaryLiteral || v.Kind() == types.KindMysqlBit {
if types.IsTypeBlob(tp) || tp == mysql.TypeJSON {
if types.IsTypeBlob(tp) || tp == mysql.TypeJSON || tp == mysql.TypeTiDBVectorFloat32 {
// BLOB/TEXT/JSON column cannot have a default value.
// Skip the unnecessary decode procedure.
return v.GetString(), false, err
Expand Down Expand Up @@ -3483,7 +3483,8 @@ func checkPartitionByList(ctx sessionctx.Context, tbInfo *model.TableInfo) error

func isValidKeyPartitionColType(fieldType types.FieldType) bool {
switch fieldType.GetType() {
case mysql.TypeBlob, mysql.TypeMediumBlob, mysql.TypeLongBlob, mysql.TypeJSON, mysql.TypeGeometry, mysql.TypeTiDBVectorFloat32:
case mysql.TypeBlob, mysql.TypeMediumBlob, mysql.TypeLongBlob, mysql.TypeJSON, mysql.TypeGeometry,
mysql.TypeTiDBVectorFloat32:
return false
default:
return true
Expand Down
8 changes: 8 additions & 0 deletions pkg/ddl/index.go
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,14 @@ func checkIndexColumn(ctx sessionctx.Context, col *model.ColumnInfo, indexColumn
return errors.Trace(dbterror.ErrJSONUsedAsKey.GenWithStackByArgs(col.Name.O))
}

// Vector column cannot index, for now.
if col.FieldType.GetType() == mysql.TypeTiDBVectorFloat32 {
if col.Hidden {
return errors.Errorf("Cannot create an expression index on a function that returns a VECTOR value")
}
return errors.Trace(dbterror.ErrWrongKeyColumn.GenWithStackByArgs(col.Name))
}

// Length must be specified and non-zero for BLOB and TEXT column indexes.
if types.IsTypeBlob(col.FieldType.GetType()) {
if indexColumnLen == types.UnspecifiedLength {
Expand Down
3 changes: 3 additions & 0 deletions pkg/executor/aggfuncs/aggfuncs.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ var (
_ AggFunc = (*countOriginal4Time)(nil)
_ AggFunc = (*countOriginal4Duration)(nil)
_ AggFunc = (*countOriginal4JSON)(nil)
_ AggFunc = (*countOriginal4VectorFloat32)(nil)
_ AggFunc = (*countOriginal4String)(nil)
_ AggFunc = (*countOriginalWithDistinct4Int)(nil)
_ AggFunc = (*countOriginalWithDistinct4Real)(nil)
Expand Down Expand Up @@ -61,6 +62,7 @@ var (
_ AggFunc = (*firstRow4Float32)(nil)
_ AggFunc = (*firstRow4Float64)(nil)
_ AggFunc = (*firstRow4JSON)(nil)
_ AggFunc = (*firstRow4VectorFloat32)(nil)
_ AggFunc = (*firstRow4Enum)(nil)
_ AggFunc = (*firstRow4Set)(nil)

Expand All @@ -73,6 +75,7 @@ var (
_ AggFunc = (*maxMin4String)(nil)
_ AggFunc = (*maxMin4Duration)(nil)
_ AggFunc = (*maxMin4JSON)(nil)
_ AggFunc = (*maxMin4VectorFloat32)(nil)
_ AggFunc = (*maxMin4Enum)(nil)
_ AggFunc = (*maxMin4Set)(nil)

Expand Down
6 changes: 6 additions & 0 deletions pkg/executor/aggfuncs/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,8 @@ func buildCount(ctx expression.EvalContext, aggFuncDesc *aggregation.AggFuncDesc
return &countOriginal4Duration{baseCount{base}}
case types.ETJson:
return &countOriginal4JSON{baseCount{base}}
case types.ETVectorFloat32:
return &countOriginal4VectorFloat32{baseCount{base}}
case types.ETString:
return &countOriginal4String{baseCount{base}}
}
Expand Down Expand Up @@ -378,6 +380,8 @@ func buildFirstRow(aggFuncDesc *aggregation.AggFuncDesc, ordinal int) AggFunc {
return &firstRow4String{base}
case types.ETJson:
return &firstRow4JSON{base}
case types.ETVectorFloat32:
return &firstRow4VectorFloat32{base}
}
}
return nil
Expand Down Expand Up @@ -431,6 +435,8 @@ func buildMaxMin(aggFuncDesc *aggregation.AggFuncDesc, ordinal int, isMax bool)
return &maxMin4Duration{base}
case types.ETJson:
return &maxMin4JSON{base}
case types.ETVectorFloat32:
return &maxMin4VectorFloat32{base}
}
}
return nil
Expand Down
49 changes: 49 additions & 0 deletions pkg/executor/aggfuncs/func_count.go
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,55 @@ func (e *countOriginal4JSON) Slide(sctx AggFuncUpdateContext, getRow func(uint64
return nil
}

type countOriginal4VectorFloat32 struct {
baseCount
}

func (e *countOriginal4VectorFloat32) UpdatePartialResult(sctx AggFuncUpdateContext, rowsInGroup []chunk.Row, pr PartialResult) (memDelta int64, err error) {
p := (*partialResult4Count)(pr)

for _, row := range rowsInGroup {
_, isNull, err := e.args[0].EvalVectorFloat32(sctx, row)
if err != nil {
return 0, err
}
if isNull {
continue
}

*p++
}

return 0, nil
}

var _ SlidingWindowAggFunc = &countOriginal4VectorFloat32{}

func (e *countOriginal4VectorFloat32) Slide(sctx AggFuncUpdateContext, getRow func(uint64) chunk.Row, lastStart, lastEnd uint64, shiftStart, shiftEnd uint64, pr PartialResult) error {
p := (*partialResult4Count)(pr)
for i := uint64(0); i < shiftStart; i++ {
_, isNull, err := e.args[0].EvalVectorFloat32(sctx, getRow(lastStart+i))
if err != nil {
return err
}
if isNull {
continue
}
*p--
}
for i := uint64(0); i < shiftEnd; i++ {
_, isNull, err := e.args[0].EvalVectorFloat32(sctx, getRow(lastEnd+i))
if err != nil {
return err
}
if isNull {
continue
}
*p++
}
return nil
}

type countOriginal4String struct {
baseCount
}
Expand Down
7 changes: 7 additions & 0 deletions pkg/executor/aggfuncs/func_count_distinct.go
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,13 @@ func evalAndEncode(
break
}
encodedBytes = val.HashValue(encodedBytes)
case types.ETVectorFloat32:
var val types.VectorFloat32
val, isNull, err = arg.EvalVectorFloat32(sctx, row)
if err != nil || isNull {
break
}
encodedBytes = val.SerializeTo(encodedBytes)
case types.ETString:
var val string
val, isNull, err = arg.EvalString(sctx, row)
Expand Down
54 changes: 54 additions & 0 deletions pkg/executor/aggfuncs/func_first_row.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ const (
DefPartialResult4FirstRowDurationSize = int64(unsafe.Sizeof(partialResult4FirstRowDuration{}))
// DefPartialResult4FirstRowJSONSize is the size of partialResult4FirstRowJSON
DefPartialResult4FirstRowJSONSize = int64(unsafe.Sizeof(partialResult4FirstRowJSON{}))
// DefPartialResult4FirstRowVectorFloat32Size is the size of partialResult4FirstRowVectorFloat32
DefPartialResult4FirstRowVectorFloat32Size = int64(unsafe.Sizeof(partialResult4FirstRowVectorFloat32{}))
// DefPartialResult4FirstRowDecimalSize is the size of partialResult4FirstRowDecimal
DefPartialResult4FirstRowDecimalSize = int64(unsafe.Sizeof(partialResult4FirstRowDecimal{}))
// DefPartialResult4FirstRowEnumSize is the size of partialResult4FirstRowEnum
Expand Down Expand Up @@ -104,6 +106,12 @@ type partialResult4FirstRowJSON struct {
val types.BinaryJSON
}

type partialResult4FirstRowVectorFloat32 struct {
basePartialResult4FirstRow

val types.VectorFloat32
}

type partialResult4FirstRowEnum struct {
basePartialResult4FirstRow

Expand Down Expand Up @@ -579,6 +587,52 @@ func (e *firstRow4JSON) deserializeForSpill(helper *deserializeHelper) (PartialR
return pr, memDelta
}

type firstRow4VectorFloat32 struct {
baseAggFunc
}

func (*firstRow4VectorFloat32) AllocPartialResult() (pr PartialResult, memDelta int64) {
return PartialResult(new(partialResult4FirstRowVectorFloat32)), DefPartialResult4FirstRowVectorFloat32Size
}

func (*firstRow4VectorFloat32) ResetPartialResult(pr PartialResult) {
p := (*partialResult4FirstRowVectorFloat32)(pr)
p.isNull, p.gotFirstRow = false, false
}

func (e *firstRow4VectorFloat32) UpdatePartialResult(sctx AggFuncUpdateContext, rowsInGroup []chunk.Row, pr PartialResult) (memDelta int64, err error) {
p := (*partialResult4FirstRowVectorFloat32)(pr)
if p.gotFirstRow {
return memDelta, nil
}
if len(rowsInGroup) > 0 {
input, isNull, err := e.args[0].EvalVectorFloat32(sctx, rowsInGroup[0])
if err != nil {
return memDelta, err
}
p.gotFirstRow, p.isNull, p.val = true, isNull, input.Clone()
memDelta += int64(input.EstimatedMemUsage())
}
return memDelta, nil
}
func (*firstRow4VectorFloat32) MergePartialResult(_ AggFuncUpdateContext, src, dst PartialResult) (memDelta int64, err error) {
p1, p2 := (*partialResult4FirstRowVectorFloat32)(src), (*partialResult4FirstRowVectorFloat32)(dst)
if !p2.gotFirstRow {
*p2 = *p1
}
return memDelta, nil
}

func (e *firstRow4VectorFloat32) AppendFinalResult2Chunk(_ AggFuncUpdateContext, pr PartialResult, chk *chunk.Chunk) error {
p := (*partialResult4FirstRowVectorFloat32)(pr)
if p.isNull || !p.gotFirstRow {
chk.AppendNull(e.ordinal)
return nil
}
chk.AppendVectorFloat32(e.ordinal, p.val)
return nil
}

type firstRow4Decimal struct {
baseAggFunc
}
Expand Down
76 changes: 76 additions & 0 deletions pkg/executor/aggfuncs/func_max_min.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,8 @@ const (
DefPartialResult4MaxMinStringSize = int64(unsafe.Sizeof(partialResult4MaxMinString{}))
// DefPartialResult4MaxMinJSONSize is the size of partialResult4MaxMinJSON
DefPartialResult4MaxMinJSONSize = int64(unsafe.Sizeof(partialResult4MaxMinJSON{}))
// DefPartialResult4MaxMinVectorFloat32Size is the size of partialResult4MaxMinVectorFloat32
DefPartialResult4MaxMinVectorFloat32Size = int64(unsafe.Sizeof(partialResult4MaxMinVectorFloat32{}))
// DefPartialResult4MaxMinEnumSize is the size of partialResult4MaxMinEnum
DefPartialResult4MaxMinEnumSize = int64(unsafe.Sizeof(partialResult4MaxMinEnum{}))
// DefPartialResult4MaxMinSetSize is the size of partialResult4MaxMinSet
Expand Down Expand Up @@ -221,6 +223,11 @@ type partialResult4MaxMinJSON struct {
isNull bool
}

type partialResult4MaxMinVectorFloat32 struct {
val types.VectorFloat32
isNull bool
}

type partialResult4MaxMinEnum struct {
val types.Enum
isNull bool
Expand Down Expand Up @@ -1632,6 +1639,75 @@ func (e *maxMin4JSON) deserializeForSpill(helper *deserializeHelper) (PartialRes
return pr, memDelta
}

type maxMin4VectorFloat32 struct {
baseMaxMinAggFunc
}

func (*maxMin4VectorFloat32) AllocPartialResult() (pr PartialResult, memDelta int64) {
p := new(partialResult4MaxMinVectorFloat32)
p.isNull = true
return PartialResult(p), DefPartialResult4MaxMinVectorFloat32Size
}

func (*maxMin4VectorFloat32) ResetPartialResult(pr PartialResult) {
p := (*partialResult4MaxMinVectorFloat32)(pr)
p.isNull = true
}

func (e *maxMin4VectorFloat32) AppendFinalResult2Chunk(_ AggFuncUpdateContext, pr PartialResult, chk *chunk.Chunk) error {
p := (*partialResult4MaxMinVectorFloat32)(pr)
if p.isNull {
chk.AppendNull(e.ordinal)
return nil
}
chk.AppendVectorFloat32(e.ordinal, p.val)
return nil
}

func (e *maxMin4VectorFloat32) UpdatePartialResult(sctx AggFuncUpdateContext, rowsInGroup []chunk.Row, pr PartialResult) (memDelta int64, err error) {
p := (*partialResult4MaxMinVectorFloat32)(pr)
for _, row := range rowsInGroup {
input, isNull, err := e.args[0].EvalVectorFloat32(sctx, row)
if err != nil {
return memDelta, err
}
if isNull {
continue
}
if p.isNull {
p.val = input.Clone()
memDelta += int64(input.EstimatedMemUsage())
p.isNull = false
continue
}
cmp := input.Compare(p.val)
if e.isMax && cmp > 0 || !e.isMax && cmp < 0 {
oldMem := p.val.EstimatedMemUsage()
newMem := input.EstimatedMemUsage()
memDelta += int64(newMem - oldMem)
p.val = input.Clone()
}
}
return memDelta, nil
}

func (e *maxMin4VectorFloat32) MergePartialResult(_ AggFuncUpdateContext, src, dst PartialResult) (memDelta int64, err error) {
p1, p2 := (*partialResult4MaxMinVectorFloat32)(src), (*partialResult4MaxMinVectorFloat32)(dst)
if p1.isNull {
return 0, nil
}
if p2.isNull {
*p2 = *p1
return 0, nil
}
cmp := p1.val.Compare(p2.val)
if e.isMax && cmp > 0 || !e.isMax && cmp < 0 {
p2.val = p1.val
p2.isNull = false
}
return 0, nil
}

type maxMin4Enum struct {
baseMaxMinAggFunc
}
Expand Down
27 changes: 27 additions & 0 deletions pkg/executor/aggfuncs/func_value.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package aggfuncs

import (
"fmt"
"unsafe"

"github.com/pingcap/tidb/pkg/expression"
Expand Down Expand Up @@ -47,6 +48,8 @@ const (
DefValue4StringSize = int64(unsafe.Sizeof(value4String{}))
// DefValue4JSONSize is the size of value4JSON
DefValue4JSONSize = int64(unsafe.Sizeof(value4JSON{}))
// DefValue4VectorFloat32Size is the size of value4VectorFloat32
DefValue4VectorFloat32Size = int64(unsafe.Sizeof(value4VectorFloat32{}))
)

// valueEvaluator is used to evaluate values for `first_value`, `last_value`, `nth_value`,
Expand Down Expand Up @@ -207,6 +210,26 @@ func (v *value4JSON) appendResult(chk *chunk.Chunk, colIdx int) {
}
}

type value4VectorFloat32 struct {
val types.VectorFloat32
isNull bool
}

func (v *value4VectorFloat32) evaluateRow(ctx expression.EvalContext, expr expression.Expression, row chunk.Row) (memDelta int64, err error) {
originalLength := v.val.EstimatedMemUsage()
v.val, v.isNull, err = expr.EvalVectorFloat32(ctx, row)
v.val = v.val.Clone() // deep copy to avoid content change.
return int64(v.val.EstimatedMemUsage() - originalLength), err
}

func (v *value4VectorFloat32) appendResult(chk *chunk.Chunk, colIdx int) {
if v.isNull {
chk.AppendNull(colIdx)
} else {
chk.AppendVectorFloat32(colIdx, v.val)
}
}

func buildValueEvaluator(tp *types.FieldType) (ve valueEvaluator, memDelta int64) {
evalType := tp.EvalType()
if tp.GetType() == mysql.TypeBit {
Expand All @@ -232,6 +255,10 @@ func buildValueEvaluator(tp *types.FieldType) (ve valueEvaluator, memDelta int64
return &value4String{}, DefValue4StringSize
case types.ETJson:
return &value4JSON{}, DefValue4JSONSize
case types.ETVectorFloat32:
return &value4VectorFloat32{}, DefValue4VectorFloat32Size
default:
panic(fmt.Sprintf("unsupported eval type %v", evalType))
}
return nil, 0
}
Expand Down
1 change: 1 addition & 0 deletions pkg/executor/internal/vecgroupchecker/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ go_library(
"//pkg/types",
"//pkg/util/chunk",
"//pkg/util/codec",
"@com_github_pingcap_errors//:errors",
],
)

Expand Down
Loading

0 comments on commit 5389de9

Please sign in to comment.