From 15d2f1d41cb192aeb5eaa2d5a97cbc4451e0a7d0 Mon Sep 17 00:00:00 2001 From: Cai Zhang Date: Mon, 28 Oct 2024 20:27:50 +0800 Subject: [PATCH] update expression Signed-off-by: Cai Zhang --- .../planparserv2/fill_expression_value.go | 56 ++--- .../fill_expression_value_test.go | 194 ++++++++---------- .../parser/planparserv2/parser_visitor.go | 2 + .../planparserv2/plan_parser_v2_test.go | 8 +- internal/parser/planparserv2/show_visitor.go | 3 + internal/parser/planparserv2/utils.go | 1 + 6 files changed, 118 insertions(+), 146 deletions(-) diff --git a/internal/parser/planparserv2/fill_expression_value.go b/internal/parser/planparserv2/fill_expression_value.go index 534bbc564772a..8840ac75b08bf 100644 --- a/internal/parser/planparserv2/fill_expression_value.go +++ b/internal/parser/planparserv2/fill_expression_value.go @@ -143,39 +143,43 @@ func FillBinaryArithOpEvalRangeExpressionValue(expr *planpb.BinaryArithOpEvalRan var err error var ok bool - operand := expr.GetRightOperand() - if operand == nil || expr.GetOperandTemplateVariableName() != "" { - operand, ok = templateValues[expr.GetOperandTemplateVariableName()] - if !ok { - return fmt.Errorf("the right operand value of expression template variable name {%s} is not found", expr.GetOperandTemplateVariableName()) + if expr.ArithOp == planpb.ArithOpType_ArrayLength { + dataType = schemapb.DataType_Int64 + } else { + operand := expr.GetRightOperand() + if operand == nil || expr.GetOperandTemplateVariableName() != "" { + operand, ok = templateValues[expr.GetOperandTemplateVariableName()] + if !ok { + return fmt.Errorf("the right operand value of expression template variable name {%s} is not found", expr.GetOperandTemplateVariableName()) + } } - } - operandExpr := toValueExpr(operand) - lDataType, rDataType := expr.GetColumnInfo().GetDataType(), operandExpr.dataType - if typeutil.IsArrayType(expr.GetColumnInfo().GetDataType()) { - lDataType = expr.GetColumnInfo().GetElementType() - } + operandExpr := toValueExpr(operand) + lDataType, rDataType := expr.GetColumnInfo().GetDataType(), operandExpr.dataType + if typeutil.IsArrayType(expr.GetColumnInfo().GetDataType()) { + lDataType = expr.GetColumnInfo().GetElementType() + } - if err = checkValidModArith(expr.GetArithOp(), expr.GetColumnInfo().GetDataType(), expr.GetColumnInfo().GetElementType(), - rDataType, schemapb.DataType_None); err != nil { - return err - } + if err = checkValidModArith(expr.GetArithOp(), expr.GetColumnInfo().GetDataType(), expr.GetColumnInfo().GetElementType(), + rDataType, schemapb.DataType_None); err != nil { + return err + } - if operand.GetArrayVal() != nil { - return fmt.Errorf("can not comparisons array directly") - } + if operand.GetArrayVal() != nil { + return fmt.Errorf("can not comparisons array directly") + } - dataType, err = getTargetType(lDataType, rDataType) - if err != nil { - return err - } + dataType, err = getTargetType(lDataType, rDataType) + if err != nil { + return err + } - castedOperand, err := castValue(dataType, operand) - if err != nil { - return err + castedOperand, err := castValue(dataType, operand) + if err != nil { + return err + } + expr.RightOperand = castedOperand } - expr.RightOperand = castedOperand value := expr.GetValue() if expr.GetValue() == nil || expr.GetValueTemplateVariableName() != "" { diff --git a/internal/parser/planparserv2/fill_expression_value_test.go b/internal/parser/planparserv2/fill_expression_value_test.go index b2adb65874035..c9aca25ca870b 100644 --- a/internal/parser/planparserv2/fill_expression_value_test.go +++ b/internal/parser/planparserv2/fill_expression_value_test.go @@ -1,14 +1,13 @@ package planparserv2 import ( - "encoding/json" - "fmt" "testing" + "github.com/milvus-io/milvus/pkg/util/typeutil" + "github.com/stretchr/testify/suite" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" - "github.com/milvus-io/milvus/internal/proto/planpb" ) type FillExpressionValueSuite struct { @@ -24,10 +23,17 @@ type testcase struct { values map[string]*schemapb.TemplateValue } -func (s *FillExpressionValueSuite) jsonMarshal(v interface{}) []byte { - r, err := json.Marshal(v) - s.NoError(err) - return r +func (s *FillExpressionValueSuite) assertValidExpr(helper *typeutil.SchemaHelper, exprStr string, templateValues map[string]*schemapb.TemplateValue) { + expr, err := ParseExpr(helper, exprStr, templateValues) + s.NoError(err, exprStr) + s.NotNil(expr, exprStr) + ShowExpr(expr) +} + +func (s *FillExpressionValueSuite) assertInvalidExpr(helper *typeutil.SchemaHelper, exprStr string, templateValues map[string]*schemapb.TemplateValue) { + expr, err := ParseExpr(helper, exprStr, templateValues) + s.Error(err, exprStr) + s.Nil(expr, exprStr) } func (s *FillExpressionValueSuite) TestTermExpr() { @@ -88,17 +94,8 @@ func (s *FillExpressionValueSuite) TestTermExpr() { }}, } schemaH := newTestSchemaHelper(s.T()) - for _, c := range testcases { - plan, err := CreateSearchPlan(schemaH, c.expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }, c.values) - - s.NoError(err) - s.NotNil(plan) + s.assertValidExpr(schemaH, c.expr, c.values) } }) @@ -130,17 +127,8 @@ func (s *FillExpressionValueSuite) TestTermExpr() { }}, } schemaH := newTestSchemaHelper(s.T()) - for _, c := range testcases { - plan, err := CreateSearchPlan(schemaH, c.expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }, c.values) - s.Error(err) - s.Nil(plan) - fmt.Println(plan) + s.assertInvalidExpr(schemaH, c.expr, c.values) } }) } @@ -172,19 +160,8 @@ func (s *FillExpressionValueSuite) TestUnaryRange() { }}, } schemaH := newTestSchemaHelper(s.T()) - for _, c := range testcases { - plan, err := CreateSearchPlan(schemaH, c.expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }, c.values) - - s.NoError(err) - s.NotNil(plan) - s.NotNil(plan.GetVectorAnns()) - s.NotNil(plan.GetVectorAnns().GetPredicates()) + s.assertValidExpr(schemaH, c.expr, c.values) } }) @@ -214,17 +191,8 @@ func (s *FillExpressionValueSuite) TestUnaryRange() { }}, } schemaH := newTestSchemaHelper(s.T()) - for _, c := range testcases { - plan, err := CreateSearchPlan(schemaH, c.expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }, c.values) - - s.Error(err) - s.Nil(plan) + s.assertInvalidExpr(schemaH, c.expr, c.values) } }) } @@ -264,19 +232,8 @@ func (s *FillExpressionValueSuite) TestBinaryRange() { } schemaH := newTestSchemaHelper(s.T()) - for _, c := range testcases { - plan, err := CreateSearchPlan(schemaH, c.expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }, c.values) - - s.NoError(err) - s.NotNil(plan) - s.NotNil(plan.GetVectorAnns()) - s.NotNil(plan.GetVectorAnns().GetPredicates()) + s.assertValidExpr(schemaH, c.expr, c.values) } }) @@ -311,17 +268,8 @@ func (s *FillExpressionValueSuite) TestBinaryRange() { } schemaH := newTestSchemaHelper(s.T()) - for _, c := range testcases { - plan, err := CreateSearchPlan(schemaH, c.expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }, c.values) - - s.Error(err) - s.Nil(plan) + s.assertInvalidExpr(schemaH, c.expr, c.values) } }) } @@ -343,22 +291,17 @@ func (s *FillExpressionValueSuite) TestBinaryArithOpEvalRange() { {`ArrayField[0] % {offset} < 11`, map[string]*schemapb.TemplateValue{ "offset": generateExpressionFieldData(schemapb.DataType_Int64, int64(3)), }}, + {`array_length(ArrayField) == {length}`, map[string]*schemapb.TemplateValue{ + "length": generateExpressionFieldData(schemapb.DataType_Int64, int64(3)), + }}, + {`array_length(ArrayField) > {length}`, map[string]*schemapb.TemplateValue{ + "length": generateExpressionFieldData(schemapb.DataType_Int64, int64(3)), + }}, } schemaH := newTestSchemaHelper(s.T()) - for _, c := range testcases { - plan, err := CreateSearchPlan(schemaH, c.expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }, c.values) - - s.NoError(err) - s.NotNil(plan) - s.NotNil(plan.GetVectorAnns()) - s.NotNil(plan.GetVectorAnns().GetPredicates()) + s.assertValidExpr(schemaH, c.expr, c.values) } }) @@ -401,20 +344,14 @@ func (s *FillExpressionValueSuite) TestBinaryArithOpEvalRange() { }), "target": generateExpressionFieldData(schemapb.DataType_Int64, int64(5)), }}, + {`array_length(ArrayField) == {length}`, map[string]*schemapb.TemplateValue{ + "length": generateExpressionFieldData(schemapb.DataType_String, "abc"), + }}, } schemaH := newTestSchemaHelper(s.T()) - for _, c := range testcases { - plan, err := CreateSearchPlan(schemaH, c.expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }, c.values) - - s.Error(err) - s.Nil(plan) + s.assertInvalidExpr(schemaH, c.expr, c.values) } }) } @@ -494,17 +431,7 @@ func (s *FillExpressionValueSuite) TestJSONContainsExpression() { schemaH := newTestSchemaHelper(s.T()) for _, c := range testcases { - plan, err := CreateSearchPlan(schemaH, c.expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }, c.values) - - s.NoError(err) - s.NotNil(plan) - s.NotNil(plan.GetVectorAnns()) - s.NotNil(plan.GetVectorAnns().GetPredicates()) + s.assertValidExpr(schemaH, c.expr, c.values) } }) @@ -554,15 +481,56 @@ func (s *FillExpressionValueSuite) TestJSONContainsExpression() { schemaH := newTestSchemaHelper(s.T()) for _, c := range testcases { - plan, err := CreateSearchPlan(schemaH, c.expr, "FloatVectorField", &planpb.QueryInfo{ - Topk: 0, - MetricType: "", - SearchParams: "", - RoundDecimal: 0, - }, c.values) - - s.Error(err) - s.Nil(plan) + s.assertInvalidExpr(schemaH, c.expr, c.values) + } + }) +} + +func (s *FillExpressionValueSuite) TestBinaryExpression() { + s.Run("normal case", func() { + testcases := []testcase{ + {`Int64Field > {int} && StringField in {list}`, map[string]*schemapb.TemplateValue{ + "int": generateExpressionFieldData(schemapb.DataType_Int64, int64(10)), + "list": generateExpressionFieldData(schemapb.DataType_Array, []interface{}{ + generateExpressionFieldData(schemapb.DataType_VarChar, "abc"), + generateExpressionFieldData(schemapb.DataType_VarChar, "def"), + generateExpressionFieldData(schemapb.DataType_VarChar, "ghi"), + }), + }}, + {`{max} > FloatField >= {min} or BoolField == {bool}`, map[string]*schemapb.TemplateValue{ + "min": generateExpressionFieldData(schemapb.DataType_Int64, int64(10)), + "max": generateExpressionFieldData(schemapb.DataType_Float, 22.22), + "bool": generateExpressionFieldData(schemapb.DataType_Bool, true), + }}, + } + + schemaH := newTestSchemaHelper(s.T()) + + for _, c := range testcases { + s.assertValidExpr(schemaH, c.expr, c.values) + } + }) + + s.Run("failed case", func() { + testcases := []testcase{ + {`Int64Field > {int} && StringField in {list}`, map[string]*schemapb.TemplateValue{ + "int": generateExpressionFieldData(schemapb.DataType_String, "abc"), + "list": generateExpressionFieldData(schemapb.DataType_Array, []interface{}{ + generateExpressionFieldData(schemapb.DataType_VarChar, "abc"), + generateExpressionFieldData(schemapb.DataType_Int64, int64(10)), + generateExpressionFieldData(schemapb.DataType_VarChar, "ghi"), + }), + }}, + {`{max} > FloatField >= {min} or BoolField == {bool}`, map[string]*schemapb.TemplateValue{ + "min": generateExpressionFieldData(schemapb.DataType_Int64, int64(10)), + "bool": generateExpressionFieldData(schemapb.DataType_Bool, true), + }}, + } + + schemaH := newTestSchemaHelper(s.T()) + + for _, c := range testcases { + s.assertInvalidExpr(schemaH, c.expr, c.values) } }) } diff --git a/internal/parser/planparserv2/parser_visitor.go b/internal/parser/planparserv2/parser_visitor.go index 058f607271a31..24f437feb7d52 100644 --- a/internal/parser/planparserv2/parser_visitor.go +++ b/internal/parser/planparserv2/parser_visitor.go @@ -878,6 +878,7 @@ func (v *ParserVisitor) VisitLogicalOr(ctx *parser.LogicalOrContext) interface{} Op: planpb.BinaryExpr_LogicalOr, }, }, + IsTemplate: leftExpr.expr.GetIsTemplate() || rightExpr.expr.GetIsTemplate(), } return &ExprWithType{ @@ -926,6 +927,7 @@ func (v *ParserVisitor) VisitLogicalAnd(ctx *parser.LogicalAndContext) interface Op: planpb.BinaryExpr_LogicalAnd, }, }, + IsTemplate: leftExpr.expr.GetIsTemplate() || rightExpr.expr.GetIsTemplate(), } return &ExprWithType{ diff --git a/internal/parser/planparserv2/plan_parser_v2_test.go b/internal/parser/planparserv2/plan_parser_v2_test.go index be62c5f1bf927..d3adb5b36577c 100644 --- a/internal/parser/planparserv2/plan_parser_v2_test.go +++ b/internal/parser/planparserv2/plan_parser_v2_test.go @@ -283,6 +283,7 @@ func TestExpr_BinaryArith(t *testing.T) { `Int64Field % 10 == 9`, `Int64Field % 10 != 9`, `FloatField + 1.1 == 2.1`, + `Int64Field + 1.1 == 2.1`, `A % 10 != 2`, `Int8Field + 1 < 2`, `Int16Field - 3 <= 4`, @@ -300,13 +301,6 @@ func TestExpr_BinaryArith(t *testing.T) { assertValidExpr(t, helper, exprStr) } - invalidExprs := []string{ - `Int64Field + 1.1 == 2.1`, - } - for _, exprStr := range invalidExprs { - assertInvalidExpr(t, helper, exprStr) - } - // TODO: enable these after execution backend is ready. unsupported := []string{ `ArrayField + 15 == 16`, diff --git a/internal/parser/planparserv2/show_visitor.go b/internal/parser/planparserv2/show_visitor.go index 1a06d93d5e62d..64cb854a2892a 100644 --- a/internal/parser/planparserv2/show_visitor.go +++ b/internal/parser/planparserv2/show_visitor.go @@ -21,6 +21,9 @@ func extractColumnInfo(info *planpb.ColumnInfo) interface{} { } func extractGenericValue(value *planpb.GenericValue) interface{} { + if value == nil { + return nil + } switch realValue := value.Val.(type) { case *planpb.GenericValue_BoolVal: return realValue.BoolVal diff --git a/internal/parser/planparserv2/utils.go b/internal/parser/planparserv2/utils.go index aff0f54ea25d3..e61bbd237c9bc 100644 --- a/internal/parser/planparserv2/utils.go +++ b/internal/parser/planparserv2/utils.go @@ -269,6 +269,7 @@ func combineArrayLengthExpr(op planpb.OpType, arithOp planpb.ArithOpType, column ValueTemplateVariableName: valueExpr.GetTemplateVariableName(), }, }, + IsTemplate: isTemplateExpr(valueExpr), }, nil }