Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

*: support fixed dimension vector #54956

Closed
wants to merge 23 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ require (
github.com/minio/c2goasm v0.0.0-20190812172519-36a3d3bbc4f3 // indirect
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
github.com/pierrec/lz4/v4 v4.1.15 // indirect
github.com/pingcap/tidb/parser v0.0.0-20231013125129-93a834a6bf8d // indirect
github.com/qri-io/jsonpointer v0.1.1 // indirect
github.com/zeebo/xxh3 v1.0.2 // indirect
)
Expand Down Expand Up @@ -239,7 +240,7 @@ require (
github.com/jfcg/sixb v1.3.8 // indirect
github.com/jmespath/go-jmespath v0.4.0 // indirect
github.com/jonboulle/clockwork v0.4.0 // indirect
github.com/json-iterator/go v1.1.12 // indirect
github.com/json-iterator/go v1.1.12
github.com/klauspost/cpuid v1.3.1 // indirect
github.com/kr/pretty v0.3.1 // indirect
github.com/kr/text v0.2.0 // indirect
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -704,6 +704,8 @@ github.com/pingcap/log v1.1.1-0.20240314023424-862ccc32f18d h1:y3EueKVfVykdpTyfU
github.com/pingcap/log v1.1.1-0.20240314023424-862ccc32f18d/go.mod h1:ORfBOFp1eteu2odzsyaxI+b8TzJwgjwyQcGhI+9SfEA=
github.com/pingcap/sysutil v1.0.1-0.20240311050922-ae81ee01f3a5 h1:T4pXRhBflzDeAhmOQHNPRRogMYxP13V7BkYw3ZsoSfE=
github.com/pingcap/sysutil v1.0.1-0.20240311050922-ae81ee01f3a5/go.mod h1:rlimy0GcTvjiJqvD5mXTRr8O2eNZPBrcUgiWVYp9530=
github.com/pingcap/tidb/parser v0.0.0-20231013125129-93a834a6bf8d h1:EHXDxa7eq8vWc2T8cwstlr3A48dx4TvMsCh5Y7z2VZ8=
github.com/pingcap/tidb/parser v0.0.0-20231013125129-93a834a6bf8d/go.mod h1:cwq4bKUlftpWuznB+rqNwbN0xy6/i5SL/nYvEKeJn4s=
github.com/pingcap/tipb v0.0.0-20240318032315-55a7867ddd50 h1:fVNBE06Rjec+EIHaYAKAHa/bIt5lnu3Zh9O6kV7ZAdg=
github.com/pingcap/tipb v0.0.0-20240318032315-55a7867ddd50/go.mod h1:A7mrd7WHBl1o63LE2bIBGEJMTNWXqhgmYiOvMLxozfs=
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ=
Expand Down
2 changes: 2 additions & 0 deletions pkg/ddl/column.go
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,8 @@ func needChangeColumnData(oldCol, newCol *model.ColumnInfo) bool {
if types.IsBinaryStr(&oldCol.FieldType) {
return newCol.GetFlen() != oldCol.GetFlen()
}
case mysql.TypeTiDBVectorFloat32:
return newCol.GetFlen() != types.UnspecifiedLength && oldCol.GetFlen() != newCol.GetFlen()
}

return needTruncationOrToggleSign()
Expand Down
5 changes: 3 additions & 2 deletions pkg/ddl/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -1636,7 +1636,7 @@ func getDefaultValue(ctx exprctx.BuildContext, col *table.Column, option *ast.Co
}

if v.Kind() == types.KindBinaryLiteral || v.Kind() == types.KindMysqlBit {
if types.IsTypeBlob(tp) || tp == mysql.TypeJSON {
if types.IsTypeBlob(tp) || tp == mysql.TypeJSON || tp == mysql.TypeTiDBVectorFloat32 {
// BLOB/TEXT/JSON column cannot have a default value.
// Skip the unnecessary decode procedure.
return v.GetString(), false, err
Expand Down Expand Up @@ -3483,7 +3483,8 @@ func checkPartitionByList(ctx sessionctx.Context, tbInfo *model.TableInfo) error

func isValidKeyPartitionColType(fieldType types.FieldType) bool {
switch fieldType.GetType() {
case mysql.TypeBlob, mysql.TypeMediumBlob, mysql.TypeLongBlob, mysql.TypeJSON, mysql.TypeGeometry, mysql.TypeTiDBVectorFloat32:
case mysql.TypeBlob, mysql.TypeMediumBlob, mysql.TypeLongBlob, mysql.TypeJSON, mysql.TypeGeometry,
mysql.TypeTiDBVectorFloat32:
return false
default:
return true
Expand Down
8 changes: 8 additions & 0 deletions pkg/ddl/index.go
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,14 @@ func checkIndexColumn(ctx sessionctx.Context, col *model.ColumnInfo, indexColumn
return errors.Trace(dbterror.ErrJSONUsedAsKey.GenWithStackByArgs(col.Name.O))
}

// Vector column cannot index, for now.
if col.FieldType.GetType() == mysql.TypeTiDBVectorFloat32 {
if col.Hidden {
return dbterror.ErrFunctionalIndexOnJSONOrGeometryFunction
}
return errors.Trace(dbterror.ErrWrongKeyColumn.GenWithStackByArgs(col.Name))
}

// Length must be specified and non-zero for BLOB and TEXT column indexes.
if types.IsTypeBlob(col.FieldType.GetType()) {
if indexColumnLen == types.UnspecifiedLength {
Expand Down
3 changes: 3 additions & 0 deletions pkg/executor/aggfuncs/aggfuncs.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ var (
_ AggFunc = (*countOriginal4Time)(nil)
_ AggFunc = (*countOriginal4Duration)(nil)
_ AggFunc = (*countOriginal4JSON)(nil)
_ AggFunc = (*countOriginal4VectorFloat32)(nil)
_ AggFunc = (*countOriginal4String)(nil)
_ AggFunc = (*countOriginalWithDistinct4Int)(nil)
_ AggFunc = (*countOriginalWithDistinct4Real)(nil)
Expand Down Expand Up @@ -61,6 +62,7 @@ var (
_ AggFunc = (*firstRow4Float32)(nil)
_ AggFunc = (*firstRow4Float64)(nil)
_ AggFunc = (*firstRow4JSON)(nil)
_ AggFunc = (*firstRow4VectorFloat32)(nil)
_ AggFunc = (*firstRow4Enum)(nil)
_ AggFunc = (*firstRow4Set)(nil)

Expand All @@ -73,6 +75,7 @@ var (
_ AggFunc = (*maxMin4String)(nil)
_ AggFunc = (*maxMin4Duration)(nil)
_ AggFunc = (*maxMin4JSON)(nil)
_ AggFunc = (*maxMin4VectorFloat32)(nil)
_ AggFunc = (*maxMin4Enum)(nil)
_ AggFunc = (*maxMin4Set)(nil)

Expand Down
6 changes: 6 additions & 0 deletions pkg/executor/aggfuncs/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,8 @@ func buildCount(ctx expression.EvalContext, aggFuncDesc *aggregation.AggFuncDesc
return &countOriginal4Duration{baseCount{base}}
case types.ETJson:
return &countOriginal4JSON{baseCount{base}}
case types.ETVectorFloat32:
return &countOriginal4VectorFloat32{baseCount{base}}
case types.ETString:
return &countOriginal4String{baseCount{base}}
}
Expand Down Expand Up @@ -378,6 +380,8 @@ func buildFirstRow(aggFuncDesc *aggregation.AggFuncDesc, ordinal int) AggFunc {
return &firstRow4String{base}
case types.ETJson:
return &firstRow4JSON{base}
case types.ETVectorFloat32:
return &firstRow4VectorFloat32{base}
}
}
return nil
Expand Down Expand Up @@ -431,6 +435,8 @@ func buildMaxMin(aggFuncDesc *aggregation.AggFuncDesc, ordinal int, isMax bool)
return &maxMin4Duration{base}
case types.ETJson:
return &maxMin4JSON{base}
case types.ETVectorFloat32:
return &maxMin4VectorFloat32{base}
}
}
return nil
Expand Down
49 changes: 49 additions & 0 deletions pkg/executor/aggfuncs/func_count.go
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,55 @@ func (e *countOriginal4JSON) Slide(sctx AggFuncUpdateContext, getRow func(uint64
return nil
}

type countOriginal4VectorFloat32 struct {
baseCount
}

func (e *countOriginal4VectorFloat32) UpdatePartialResult(sctx AggFuncUpdateContext, rowsInGroup []chunk.Row, pr PartialResult) (memDelta int64, err error) {
p := (*partialResult4Count)(pr)

for _, row := range rowsInGroup {
_, isNull, err := e.args[0].EvalVectorFloat32(sctx, row)
if err != nil {
return 0, err
}
if isNull {
continue
}

*p++
}

return 0, nil
}

var _ SlidingWindowAggFunc = &countOriginal4VectorFloat32{}

func (e *countOriginal4VectorFloat32) Slide(sctx AggFuncUpdateContext, getRow func(uint64) chunk.Row, lastStart, lastEnd uint64, shiftStart, shiftEnd uint64, pr PartialResult) error {
p := (*partialResult4Count)(pr)
for i := uint64(0); i < shiftStart; i++ {
_, isNull, err := e.args[0].EvalVectorFloat32(sctx, getRow(lastStart+i))
if err != nil {
return err
}
if isNull {
continue
}
*p--
}
for i := uint64(0); i < shiftEnd; i++ {
_, isNull, err := e.args[0].EvalVectorFloat32(sctx, getRow(lastEnd+i))
if err != nil {
return err
}
if isNull {
continue
}
*p++
}
return nil
}

type countOriginal4String struct {
baseCount
}
Expand Down
7 changes: 7 additions & 0 deletions pkg/executor/aggfuncs/func_count_distinct.go
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,13 @@ func evalAndEncode(
break
}
encodedBytes = val.HashValue(encodedBytes)
case types.ETVectorFloat32:
var val types.VectorFloat32
val, isNull, err = arg.EvalVectorFloat32(sctx, row)
if err != nil || isNull {
break
}
encodedBytes = val.SerializeTo(encodedBytes)
case types.ETString:
var val string
val, isNull, err = arg.EvalString(sctx, row)
Expand Down
54 changes: 54 additions & 0 deletions pkg/executor/aggfuncs/func_first_row.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ const (
DefPartialResult4FirstRowDurationSize = int64(unsafe.Sizeof(partialResult4FirstRowDuration{}))
// DefPartialResult4FirstRowJSONSize is the size of partialResult4FirstRowJSON
DefPartialResult4FirstRowJSONSize = int64(unsafe.Sizeof(partialResult4FirstRowJSON{}))
// DefPartialResult4FirstRowVectorFloat32Size is the size of partialResult4FirstRowVectorFloat32
DefPartialResult4FirstRowVectorFloat32Size = int64(unsafe.Sizeof(partialResult4FirstRowVectorFloat32{}))
// DefPartialResult4FirstRowDecimalSize is the size of partialResult4FirstRowDecimal
DefPartialResult4FirstRowDecimalSize = int64(unsafe.Sizeof(partialResult4FirstRowDecimal{}))
// DefPartialResult4FirstRowEnumSize is the size of partialResult4FirstRowEnum
Expand Down Expand Up @@ -104,6 +106,12 @@ type partialResult4FirstRowJSON struct {
val types.BinaryJSON
}

type partialResult4FirstRowVectorFloat32 struct {
basePartialResult4FirstRow

val types.VectorFloat32
}

type partialResult4FirstRowEnum struct {
basePartialResult4FirstRow

Expand Down Expand Up @@ -579,6 +587,52 @@ func (e *firstRow4JSON) deserializeForSpill(helper *deserializeHelper) (PartialR
return pr, memDelta
}

type firstRow4VectorFloat32 struct {
baseAggFunc
}

func (*firstRow4VectorFloat32) AllocPartialResult() (pr PartialResult, memDelta int64) {
return PartialResult(new(partialResult4FirstRowVectorFloat32)), DefPartialResult4FirstRowVectorFloat32Size
}

func (*firstRow4VectorFloat32) ResetPartialResult(pr PartialResult) {
p := (*partialResult4FirstRowVectorFloat32)(pr)
p.isNull, p.gotFirstRow = false, false
}

func (e *firstRow4VectorFloat32) UpdatePartialResult(sctx AggFuncUpdateContext, rowsInGroup []chunk.Row, pr PartialResult) (memDelta int64, err error) {
p := (*partialResult4FirstRowVectorFloat32)(pr)
if p.gotFirstRow {
return memDelta, nil
}
if len(rowsInGroup) > 0 {
input, isNull, err := e.args[0].EvalVectorFloat32(sctx, rowsInGroup[0])
if err != nil {
return memDelta, err
}
p.gotFirstRow, p.isNull, p.val = true, isNull, input.Clone()
memDelta += int64(input.EstimatedMemUsage())
}
return memDelta, nil
}
func (*firstRow4VectorFloat32) MergePartialResult(_ AggFuncUpdateContext, src, dst PartialResult) (memDelta int64, err error) {
p1, p2 := (*partialResult4FirstRowVectorFloat32)(src), (*partialResult4FirstRowVectorFloat32)(dst)
if !p2.gotFirstRow {
*p2 = *p1
}
return memDelta, nil
}

func (e *firstRow4VectorFloat32) AppendFinalResult2Chunk(_ AggFuncUpdateContext, pr PartialResult, chk *chunk.Chunk) error {
p := (*partialResult4FirstRowVectorFloat32)(pr)
if p.isNull || !p.gotFirstRow {
chk.AppendNull(e.ordinal)
return nil
}
chk.AppendVectorFloat32(e.ordinal, p.val)
return nil
}

type firstRow4Decimal struct {
baseAggFunc
}
Expand Down
76 changes: 76 additions & 0 deletions pkg/executor/aggfuncs/func_max_min.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,8 @@ const (
DefPartialResult4MaxMinStringSize = int64(unsafe.Sizeof(partialResult4MaxMinString{}))
// DefPartialResult4MaxMinJSONSize is the size of partialResult4MaxMinJSON
DefPartialResult4MaxMinJSONSize = int64(unsafe.Sizeof(partialResult4MaxMinJSON{}))
// DefPartialResult4MaxMinVectorFloat32Size is the size of partialResult4MaxMinVectorFloat32
DefPartialResult4MaxMinVectorFloat32Size = int64(unsafe.Sizeof(partialResult4MaxMinVectorFloat32{}))
// DefPartialResult4MaxMinEnumSize is the size of partialResult4MaxMinEnum
DefPartialResult4MaxMinEnumSize = int64(unsafe.Sizeof(partialResult4MaxMinEnum{}))
// DefPartialResult4MaxMinSetSize is the size of partialResult4MaxMinSet
Expand Down Expand Up @@ -221,6 +223,11 @@ type partialResult4MaxMinJSON struct {
isNull bool
}

type partialResult4MaxMinVectorFloat32 struct {
val types.VectorFloat32
isNull bool
}

type partialResult4MaxMinEnum struct {
val types.Enum
isNull bool
Expand Down Expand Up @@ -1632,6 +1639,75 @@ func (e *maxMin4JSON) deserializeForSpill(helper *deserializeHelper) (PartialRes
return pr, memDelta
}

type maxMin4VectorFloat32 struct {
baseMaxMinAggFunc
}

func (*maxMin4VectorFloat32) AllocPartialResult() (pr PartialResult, memDelta int64) {
p := new(partialResult4MaxMinVectorFloat32)
p.isNull = true
return PartialResult(p), DefPartialResult4MaxMinVectorFloat32Size
}

func (*maxMin4VectorFloat32) ResetPartialResult(pr PartialResult) {
p := (*partialResult4MaxMinVectorFloat32)(pr)
p.isNull = true
}

func (e *maxMin4VectorFloat32) AppendFinalResult2Chunk(_ AggFuncUpdateContext, pr PartialResult, chk *chunk.Chunk) error {
p := (*partialResult4MaxMinVectorFloat32)(pr)
if p.isNull {
chk.AppendNull(e.ordinal)
return nil
}
chk.AppendVectorFloat32(e.ordinal, p.val)
return nil
}

func (e *maxMin4VectorFloat32) UpdatePartialResult(sctx AggFuncUpdateContext, rowsInGroup []chunk.Row, pr PartialResult) (memDelta int64, err error) {
p := (*partialResult4MaxMinVectorFloat32)(pr)
for _, row := range rowsInGroup {
input, isNull, err := e.args[0].EvalVectorFloat32(sctx, row)
if err != nil {
return memDelta, err
}
if isNull {
continue
}
if p.isNull {
p.val = input.Clone()
memDelta += int64(input.EstimatedMemUsage())
p.isNull = false
continue
}
cmp := input.Compare(p.val)
if e.isMax && cmp > 0 || !e.isMax && cmp < 0 {
oldMem := p.val.EstimatedMemUsage()
newMem := input.EstimatedMemUsage()
memDelta += int64(newMem - oldMem)
p.val = input.Clone()
}
}
return memDelta, nil
}

func (e *maxMin4VectorFloat32) MergePartialResult(_ AggFuncUpdateContext, src, dst PartialResult) (memDelta int64, err error) {
p1, p2 := (*partialResult4MaxMinVectorFloat32)(src), (*partialResult4MaxMinVectorFloat32)(dst)
if p1.isNull {
return 0, nil
}
if p2.isNull {
*p2 = *p1
return 0, nil
}
cmp := p1.val.Compare(p2.val)
if e.isMax && cmp > 0 || !e.isMax && cmp < 0 {
p2.val = p1.val
p2.isNull = false
}
return 0, nil
}

type maxMin4Enum struct {
baseMaxMinAggFunc
}
Expand Down
Loading