Skip to content

Commit

Permalink
expression: Add vector functions (#55021)
Browse files Browse the repository at this point in the history
ref #54245
  • Loading branch information
EricZequan committed Aug 12, 2024
1 parent 7fff125 commit 2ba29a2
Show file tree
Hide file tree
Showing 16 changed files with 1,073 additions and 14 deletions.
19 changes: 19 additions & 0 deletions pkg/expression/aggregation/aggregation.go
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,9 @@ func CheckAggPushDown(ctx expression.EvalContext, aggFunc *AggFuncDesc, storeTyp
if aggFunc.Name == ast.AggFuncApproxPercentile {
return false
}
if !checkVectorAggPushDown(ctx, aggFunc) {
return false
}
ret := true
switch storeType {
case kv.TiFlash:
Expand All @@ -253,6 +256,22 @@ func CheckAggPushDown(ctx expression.EvalContext, aggFunc *AggFuncDesc, storeTyp
return ret
}

// checkVectorAggPushDown returns false if this aggregate function is not supported to push down.
// - The aggregate function is not calculated over a Vector column (returns true)
// - The aggregate function is calculated over a Vector column and the function is supported (returns true)
// - The aggregate function is calculated over a Vector column and the function is not supported (returns false)
func checkVectorAggPushDown(ctx expression.EvalContext, aggFunc *AggFuncDesc) bool {
switch aggFunc.Name {
case ast.AggFuncCount, ast.AggFuncMin, ast.AggFuncMax, ast.AggFuncFirstRow:
return true
default:
if aggFunc.Args[0].GetType(ctx).GetType() == mysql.TypeTiDBVectorFloat32 {
return false
}
}
return true
}

// CheckAggPushFlash checks whether an agg function can be pushed to flash storage.
func CheckAggPushFlash(ctx expression.EvalContext, aggFunc *AggFuncDesc) bool {
for _, arg := range aggFunc.Args {
Expand Down
11 changes: 8 additions & 3 deletions pkg/expression/builtin.go
Original file line number Diff line number Diff line change
Expand Up @@ -929,9 +929,14 @@ var funcs = map[string]functionClass{
ast.JSONLength: &jsonLengthFunctionClass{baseFunctionClass{ast.JSONLength, 1, 2}},

// vector functions (TiDB extension)
ast.VecDims: &vecDimsFunctionClass{baseFunctionClass{ast.VecDims, 1, 1}},
ast.VecFromText: &vecFromTextFunctionClass{baseFunctionClass{ast.VecFromText, 1, 1}},
ast.VecAsText: &vecAsTextFunctionClass{baseFunctionClass{ast.VecAsText, 1, 1}},
ast.VecDims: &vecDimsFunctionClass{baseFunctionClass{ast.VecDims, 1, 1}},
ast.VecL1Distance: &vecL1DistanceFunctionClass{baseFunctionClass{ast.VecL1Distance, 2, 2}},
ast.VecL2Distance: &vecL2DistanceFunctionClass{baseFunctionClass{ast.VecL2Distance, 2, 2}},
ast.VecNegativeInnerProduct: &vecNegativeInnerProductFunctionClass{baseFunctionClass{ast.VecNegativeInnerProduct, 2, 2}},
ast.VecCosineDistance: &vecCosineDistanceFunctionClass{baseFunctionClass{ast.VecCosineDistance, 2, 2}},
ast.VecL2Norm: &vecL2NormFunctionClass{baseFunctionClass{ast.VecL2Norm, 1, 1}},
ast.VecFromText: &vecFromTextFunctionClass{baseFunctionClass{ast.VecFromText, 1, 1}},
ast.VecAsText: &vecAsTextFunctionClass{baseFunctionClass{ast.VecAsText, 1, 1}},

// TiDB internal function.
ast.TiDBDecodeKey: &tidbDecodeKeyFunctionClass{baseFunctionClass{ast.TiDBDecodeKey, 1, 1}},
Expand Down
118 changes: 118 additions & 0 deletions pkg/expression/builtin_arithmetic.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,10 @@ var (
_ builtinFunc = &builtinArithmeticModIntSignedSignedSig{}
_ builtinFunc = &builtinArithmeticModRealSig{}
_ builtinFunc = &builtinArithmeticModDecimalSig{}

_ builtinFunc = &builtinArithmeticPlusVectorFloat32Sig{}
_ builtinFunc = &builtinArithmeticMinusVectorFloat32Sig{}
_ builtinFunc = &builtinArithmeticMultiplyVectorFloat32Sig{}
)

// isConstantBinaryLiteral return true if expr is constant binary literal
Expand Down Expand Up @@ -167,6 +171,15 @@ func (c *arithmeticPlusFunctionClass) getFunction(ctx BuildContext, args []Expre
if err := c.verifyArgs(args); err != nil {
return nil, err
}
if args[0].GetType(ctx.GetEvalCtx()).EvalType().IsVectorKind() || args[1].GetType(ctx.GetEvalCtx()).EvalType().IsVectorKind() {
bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETVectorFloat32, types.ETVectorFloat32, types.ETVectorFloat32)
if err != nil {
return nil, err
}
sig := &builtinArithmeticPlusVectorFloat32Sig{bf}
// sig.setPbCode(tipb.ScalarFuncSig_PlusVectorFloat32)
return sig, nil
}
lhsEvalTp, rhsEvalTp := numericContextResultType(ctx.GetEvalCtx(), args[0]), numericContextResultType(ctx.GetEvalCtx(), args[1])
if lhsEvalTp == types.ETReal || rhsEvalTp == types.ETReal {
bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETReal, types.ETReal, types.ETReal)
Expand Down Expand Up @@ -317,6 +330,15 @@ func (c *arithmeticMinusFunctionClass) getFunction(ctx BuildContext, args []Expr
if err := c.verifyArgs(args); err != nil {
return nil, err
}
if args[0].GetType(ctx.GetEvalCtx()).EvalType().IsVectorKind() || args[1].GetType(ctx.GetEvalCtx()).EvalType().IsVectorKind() {
bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETVectorFloat32, types.ETVectorFloat32, types.ETVectorFloat32)
if err != nil {
return nil, err
}
sig := &builtinArithmeticMinusVectorFloat32Sig{bf}
// sig.setPbCode(tipb.ScalarFuncSig_PlusVectorFloat32)
return sig, nil
}
lhsEvalTp, rhsEvalTp := numericContextResultType(ctx.GetEvalCtx(), args[0]), numericContextResultType(ctx.GetEvalCtx(), args[1])
if lhsEvalTp == types.ETReal || rhsEvalTp == types.ETReal {
bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETReal, types.ETReal, types.ETReal)
Expand Down Expand Up @@ -500,6 +522,15 @@ func (c *arithmeticMultiplyFunctionClass) getFunction(ctx BuildContext, args []E
if err := c.verifyArgs(args); err != nil {
return nil, err
}
if args[0].GetType(ctx.GetEvalCtx()).EvalType().IsVectorKind() || args[1].GetType(ctx.GetEvalCtx()).EvalType().IsVectorKind() {
bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETVectorFloat32, types.ETVectorFloat32, types.ETVectorFloat32)
if err != nil {
return nil, err
}
sig := &builtinArithmeticMultiplyVectorFloat32Sig{bf}
// sig.setPbCode(tipb.ScalarFuncSig_PlusVectorFloat32)
return sig, nil
}
lhsTp, rhsTp := args[0].GetType(ctx.GetEvalCtx()), args[1].GetType(ctx.GetEvalCtx())
lhsEvalTp, rhsEvalTp := numericContextResultType(ctx.GetEvalCtx(), args[0]), numericContextResultType(ctx.GetEvalCtx(), args[1])
if lhsEvalTp == types.ETReal || rhsEvalTp == types.ETReal {
Expand Down Expand Up @@ -1157,3 +1188,90 @@ func (s *builtinArithmeticModIntSignedSignedSig) evalInt(ctx EvalContext, row ch

return a % b, false, nil
}

type builtinArithmeticPlusVectorFloat32Sig struct {
baseBuiltinFunc
}

func (s *builtinArithmeticPlusVectorFloat32Sig) Clone() builtinFunc {
newSig := &builtinArithmeticPlusVectorFloat32Sig{}
newSig.cloneFrom(&s.baseBuiltinFunc)
return newSig
}

func (s *builtinArithmeticPlusVectorFloat32Sig) evalVectorFloat32(ctx EvalContext, row chunk.Row) (types.VectorFloat32, bool, error) {
a, isLHSNull, err := s.args[0].EvalVectorFloat32(ctx, row)
if err != nil {
return types.ZeroVectorFloat32, isLHSNull, err
}
b, isRHSNull, err := s.args[1].EvalVectorFloat32(ctx, row)
if err != nil {
return types.ZeroVectorFloat32, isRHSNull, err
}
if isLHSNull || isRHSNull {
return types.ZeroVectorFloat32, true, nil
}
v, err := a.Add(b)
if err != nil {
return types.ZeroVectorFloat32, true, err
}
return v, false, nil
}

type builtinArithmeticMinusVectorFloat32Sig struct {
baseBuiltinFunc
}

func (s *builtinArithmeticMinusVectorFloat32Sig) Clone() builtinFunc {
newSig := &builtinArithmeticMinusVectorFloat32Sig{}
newSig.cloneFrom(&s.baseBuiltinFunc)
return newSig
}

func (s *builtinArithmeticMinusVectorFloat32Sig) evalVectorFloat32(ctx EvalContext, row chunk.Row) (types.VectorFloat32, bool, error) {
a, isLHSNull, err := s.args[0].EvalVectorFloat32(ctx, row)
if err != nil {
return types.ZeroVectorFloat32, isLHSNull, err
}
b, isRHSNull, err := s.args[1].EvalVectorFloat32(ctx, row)
if err != nil {
return types.ZeroVectorFloat32, isRHSNull, err
}
if isLHSNull || isRHSNull {
return types.ZeroVectorFloat32, true, nil
}
v, err := a.Sub(b)
if err != nil {
return types.ZeroVectorFloat32, true, err
}
return v, false, nil
}

type builtinArithmeticMultiplyVectorFloat32Sig struct {
baseBuiltinFunc
}

func (s *builtinArithmeticMultiplyVectorFloat32Sig) Clone() builtinFunc {
newSig := &builtinArithmeticMultiplyVectorFloat32Sig{}
newSig.cloneFrom(&s.baseBuiltinFunc)
return newSig
}

func (s *builtinArithmeticMultiplyVectorFloat32Sig) evalVectorFloat32(ctx EvalContext, row chunk.Row) (types.VectorFloat32, bool, error) {
a, isLHSNull, err := s.args[0].EvalVectorFloat32(ctx, row)
if err != nil {
return types.ZeroVectorFloat32, isLHSNull, err
}
b, isRHSNull, err := s.args[1].EvalVectorFloat32(ctx, row)
if err != nil {
return types.ZeroVectorFloat32, isRHSNull, err
}
if isLHSNull || isRHSNull {
return types.ZeroVectorFloat32, true, nil
}
v, err := a.Mul(b)
if err != nil {
return types.ZeroVectorFloat32, true, err
}
return v, false, nil
}
2 changes: 1 addition & 1 deletion pkg/expression/builtin_cast.go
Original file line number Diff line number Diff line change
Expand Up @@ -691,7 +691,7 @@ func (c *castAsVectorFloat32FunctionClass) getFunction(ctx BuildContext, args []
sig.setPbCode(tipb.ScalarFuncSig_CastVectorFloat32AsVectorFloat32)
case types.ETString:
sig = &builtinCastStringAsVectorFloat32Sig{bf}
sig.setPbCode(tipb.ScalarFuncSig_CastStringAsVectorFloat32)
// sig.setPbCode(tipb.ScalarFuncSig_CastStringAsVectorFloat32)
default:
return nil, errors.Errorf("cannot cast from %s to %s", argTp, "VectorFloat32")
}
Expand Down
80 changes: 80 additions & 0 deletions pkg/expression/builtin_compare.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ var (
_ builtinFunc = &builtinCoalesceStringSig{}
_ builtinFunc = &builtinCoalesceTimeSig{}
_ builtinFunc = &builtinCoalesceDurationSig{}
_ builtinFunc = &builtinCoalesceVectorFloat32Sig{}

_ builtinFunc = &builtinGreatestIntSig{}
_ builtinFunc = &builtinGreatestRealSig{}
Expand All @@ -54,13 +55,15 @@ var (
_ builtinFunc = &builtinGreatestDurationSig{}
_ builtinFunc = &builtinGreatestTimeSig{}
_ builtinFunc = &builtinGreatestCmpStringAsTimeSig{}
_ builtinFunc = &builtinGreatestVectorFloat32Sig{}
_ builtinFunc = &builtinLeastIntSig{}
_ builtinFunc = &builtinLeastRealSig{}
_ builtinFunc = &builtinLeastDecimalSig{}
_ builtinFunc = &builtinLeastStringSig{}
_ builtinFunc = &builtinLeastTimeSig{}
_ builtinFunc = &builtinLeastDurationSig{}
_ builtinFunc = &builtinLeastCmpStringAsTimeSig{}
_ builtinFunc = &builtinLeastVectorFloat32Sig{}
_ builtinFunc = &builtinIntervalIntSig{}
_ builtinFunc = &builtinIntervalRealSig{}

Expand Down Expand Up @@ -167,6 +170,9 @@ func (c *coalesceFunctionClass) getFunction(ctx BuildContext, args []Expression)
case types.ETJson:
sig = &builtinCoalesceJSONSig{bf}
sig.setPbCode(tipb.ScalarFuncSig_CoalesceJson)
case types.ETVectorFloat32:
sig = &builtinCoalesceVectorFloat32Sig{bf}
// sig.setPbCode(tipb.ScalarFuncSig_CoalesceVectorFloat32)
default:
return nil, errors.Errorf("%s is not supported for COALESCE()", retEvalTp)
}
Expand Down Expand Up @@ -331,6 +337,28 @@ func (b *builtinCoalesceJSONSig) evalJSON(ctx EvalContext, row chunk.Row) (res t
return res, isNull, err
}

// builtinCoalesceVectorFloat32Sig is builtin function coalesce signature which return type vector float32.
// See http://dev.mysql.com/doc/refman/5.7/en/comparison-operators.html#function_coalesce
type builtinCoalesceVectorFloat32Sig struct {
baseBuiltinFunc
}

func (b *builtinCoalesceVectorFloat32Sig) Clone() builtinFunc {
newSig := &builtinCoalesceVectorFloat32Sig{}
newSig.cloneFrom(&b.baseBuiltinFunc)
return newSig
}

func (b *builtinCoalesceVectorFloat32Sig) evalVectorFloat32(ctx EvalContext, row chunk.Row) (res types.VectorFloat32, isNull bool, err error) {
for _, a := range b.getArgs() {
res, isNull, err = a.EvalVectorFloat32(ctx, row)
if err != nil || !isNull {
break
}
}
return res, isNull, err
}

func aggregateType(ctx EvalContext, args []Expression) *types.FieldType {
fieldTypes := make([]*types.FieldType, len(args))
for i := range fieldTypes {
Expand Down Expand Up @@ -499,6 +527,9 @@ func (c *greatestFunctionClass) getFunction(ctx BuildContext, args []Expression)
sig = &builtinGreatestTimeSig{bf, false}
sig.setPbCode(tipb.ScalarFuncSig_GreatestTime)
}
case types.ETVectorFloat32:
sig = &builtinGreatestVectorFloat32Sig{bf}
// sig.setPbCode(tipb.ScalarFuncSig_GreatestVectorFloat32)
default:
return nil, errors.Errorf("unsupported type %s during evaluation", argTp)
}
Expand Down Expand Up @@ -754,6 +785,29 @@ func (b *builtinGreatestDurationSig) evalDuration(ctx EvalContext, row chunk.Row
return res, false, nil
}

type builtinGreatestVectorFloat32Sig struct {
baseBuiltinFunc
}

func (b *builtinGreatestVectorFloat32Sig) Clone() builtinFunc {
newSig := &builtinGreatestVectorFloat32Sig{}
newSig.cloneFrom(&b.baseBuiltinFunc)
return newSig
}

func (b *builtinGreatestVectorFloat32Sig) evalVectorFloat32(ctx EvalContext, row chunk.Row) (res types.VectorFloat32, isNull bool, err error) {
for i := 0; i < len(b.args); i++ {
v, isNull, err := b.args[i].EvalVectorFloat32(ctx, row)
if isNull || err != nil {
return types.VectorFloat32{}, true, err
}
if i == 0 || v.Compare(res) > 0 {
res = v
}
}
return res, false, nil
}

type leastFunctionClass struct {
baseFunctionClass
}
Expand Down Expand Up @@ -814,6 +868,9 @@ func (c *leastFunctionClass) getFunction(ctx BuildContext, args []Expression) (s
sig = &builtinLeastTimeSig{bf, false}
sig.setPbCode(tipb.ScalarFuncSig_LeastTime)
}
case types.ETVectorFloat32:
sig = &builtinLeastVectorFloat32Sig{bf}
// sig.setPbCode(tipb.ScalarFuncSig_LeastVectorFloat32)
default:
return nil, errors.Errorf("unsupported type %s during evaluation", argTp)
}
Expand Down Expand Up @@ -1039,6 +1096,29 @@ func (b *builtinLeastDurationSig) evalDuration(ctx EvalContext, row chunk.Row) (
return res, false, nil
}

type builtinLeastVectorFloat32Sig struct {
baseBuiltinFunc
}

func (b *builtinLeastVectorFloat32Sig) Clone() builtinFunc {
newSig := &builtinLeastVectorFloat32Sig{}
newSig.cloneFrom(&b.baseBuiltinFunc)
return newSig
}

func (b *builtinLeastVectorFloat32Sig) evalVectorFloat32(ctx EvalContext, row chunk.Row) (res types.VectorFloat32, isNull bool, err error) {
for i := 0; i < len(b.args); i++ {
v, isNull, err := b.args[i].EvalVectorFloat32(ctx, row)
if isNull || err != nil {
return types.VectorFloat32{}, true, err
}
if i == 0 || v.Compare(res) < 0 {
res = v
}
}
return res, false, nil
}

type intervalFunctionClass struct {
baseFunctionClass
}
Expand Down
Loading

0 comments on commit 2ba29a2

Please sign in to comment.