diff --git a/compiler/compiler_test.go b/compiler/compiler_test.go index 046c208f1..858a2b340 100644 --- a/compiler/compiler_test.go +++ b/compiler/compiler_test.go @@ -112,11 +112,12 @@ func TestCompile(t *testing.T) { { `-1`, vm.Program{ - Constants: []interface{}{-1}, + Constants: []interface{}{1}, Bytecode: []vm.Opcode{ vm.OpPush, + vm.OpNegate, }, - Arguments: []int{0}, + Arguments: []int{0, 0}, }, }, { @@ -232,7 +233,7 @@ func TestCompile(t *testing.T) { } for _, test := range tests { - program, err := expr.Compile(test.input, expr.Env(Env{})) + program, err := expr.Compile(test.input, expr.Env(Env{}), expr.Optimize(false)) require.NoError(t, err, test.input) assert.Equal(t, test.program.Disassemble(), program.Disassemble(), test.input) diff --git a/optimizer/fold.go b/optimizer/fold.go index d6706ee03..b62b2d7ed 100644 --- a/optimizer/fold.go +++ b/optimizer/fold.go @@ -42,6 +42,10 @@ func (fold *fold) Visit(node *Node) { if i, ok := n.Node.(*FloatNode); ok { patchWithType(&FloatNode{Value: i.Value}, n.Node.Type()) } + case "!", "not": + if a := toBool(n.Node); a != nil { + patch(&BoolNode{Value: !a.Value}) + } } case *BinaryNode: @@ -211,6 +215,50 @@ func (fold *fold) Visit(node *Node) { patchWithType(&FloatNode{Value: math.Pow(a.Value, b.Value)}, a.Type()) } } + case "and", "&&": + a := toBool(n.Left) + b := toBool(n.Right) + + if a != nil && a.Value { // true and x + patch(n.Right) + } else if b != nil && b.Value { // x and true + patch(n.Left) + } else if (a != nil && !a.Value) || (b != nil && !b.Value) { // "x and false" or "false and x" + patch(&BoolNode{Value: false}) + } + case "or", "||": + a := toBool(n.Left) + b := toBool(n.Right) + + if a != nil && !a.Value { // false or x + patch(n.Right) + } else if b != nil && !b.Value { // x or false + patch(n.Left) + } else if (a != nil && a.Value) || (b != nil && b.Value) { // "x or true" or "true or x" + patch(&BoolNode{Value: true}) + } + case "==": + { + a := toInteger(n.Left) + b := toInteger(n.Right) + if a != nil && b != nil { + patch(&BoolNode{Value: a.Value == b.Value}) + } + } + { + a := toString(n.Left) + b := toString(n.Right) + if a != nil && b != nil { + patch(&BoolNode{Value: a.Value == b.Value}) + } + } + { + a := toBool(n.Left) + b := toBool(n.Right) + if a != nil && b != nil { + patch(&BoolNode{Value: a.Value == b.Value}) + } + } } case *ArrayNode: @@ -285,3 +333,11 @@ func toFloat(n Node) *FloatNode { } return nil } + +func toBool(n Node) *BoolNode { + switch a := n.(type) { + case *BoolNode: + return a + } + return nil +} diff --git a/optimizer/optimizer_test.go b/optimizer/optimizer_test.go index eac183438..ba31af56a 100644 --- a/optimizer/optimizer_test.go +++ b/optimizer/optimizer_test.go @@ -40,6 +40,18 @@ func TestOptimize_constant_folding_with_floats(t *testing.T) { assert.Equal(t, ast.Dump(expected), ast.Dump(tree.Node)) } +func TestOptimize_constant_folding_with_bools(t *testing.T) { + tree, err := parser.Parse(`(true and false) or (true or false) or (false and false) or (true and (true == false))`) + require.NoError(t, err) + + err = optimizer.Optimize(&tree.Node, nil) + require.NoError(t, err) + + expected := &ast.BoolNode{Value: true} + + assert.Equal(t, ast.Dump(expected), ast.Dump(tree.Node)) +} + func TestOptimize_in_array(t *testing.T) { config := conf.New(map[string]int{"v": 0})