Skip to content

Commit

Permalink
update expression
Browse files Browse the repository at this point in the history
Signed-off-by: Cai Zhang <[email protected]>
  • Loading branch information
xiaocai2333 committed Oct 29, 2024
1 parent 4ac379a commit 15d2f1d
Show file tree
Hide file tree
Showing 6 changed files with 118 additions and 146 deletions.
56 changes: 30 additions & 26 deletions internal/parser/planparserv2/fill_expression_value.go
Original file line number Diff line number Diff line change
Expand Up @@ -143,39 +143,43 @@ func FillBinaryArithOpEvalRangeExpressionValue(expr *planpb.BinaryArithOpEvalRan
var err error
var ok bool

operand := expr.GetRightOperand()
if operand == nil || expr.GetOperandTemplateVariableName() != "" {
operand, ok = templateValues[expr.GetOperandTemplateVariableName()]
if !ok {
return fmt.Errorf("the right operand value of expression template variable name {%s} is not found", expr.GetOperandTemplateVariableName())
if expr.ArithOp == planpb.ArithOpType_ArrayLength {
dataType = schemapb.DataType_Int64
} else {
operand := expr.GetRightOperand()
if operand == nil || expr.GetOperandTemplateVariableName() != "" {
operand, ok = templateValues[expr.GetOperandTemplateVariableName()]
if !ok {
return fmt.Errorf("the right operand value of expression template variable name {%s} is not found", expr.GetOperandTemplateVariableName())
}
}
}

operandExpr := toValueExpr(operand)
lDataType, rDataType := expr.GetColumnInfo().GetDataType(), operandExpr.dataType
if typeutil.IsArrayType(expr.GetColumnInfo().GetDataType()) {
lDataType = expr.GetColumnInfo().GetElementType()
}
operandExpr := toValueExpr(operand)
lDataType, rDataType := expr.GetColumnInfo().GetDataType(), operandExpr.dataType
if typeutil.IsArrayType(expr.GetColumnInfo().GetDataType()) {
lDataType = expr.GetColumnInfo().GetElementType()
}

if err = checkValidModArith(expr.GetArithOp(), expr.GetColumnInfo().GetDataType(), expr.GetColumnInfo().GetElementType(),
rDataType, schemapb.DataType_None); err != nil {
return err
}
if err = checkValidModArith(expr.GetArithOp(), expr.GetColumnInfo().GetDataType(), expr.GetColumnInfo().GetElementType(),
rDataType, schemapb.DataType_None); err != nil {
return err
}

if operand.GetArrayVal() != nil {
return fmt.Errorf("can not comparisons array directly")
}
if operand.GetArrayVal() != nil {
return fmt.Errorf("can not comparisons array directly")
}

dataType, err = getTargetType(lDataType, rDataType)
if err != nil {
return err
}
dataType, err = getTargetType(lDataType, rDataType)
if err != nil {
return err
}

castedOperand, err := castValue(dataType, operand)
if err != nil {
return err
castedOperand, err := castValue(dataType, operand)
if err != nil {
return err
}

Check warning on line 180 in internal/parser/planparserv2/fill_expression_value.go

View check run for this annotation

Codecov / codecov/patch

internal/parser/planparserv2/fill_expression_value.go#L179-L180

Added lines #L179 - L180 were not covered by tests
expr.RightOperand = castedOperand
}
expr.RightOperand = castedOperand

value := expr.GetValue()
if expr.GetValue() == nil || expr.GetValueTemplateVariableName() != "" {
Expand Down
194 changes: 81 additions & 113 deletions internal/parser/planparserv2/fill_expression_value_test.go
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
package planparserv2

import (
"encoding/json"
"fmt"
"testing"

"github.com/milvus-io/milvus/pkg/util/typeutil"

"github.com/stretchr/testify/suite"

"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/proto/planpb"
)

type FillExpressionValueSuite struct {
Expand All @@ -24,10 +23,17 @@ type testcase struct {
values map[string]*schemapb.TemplateValue
}

func (s *FillExpressionValueSuite) jsonMarshal(v interface{}) []byte {
r, err := json.Marshal(v)
s.NoError(err)
return r
func (s *FillExpressionValueSuite) assertValidExpr(helper *typeutil.SchemaHelper, exprStr string, templateValues map[string]*schemapb.TemplateValue) {
expr, err := ParseExpr(helper, exprStr, templateValues)
s.NoError(err, exprStr)
s.NotNil(expr, exprStr)
ShowExpr(expr)
}

func (s *FillExpressionValueSuite) assertInvalidExpr(helper *typeutil.SchemaHelper, exprStr string, templateValues map[string]*schemapb.TemplateValue) {
expr, err := ParseExpr(helper, exprStr, templateValues)
s.Error(err, exprStr)
s.Nil(expr, exprStr)
}

func (s *FillExpressionValueSuite) TestTermExpr() {
Expand Down Expand Up @@ -88,17 +94,8 @@ func (s *FillExpressionValueSuite) TestTermExpr() {
}},
}
schemaH := newTestSchemaHelper(s.T())

for _, c := range testcases {
plan, err := CreateSearchPlan(schemaH, c.expr, "FloatVectorField", &planpb.QueryInfo{
Topk: 0,
MetricType: "",
SearchParams: "",
RoundDecimal: 0,
}, c.values)

s.NoError(err)
s.NotNil(plan)
s.assertValidExpr(schemaH, c.expr, c.values)
}
})

Expand Down Expand Up @@ -130,17 +127,8 @@ func (s *FillExpressionValueSuite) TestTermExpr() {
}},
}
schemaH := newTestSchemaHelper(s.T())

for _, c := range testcases {
plan, err := CreateSearchPlan(schemaH, c.expr, "FloatVectorField", &planpb.QueryInfo{
Topk: 0,
MetricType: "",
SearchParams: "",
RoundDecimal: 0,
}, c.values)
s.Error(err)
s.Nil(plan)
fmt.Println(plan)
s.assertInvalidExpr(schemaH, c.expr, c.values)
}
})
}
Expand Down Expand Up @@ -172,19 +160,8 @@ func (s *FillExpressionValueSuite) TestUnaryRange() {
}},
}
schemaH := newTestSchemaHelper(s.T())

for _, c := range testcases {
plan, err := CreateSearchPlan(schemaH, c.expr, "FloatVectorField", &planpb.QueryInfo{
Topk: 0,
MetricType: "",
SearchParams: "",
RoundDecimal: 0,
}, c.values)

s.NoError(err)
s.NotNil(plan)
s.NotNil(plan.GetVectorAnns())
s.NotNil(plan.GetVectorAnns().GetPredicates())
s.assertValidExpr(schemaH, c.expr, c.values)
}
})

Expand Down Expand Up @@ -214,17 +191,8 @@ func (s *FillExpressionValueSuite) TestUnaryRange() {
}},
}
schemaH := newTestSchemaHelper(s.T())

for _, c := range testcases {
plan, err := CreateSearchPlan(schemaH, c.expr, "FloatVectorField", &planpb.QueryInfo{
Topk: 0,
MetricType: "",
SearchParams: "",
RoundDecimal: 0,
}, c.values)

s.Error(err)
s.Nil(plan)
s.assertInvalidExpr(schemaH, c.expr, c.values)
}
})
}
Expand Down Expand Up @@ -264,19 +232,8 @@ func (s *FillExpressionValueSuite) TestBinaryRange() {
}

schemaH := newTestSchemaHelper(s.T())

for _, c := range testcases {
plan, err := CreateSearchPlan(schemaH, c.expr, "FloatVectorField", &planpb.QueryInfo{
Topk: 0,
MetricType: "",
SearchParams: "",
RoundDecimal: 0,
}, c.values)

s.NoError(err)
s.NotNil(plan)
s.NotNil(plan.GetVectorAnns())
s.NotNil(plan.GetVectorAnns().GetPredicates())
s.assertValidExpr(schemaH, c.expr, c.values)
}
})

Expand Down Expand Up @@ -311,17 +268,8 @@ func (s *FillExpressionValueSuite) TestBinaryRange() {
}

schemaH := newTestSchemaHelper(s.T())

for _, c := range testcases {
plan, err := CreateSearchPlan(schemaH, c.expr, "FloatVectorField", &planpb.QueryInfo{
Topk: 0,
MetricType: "",
SearchParams: "",
RoundDecimal: 0,
}, c.values)

s.Error(err)
s.Nil(plan)
s.assertInvalidExpr(schemaH, c.expr, c.values)
}
})
}
Expand All @@ -343,22 +291,17 @@ func (s *FillExpressionValueSuite) TestBinaryArithOpEvalRange() {
{`ArrayField[0] % {offset} < 11`, map[string]*schemapb.TemplateValue{
"offset": generateExpressionFieldData(schemapb.DataType_Int64, int64(3)),
}},
{`array_length(ArrayField) == {length}`, map[string]*schemapb.TemplateValue{
"length": generateExpressionFieldData(schemapb.DataType_Int64, int64(3)),
}},
{`array_length(ArrayField) > {length}`, map[string]*schemapb.TemplateValue{
"length": generateExpressionFieldData(schemapb.DataType_Int64, int64(3)),
}},
}

schemaH := newTestSchemaHelper(s.T())

for _, c := range testcases {
plan, err := CreateSearchPlan(schemaH, c.expr, "FloatVectorField", &planpb.QueryInfo{
Topk: 0,
MetricType: "",
SearchParams: "",
RoundDecimal: 0,
}, c.values)

s.NoError(err)
s.NotNil(plan)
s.NotNil(plan.GetVectorAnns())
s.NotNil(plan.GetVectorAnns().GetPredicates())
s.assertValidExpr(schemaH, c.expr, c.values)
}
})

Expand Down Expand Up @@ -401,20 +344,14 @@ func (s *FillExpressionValueSuite) TestBinaryArithOpEvalRange() {
}),
"target": generateExpressionFieldData(schemapb.DataType_Int64, int64(5)),
}},
{`array_length(ArrayField) == {length}`, map[string]*schemapb.TemplateValue{
"length": generateExpressionFieldData(schemapb.DataType_String, "abc"),
}},
}

schemaH := newTestSchemaHelper(s.T())

for _, c := range testcases {
plan, err := CreateSearchPlan(schemaH, c.expr, "FloatVectorField", &planpb.QueryInfo{
Topk: 0,
MetricType: "",
SearchParams: "",
RoundDecimal: 0,
}, c.values)

s.Error(err)
s.Nil(plan)
s.assertInvalidExpr(schemaH, c.expr, c.values)
}
})
}
Expand Down Expand Up @@ -494,17 +431,7 @@ func (s *FillExpressionValueSuite) TestJSONContainsExpression() {
schemaH := newTestSchemaHelper(s.T())

for _, c := range testcases {
plan, err := CreateSearchPlan(schemaH, c.expr, "FloatVectorField", &planpb.QueryInfo{
Topk: 0,
MetricType: "",
SearchParams: "",
RoundDecimal: 0,
}, c.values)

s.NoError(err)
s.NotNil(plan)
s.NotNil(plan.GetVectorAnns())
s.NotNil(plan.GetVectorAnns().GetPredicates())
s.assertValidExpr(schemaH, c.expr, c.values)
}
})

Expand Down Expand Up @@ -554,15 +481,56 @@ func (s *FillExpressionValueSuite) TestJSONContainsExpression() {
schemaH := newTestSchemaHelper(s.T())

for _, c := range testcases {
plan, err := CreateSearchPlan(schemaH, c.expr, "FloatVectorField", &planpb.QueryInfo{
Topk: 0,
MetricType: "",
SearchParams: "",
RoundDecimal: 0,
}, c.values)

s.Error(err)
s.Nil(plan)
s.assertInvalidExpr(schemaH, c.expr, c.values)
}
})
}

func (s *FillExpressionValueSuite) TestBinaryExpression() {
s.Run("normal case", func() {
testcases := []testcase{
{`Int64Field > {int} && StringField in {list}`, map[string]*schemapb.TemplateValue{
"int": generateExpressionFieldData(schemapb.DataType_Int64, int64(10)),
"list": generateExpressionFieldData(schemapb.DataType_Array, []interface{}{
generateExpressionFieldData(schemapb.DataType_VarChar, "abc"),
generateExpressionFieldData(schemapb.DataType_VarChar, "def"),
generateExpressionFieldData(schemapb.DataType_VarChar, "ghi"),
}),
}},
{`{max} > FloatField >= {min} or BoolField == {bool}`, map[string]*schemapb.TemplateValue{
"min": generateExpressionFieldData(schemapb.DataType_Int64, int64(10)),
"max": generateExpressionFieldData(schemapb.DataType_Float, 22.22),
"bool": generateExpressionFieldData(schemapb.DataType_Bool, true),
}},
}

schemaH := newTestSchemaHelper(s.T())

for _, c := range testcases {
s.assertValidExpr(schemaH, c.expr, c.values)
}
})

s.Run("failed case", func() {
testcases := []testcase{
{`Int64Field > {int} && StringField in {list}`, map[string]*schemapb.TemplateValue{
"int": generateExpressionFieldData(schemapb.DataType_String, "abc"),
"list": generateExpressionFieldData(schemapb.DataType_Array, []interface{}{
generateExpressionFieldData(schemapb.DataType_VarChar, "abc"),
generateExpressionFieldData(schemapb.DataType_Int64, int64(10)),
generateExpressionFieldData(schemapb.DataType_VarChar, "ghi"),
}),
}},
{`{max} > FloatField >= {min} or BoolField == {bool}`, map[string]*schemapb.TemplateValue{
"min": generateExpressionFieldData(schemapb.DataType_Int64, int64(10)),
"bool": generateExpressionFieldData(schemapb.DataType_Bool, true),
}},
}

schemaH := newTestSchemaHelper(s.T())

for _, c := range testcases {
s.assertInvalidExpr(schemaH, c.expr, c.values)
}
})
}
2 changes: 2 additions & 0 deletions internal/parser/planparserv2/parser_visitor.go
Original file line number Diff line number Diff line change
Expand Up @@ -878,6 +878,7 @@ func (v *ParserVisitor) VisitLogicalOr(ctx *parser.LogicalOrContext) interface{}
Op: planpb.BinaryExpr_LogicalOr,
},
},
IsTemplate: leftExpr.expr.GetIsTemplate() || rightExpr.expr.GetIsTemplate(),
}

return &ExprWithType{
Expand Down Expand Up @@ -926,6 +927,7 @@ func (v *ParserVisitor) VisitLogicalAnd(ctx *parser.LogicalAndContext) interface
Op: planpb.BinaryExpr_LogicalAnd,
},
},
IsTemplate: leftExpr.expr.GetIsTemplate() || rightExpr.expr.GetIsTemplate(),
}

return &ExprWithType{
Expand Down
Loading

0 comments on commit 15d2f1d

Please sign in to comment.