From 4552f219fd98bdeba0930bf8631ec75de47d37ef Mon Sep 17 00:00:00 2001 From: TristonianJones Date: Wed, 23 Aug 2023 16:47:05 -0700 Subject: [PATCH 1/4] Optimizer API with Constant Folding implementatiton --- cel/BUILD.bazel | 3 + cel/env.go | 8 + cel/folding.go | 450 ++++++++++++++++++++++++++++++++++++++++ cel/folding_test.go | 195 +++++++++++++++++ cel/optimizer.go | 391 ++++++++++++++++++++++++++++++++++ common/ast/ast.go | 13 +- common/ast/expr.go | 12 ++ common/ast/navigable.go | 4 + 8 files changed, 1073 insertions(+), 3 deletions(-) create mode 100644 cel/folding.go create mode 100644 cel/folding_test.go create mode 100644 cel/optimizer.go diff --git a/cel/BUILD.bazel b/cel/BUILD.bazel index aa978e06..62b903c8 100644 --- a/cel/BUILD.bazel +++ b/cel/BUILD.bazel @@ -10,9 +10,11 @@ go_library( "cel.go", "decls.go", "env.go", + "folding.go", "io.go", "library.go", "macro.go", + "optimizer.go", "options.go", "program.go", "validator.go", @@ -56,6 +58,7 @@ go_test( "cel_test.go", "decls_test.go", "env_test.go", + "folding_test.go", "io_test.go", "validator_test.go", ], diff --git a/cel/env.go b/cel/env.go index 113b89b5..786a13c4 100644 --- a/cel/env.go +++ b/cel/env.go @@ -43,6 +43,9 @@ type Ast struct { } // Expr returns the proto serializable instance of the parsed/checked expression. +// +// Deprecated: prefer cel.AstToCheckedExpr() or cel.AstToParsedExpr() and call GetExpr() +// the result instead. func (ast *Ast) Expr() *exprpb.Expr { if ast == nil { return nil @@ -221,6 +224,11 @@ func (e *Env) Check(ast *Ast) (*Ast, *Issues) { source: ast.Source(), impl: checked} + // Avoid creating a validator config if it's not needed. + if len(e.validators) == 0 { + return ast, nil + } + // Generate a validator configuration from the set of configured validators. vConfig := newValidatorConfig() for _, v := range e.validators { diff --git a/cel/folding.go b/cel/folding.go new file mode 100644 index 00000000..fce15da5 --- /dev/null +++ b/cel/folding.go @@ -0,0 +1,450 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cel + +import ( + "fmt" + + "github.com/google/cel-go/common/ast" + "github.com/google/cel-go/common/operators" + "github.com/google/cel-go/common/overloads" + "github.com/google/cel-go/common/types" + "github.com/google/cel-go/common/types/ref" + "github.com/google/cel-go/common/types/traits" +) + +// NewConstantFoldingOptimizer creates an optimizer which inlines constant scalar an aggregate +// literal values within function calls and select statements with their evaluated result. +func NewConstantFoldingOptimizer() ASTOptimizer { + return &constantFoldingOptimizer{} +} + +type constantFoldingOptimizer struct{} + +// Optimize queries the expression graph for scalar and aggregate literal expressions within call and +// select statements and then evaluates them and replaces the call site with the literal result. +// +// Note: only values which can be represented as literals in CEL syntax are supported. +func (*constantFoldingOptimizer) Optimize(ctx *OptimizerContext, a *ast.AST) *ast.AST { + root := ast.NavigateAST(a) + + // Walk the list of foldable expression and continue to fold until there are no more folds left. + // All of the fold candidates returned by the constantExprMatcher should succeed unless there's + // a logic bug with the selection of expressions. + foldableExprs := ast.MatchDescendants(root, constantExprMatcher) + for len(foldableExprs) != 0 { + for _, fold := range foldableExprs { + // If the expression could be folded because it's a non-strict call, and the + // branches are pruned, continue to the next fold. + if fold.Kind() == ast.CallKind && maybePruneBranches(fold) { + continue + } + // Otherwise, assume all context is needed to evaluate the expression. + err := tryFold(ctx, a, fold) + if err != nil { + ctx.ReportErrorAtID(fold.ID(), "constant-folding evaluation failed: %v", err.Error()) + return a + } + } + foldableExprs = ast.MatchDescendants(root, constantExprMatcher) + } + // Once all of the constants have been folded, try to run through the remaining comprehensions + // one last time. In this case, there's no guarantee they'll run, so we only update the + // target comprehension node with the literal value if the evaluation succeeds. + for _, compre := range ast.MatchDescendants(root, ast.KindMatcher(ast.ComprehensionKind)) { + tryFold(ctx, a, compre) + } + + // If the output is a list, map, or struct which contains optional entries, then prune it + // to make sure that the optionals, if resolved, do not surface in the output literal. + pruneOptionalElements(ctx, root) + + // Ensure that all intermediate values in the folded expression can be represented as valid + // CEL literals within the AST structure. + ast.PostOrderVisit(root, ast.NewExprVisitor(func(e ast.Expr) { + if e.Kind() != ast.LiteralKind { + return + } + val := e.AsLiteral() + adapted, err := adaptLiteral(ctx, val) + if err != nil { + ctx.ReportErrorAtID(root.ID(), "constant-folding evaluation failed: %v", err.Error()) + return + } + e.SetKindCase(adapted) + })) + + return a +} + +// tryFold attempts to evaluate a sub-expression to a literal. +// +// If the evaluation succeeds, the input expr value will be modified to become a literal, otherwise +// the method will return an error. +func tryFold(ctx *OptimizerContext, a *ast.AST, expr ast.Expr) error { + // Assume all context is needed to evaluate the expression. + subAST := &Ast{ + impl: ast.NewCheckedAST(ast.NewAST(expr, a.SourceInfo()), a.TypeMap(), a.ReferenceMap()), + } + prg, err := ctx.Program(subAST) + if err != nil { + return err + } + out, _, err := prg.Eval(NoVars()) + if err != nil { + return err + } + // Clear any macro metadata associated with the fold. + a.SourceInfo().ClearMacroCall(expr.ID()) + // Update the fold expression to be a literal. + expr.SetKindCase(ctx.NewLiteral(out)) + return nil +} + +// maybePruneBranches inspects the non-strict call expression to determine whether +// a branch can be removed. Evaluation will naturally prune logical and / or calls, +// but conditional will not be pruned cleanly, so this is one small area where the +// constant folding step reimplements a portion of the evaluator. +func maybePruneBranches(expr ast.NavigableExpr) bool { + call := expr.AsCall() + switch call.FunctionName() { + case operators.Conditional: + args := call.Args() + cond := args[0] + truthy := args[1] + falsy := args[2] + if cond.AsLiteral() == types.True { + expr.SetKindCase(truthy) + } else { + expr.SetKindCase(falsy) + } + return true + } + return false +} + +// pruneOptionalElements works from the bottom up to resolve optional elements within +// aggregate literals. +// +// Note, may aggregate literals will be resolved as arguments to functions or select +// statements, so this method exists to handle the case where the literal could not be +// fully resolved or exists outside of a call, select, or comprehension context. +func pruneOptionalElements(ctx *OptimizerContext, root ast.NavigableExpr) { + aggregateLiterals := ast.MatchDescendants(root, aggregateLiteralMatcher) + for _, lit := range aggregateLiterals { + switch lit.Kind() { + case ast.ListKind: + pruneOptionalListElements(ctx, lit) + case ast.MapKind: + pruneOptionalMapEntries(ctx, lit) + case ast.StructKind: + pruneOptionalStructFields(ctx, lit) + } + } +} + +func pruneOptionalListElements(ctx *OptimizerContext, e ast.Expr) { + l := e.AsList() + elems := l.Elements() + optIndices := l.OptionalIndices() + if len(optIndices) == 0 { + return + } + updatedElems := []ast.Expr{} + updatedIndices := []int32{} + for i, e := range elems { + if !l.IsOptional(int32(i)) { + updatedElems = append(updatedElems, e) + continue + } + if e.Kind() != ast.LiteralKind { + updatedElems = append(updatedElems, e) + updatedIndices = append(updatedIndices, int32(i)) + continue + } + optElemVal, ok := e.AsLiteral().(*types.Optional) + if !ok { + updatedElems = append(updatedElems, e) + updatedIndices = append(updatedIndices, int32(i)) + continue + } + if !optElemVal.HasValue() { + continue + } + e.SetKindCase(ctx.NewLiteral(optElemVal.GetValue())) + updatedElems = append(updatedElems, e) + } + e.SetKindCase(ctx.NewList(updatedElems, updatedIndices)) +} + +func pruneOptionalMapEntries(ctx *OptimizerContext, e ast.Expr) { + m := e.AsMap() + entries := m.Entries() + updatedEntries := []ast.EntryExpr{} + modified := false + for _, e := range entries { + entry := e.AsMapEntry() + key := entry.Key() + val := entry.Value() + if !entry.IsOptional() || val.Kind() != ast.LiteralKind { + updatedEntries = append(updatedEntries, e) + continue + } + optElemVal, ok := val.AsLiteral().(*types.Optional) + if !ok { + updatedEntries = append(updatedEntries, e) + continue + } + if key.Kind() != ast.LiteralKind { + undoOptVal, err := adaptLiteral(ctx, optElemVal) + if err != nil { + ctx.ReportErrorAtID(val.ID(), "invalid map value literal %v: %v", optElemVal, err) + } + val.SetKindCase(undoOptVal) + updatedEntries = append(updatedEntries, e) + continue + } + modified = true + if !optElemVal.HasValue() { + continue + } + val.SetKindCase(ctx.NewLiteral(optElemVal.GetValue())) + updatedEntry := ctx.NewMapEntry(key, val, false) + updatedEntries = append(updatedEntries, updatedEntry) + } + if modified { + e.SetKindCase(ctx.NewMap(updatedEntries)) + } +} + +func pruneOptionalStructFields(ctx *OptimizerContext, e ast.Expr) { + s := e.AsStruct() + fields := s.Fields() + updatedFields := []ast.EntryExpr{} + modified := false + for _, f := range fields { + field := f.AsStructField() + val := field.Value() + if !field.IsOptional() || val.Kind() != ast.LiteralKind { + updatedFields = append(updatedFields, f) + continue + } + optElemVal, ok := val.AsLiteral().(*types.Optional) + if !ok { + updatedFields = append(updatedFields, f) + continue + } + modified = true + if !optElemVal.HasValue() { + continue + } + val.SetKindCase(ctx.NewLiteral(optElemVal.GetValue())) + updatedField := ctx.NewStructField(field.Name(), val, false) + updatedFields = append(updatedFields, updatedField) + } + if modified { + e.SetKindCase(ctx.NewStruct(s.TypeName(), updatedFields)) + } +} + +// adaptLiteral converts a runtime CEL value to its equivalent literal expression. +// +// For strongly typed values, the type-provider will be used to reconstruct the fields +// which are present in the literal and their equivalent initialization values. +func adaptLiteral(ctx *OptimizerContext, val ref.Val) (ast.Expr, error) { + switch t := val.Type().(type) { + case *types.Type: + switch t { + case types.BoolType, types.BytesType, types.DoubleType, types.IntType, + types.NullType, types.StringType, types.UintType: + return ctx.NewLiteral(val), nil + case types.DurationType: + return ctx.NewCall( + overloads.TypeConvertDuration, + ctx.NewLiteral(val.ConvertToType(types.StringType)), + ), nil + case types.TimestampType: + return ctx.NewCall( + overloads.TypeConvertTimestamp, + ctx.NewLiteral(val.ConvertToType(types.StringType)), + ), nil + case types.OptionalType: + opt := val.(*types.Optional) + if !opt.HasValue() { + return ctx.NewCall("optional.none"), nil + } + target, err := adaptLiteral(ctx, opt.GetValue()) + if err != nil { + return nil, err + } + return ctx.NewCall("optional.of", target), nil + case types.TypeType: + return ctx.NewIdent(val.(*types.Type).TypeName()), nil + case types.ListType: + l, ok := val.(traits.Lister) + if !ok { + return nil, fmt.Errorf("failed to adapt %v to literal", val) + } + elems := make([]ast.Expr, l.Size().(types.Int)) + idx := 0 + it := l.Iterator() + for it.HasNext() == types.True { + elemVal := it.Next() + elemExpr, err := adaptLiteral(ctx, elemVal) + if err != nil { + return nil, err + } + elems[idx] = elemExpr + idx++ + } + return ctx.NewList(elems, []int32{}), nil + case types.MapType: + m, ok := val.(traits.Mapper) + if !ok { + return nil, fmt.Errorf("failed to adapt %v to literal", val) + } + entries := make([]ast.EntryExpr, m.Size().(types.Int)) + idx := 0 + it := m.Iterator() + for it.HasNext() == types.True { + keyVal := it.Next() + keyExpr, err := adaptLiteral(ctx, keyVal) + if err != nil { + return nil, err + } + valVal := m.Get(keyVal) + valExpr, err := adaptLiteral(ctx, valVal) + if err != nil { + return nil, err + } + entries[idx] = ctx.NewMapEntry(keyExpr, valExpr, false) + idx++ + } + return ctx.NewMap(entries), nil + default: + provider := ctx.CELTypeProvider() + fields, found := provider.FindStructFieldNames(t.TypeName()) + if !found { + return nil, fmt.Errorf("failed to adapt %v to literal", val) + } + tester := val.(traits.FieldTester) + indexer := val.(traits.Indexer) + fieldInits := []ast.EntryExpr{} + for _, f := range fields { + field := types.String(f) + if tester.IsSet(field) != types.True { + continue + } + fieldVal := indexer.Get(field) + fieldExpr, err := adaptLiteral(ctx, fieldVal) + if err != nil { + return nil, err + } + fieldInits = append(fieldInits, ctx.NewStructField(f, fieldExpr, false)) + } + return ctx.NewStruct(t.TypeName(), fieldInits), nil + } + } + return nil, fmt.Errorf("failed to adapt %v to literal", val) +} + +// constantExprMatcher matches calls, select statements, and comprehensions whose arguments +// are all constant scalar or aggregate literal values. +// +// Only comprehensions which are not nested are included as possible constant folds, and only +// if all variables referenced in the comprehension stack exist are only iteration or +// accumulation variables. +func constantExprMatcher(e ast.NavigableExpr) bool { + switch e.Kind() { + case ast.CallKind: + return constantCallMatcher(e) + case ast.SelectKind: + sel := e.AsSelect() // guaranteed to be a navigable value + return constantMatcher(sel.Operand().(ast.NavigableExpr)) + case ast.ComprehensionKind: + if isNestedComprehension(e) { + return false + } + vars := map[string]bool{} + constantExprs := true + visitor := ast.NewExprVisitor(func(e ast.Expr) { + if e.Kind() == ast.ComprehensionKind { + nested := e.AsComprehension() + vars[nested.AccuVar()] = true + vars[nested.IterVar()] = true + } + if e.Kind() == ast.IdentKind && !vars[e.AsIdent()] { + constantExprs = false + } + }) + ast.PreOrderVisit(e, visitor) + return constantExprs + default: + return false + } +} + +// constantCallMatcher identifies strict and non-strict calls which can be folded. +func constantCallMatcher(e ast.NavigableExpr) bool { + call := e.AsCall() + children := e.Children() + fnName := call.FunctionName() + if fnName == operators.LogicalAnd { + for _, child := range children { + if child.Kind() == ast.LiteralKind && child.AsLiteral() == types.False { + return true + } + } + } + if fnName == operators.LogicalOr { + for _, child := range children { + if child.Kind() == ast.LiteralKind && child.AsLiteral() == types.True { + return true + } + } + } + if fnName == operators.Conditional { + cond := children[0] + if cond.Kind() == ast.LiteralKind && cond.AsLiteral().Type() == types.BoolType { + return true + } + } + // convert all other calls with constant arguments + for _, child := range children { + if !constantMatcher(child) { + return false + } + } + return true +} + +func isNestedComprehension(e ast.NavigableExpr) bool { + parent, found := e.Parent() + for found { + if parent.Kind() == ast.ComprehensionKind { + return true + } + parent, found = parent.Parent() + } + return false +} + +func aggregateLiteralMatcher(e ast.NavigableExpr) bool { + return e.Kind() == ast.ListKind || e.Kind() == ast.MapKind || e.Kind() == ast.StructKind +} + +var ( + constantMatcher = ast.ConstantValueMatcher() +) diff --git a/cel/folding_test.go b/cel/folding_test.go new file mode 100644 index 00000000..c871a5bd --- /dev/null +++ b/cel/folding_test.go @@ -0,0 +1,195 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cel_test + +import ( + "testing" + + "github.com/google/cel-go/cel" + "github.com/google/cel-go/test/proto3pb" +) + +func TestConstantFoldingOptimizer(t *testing.T) { + tests := []struct { + expr string + folded string + }{ + { + expr: `[1, 1 + 2, 1 + (2 + 3)]`, + folded: `[1, 3, 6]`, + }, + { + expr: `6 in [1, 1 + 2, 1 + (2 + 3)]`, + folded: `true`, + }, + { + expr: `5 in [1, 1 + 2, 1 + (2 + 3)]`, + folded: `false`, + }, + { + expr: `x in [1, 1 + 2, 1 + (2 + 3)]`, + folded: `x in [1, 3, 6]`, + }, + { + expr: `1 in [1, x + 2, 1 + (2 + 3)]`, + folded: `1 in [1, x + 2, 6]`, + }, + { + expr: `{'hello': 'world'}.hello == x`, + folded: `"world" == x`, + }, + { + expr: `{'hello': 'world'}.?hello.orValue('default') == x`, + folded: `"world" == x`, + }, + { + expr: `{'hello': 'world'}['hello'] == x`, + folded: `"world" == x`, + }, + { + expr: `optional.of("hello")`, + folded: `optional.of("hello")`, + }, + { + expr: `optional.ofNonZeroValue("")`, + folded: `optional.none()`, + }, + { + expr: `{?'hello': optional.of('world')}['hello'] == x`, + folded: `"world" == x`, + }, + { + expr: `duration(string(7 * 24) + 'h')`, + folded: `duration("604800s")`, + }, + { + expr: `timestamp("1970-01-01T00:00:00Z")`, + folded: `timestamp("1970-01-01T00:00:00Z")`, + }, + { + expr: `[1, 1 + 1, 1 + 2, 2 + 3].exists(i, i < 10)`, + folded: `true`, + }, + { + expr: `[1, 1 + 1, 1 + 2, 2 + 3].exists(i, i < 1 % 2)`, + folded: `false`, + }, + { + expr: `[1, 2, 3].map(i, [1, 2, 3].map(j, i * j))`, + folded: `[[1, 2, 3], [2, 4, 6], [3, 6, 9]]`, + }, + { + expr: `[{}, {"a": 1}, {"b": 2}].filter(m, has(m.a))`, + folded: `[{"a": 1}]`, + }, + { + expr: `[{}, {"a": 1}, {"b": 2}].filter(m, has({'a': true}.a))`, + folded: `[{}, {"a": 1}, {"b": 2}]`, + }, + { + expr: `type(1)`, + folded: `int`, + }, + { + expr: `[google.expr.proto3.test.TestAllTypes{single_int32: 2 + 3}].map(i, i)[0]`, + folded: `google.expr.proto3.test.TestAllTypes{single_int32: 5}`, + }, + { + expr: `[1, ?optional.ofNonZeroValue(0)]`, + folded: `[1]`, + }, + { + expr: `[1, x, ?optional.ofNonZeroValue(3), ?x.?y]`, + folded: `[1, x, 3, ?x.?y]`, + }, + { + expr: `[1, x, ?optional.ofNonZeroValue(3), ?x.?y].size() > 3`, + folded: `[1, x, 3, ?x.?y].size() > 3`, + }, + { + expr: `{?'a': optional.of('hello'), ?x : optional.of(1), ?'b': optional.none()}`, + folded: `{"a": "hello", ?x: optional.of(1)}`, + }, + { + expr: `true ? x + 1 : x + 2`, + folded: `x + 1`, + }, + { + expr: `false ? x + 1 : x + 2`, + folded: `x + 2`, + }, + { + expr: `false ? x + 'world' : 'hello' + 'world'`, + folded: `"helloworld"`, + }, + { + expr: `null`, + folded: `null`, + }, + { + expr: `google.expr.proto3.test.TestAllTypes{?single_int32: optional.ofNonZeroValue(1)}`, + folded: `google.expr.proto3.test.TestAllTypes{single_int32: 1}`, + }, + { + expr: `google.expr.proto3.test.TestAllTypes{?single_int32: optional.ofNonZeroValue(0)}`, + folded: `google.expr.proto3.test.TestAllTypes{}`, + }, + { + expr: `google.expr.proto3.test.TestAllTypes{single_int32: x, repeated_int32: [1, 2, 3]}`, + folded: `google.expr.proto3.test.TestAllTypes{single_int32: x, repeated_int32: [1, 2, 3]}`, + }, + { + expr: `x + dyn([1, 2] + [3, 4])`, + folded: `x + [1, 2, 3, 4]`, + }, + { + expr: `dyn([1, 2]) + [3.0, 4.0]`, + folded: `[1, 2, 3.0, 4.0]`, + }, + { + expr: `{'a': dyn([1, 2]), 'b': x}`, + folded: `{"a": [1, 2], "b": x}`, + }, + } + e, err := cel.NewEnv( + cel.OptionalTypes(), + cel.EnableMacroCallTracking(), + cel.Types(&proto3pb.TestAllTypes{}), + cel.Variable("x", cel.DynType)) + if err != nil { + t.Fatalf("cel.NewEnv() failed: %v", err) + } + for _, tst := range tests { + tc := tst + t.Run(tc.expr, func(t *testing.T) { + checked, iss := e.Compile(tc.expr) + if iss.Err() != nil { + t.Fatalf("Compile() failed: %v", iss.Err()) + } + opt := cel.NewStaticOptimizer(cel.NewConstantFoldingOptimizer()) + optimized, iss := opt.Optimize(e, checked) + if iss.Err() != nil { + t.Fatalf("Optimize() generated an invalid AST: %v", iss.Err()) + } + folded, err := cel.AstToString(optimized) + if err != nil { + t.Fatalf("cel.AstToString() failed: %v", err) + } + if folded != tc.folded { + t.Errorf("got %q, wanted %q", folded, tc.folded) + } + }) + } +} diff --git a/cel/optimizer.go b/cel/optimizer.go new file mode 100644 index 00000000..4a62d5bf --- /dev/null +++ b/cel/optimizer.go @@ -0,0 +1,391 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cel + +import ( + "github.com/google/cel-go/common" + "github.com/google/cel-go/common/ast" + "github.com/google/cel-go/common/types" + "github.com/google/cel-go/common/types/ref" +) + +// StaticOptimizer contains a sequence of ASTOptimizer instances which will be applied in order. +// +// The static optimizer normalizes expression ids and type-checking run between optimization +// passes to ensure that the final optimized output is a valid expression with metadata consistent +// with what would have been generated from a parsed and checked expression. +// +// Note: source position information is best-effort and likely wrong, but optimized expressions +// should be suitable for calls to parser.Unparse. +type StaticOptimizer struct { + optimizers []ASTOptimizer +} + +// NewStaticOptimizer creates a StaticOptimizer with a sequence of ASTOptimizer's to be applied +// to a checked expression. +func NewStaticOptimizer(optimizers ...ASTOptimizer) *StaticOptimizer { + return &StaticOptimizer{ + optimizers: optimizers, + } +} + +// Optimize applies a sequence of optimizations to an Ast within a given environment. +// +// If issues are encountered, the Issues.Err() return value will be non-nil. +func (opt *StaticOptimizer) Optimize(env *Env, a *Ast) (*Ast, *Issues) { + // Make a copy of the AST to be optimized. + optimized := ast.Copy(a.impl) + + // Create the optimizer context, could be pooled in the future. + issues := NewIssues(common.NewErrors(a.Source())) + ids := newMonotonicIDGen(ast.MaxID(a.impl)) + fac := &optimizerExprFactory{ + nextID: ids.nextID, + renumberID: ids.renumberID, + fac: ast.NewExprFactory(), + sourceInfo: optimized.SourceInfo(), + } + ctx := &OptimizerContext{ + optimizerExprFactory: fac, + Env: env, + Issues: issues, + } + + // Apply the optimizations sequentially. + for _, o := range opt.optimizers { + optimized = o.Optimize(ctx, optimized) + if issues.Err() != nil { + return nil, issues + } + // Normalize expression id metadata including coordination with macro call metadata. + normalizeIDs(env, optimized) + + // Recheck the updated expression for any possible type-agreement or validation errors. + parsed := &Ast{ + source: a.Source(), + impl: ast.NewAST(optimized.Expr(), optimized.SourceInfo())} + checked, iss := ctx.Check(parsed) + if iss.Err() != nil { + return nil, iss + } + optimized = checked.impl + } + + // Return the optimized result. + return &Ast{ + source: a.Source(), + impl: optimized, + }, nil +} + +// normalizeIDs ensures that the metadata present with an AST is reset in a manner such +// that the ids within the expression correspond to the ids within macros. This function +// ensures that +func normalizeIDs(e *Env, optimized *ast.AST) { + ids := newStableIDGen() + optimized.Expr().RenumberIDs(ids.renumberID) + allExprMap := make(map[int64]ast.Expr) + ast.PostOrderVisit(optimized.Expr(), ast.NewExprVisitor(func(e ast.Expr) { + allExprMap[e.ID()] = e + })) + info := optimized.SourceInfo() + + // First, update the macro call ids themselves. + for id, call := range info.MacroCalls() { + info.ClearMacroCall(id) + callID := ids.renumberID(id) + if e, found := allExprMap[callID]; found && e.Kind() == ast.LiteralKind { + continue + } + info.SetMacroCall(callID, call) + } + + // Second, update the macro call id references to ensure that macro pointers are' + // updated consistently across macros. + for id, call := range info.MacroCalls() { + call.RenumberIDs(ids.renumberID) + resetMacroCall(optimized, call, allExprMap) + info.SetMacroCall(id, call) + } +} + +func resetMacroCall(optimized *ast.AST, call ast.Expr, allExprMap map[int64]ast.Expr) { + modified := []ast.Expr{} + ast.PostOrderVisit(call, ast.NewExprVisitor(func(e ast.Expr) { + if _, found := allExprMap[e.ID()]; found { + modified = append(modified, e) + } + })) + for _, m := range modified { + updated := allExprMap[m.ID()] + m.SetKindCase(updated) + } +} + +// newMonotonicIDGen increments numbers from an initial seed value. +func newMonotonicIDGen(seed int64) *monotonicIDGenerator { + return &monotonicIDGenerator{seed: seed} +} + +type monotonicIDGenerator struct { + seed int64 +} + +func (gen *monotonicIDGenerator) nextID() int64 { + gen.seed++ + return gen.seed +} + +func (gen *monotonicIDGenerator) renumberID(int64) int64 { + return gen.nextID() +} + +// newStableIDGen ensures that new ids are only created the first time they are encountered. +func newStableIDGen() *stableIDGenerator { + return &stableIDGenerator{ + idMap: make(map[int64]int64), + } +} + +type stableIDGenerator struct { + idMap map[int64]int64 + nextID int64 +} + +func (gen *stableIDGenerator) renumberID(id int64) int64 { + if id == 0 { + return 0 + } + if newID, found := gen.idMap[id]; found { + return newID + } + gen.nextID++ + gen.idMap[id] = gen.nextID + return gen.nextID +} + +// OptimizerContext embeds Env and Issues instances to make it easy to type-check and evaluate +// subexpressions and report any errors encountered along the way. The context also embeds the +// optimizerExprFactory which can be used to generate new sub-expressions with expression ids +// consistent with the expectations of a parsed expression. +type OptimizerContext struct { + *Env + *optimizerExprFactory + *Issues +} + +// ASTOptimizer applies an optimization over an AST and returns the optimized result. +type ASTOptimizer interface { + // Optimize optimizes a type-checked AST within an Environment and accumulates any issues. + Optimize(*OptimizerContext, *ast.AST) *ast.AST +} + +type optimizerExprFactory struct { + nextID func() int64 + renumberID ast.IDGenerator + fac ast.ExprFactory + sourceInfo *ast.SourceInfo +} + +// CopyExpr copies the structure of the input ast.Expr and renumbers the identifiers in a manner +// consistent with the CEL parser / checker. +func (opt *optimizerExprFactory) CopyExpr(e ast.Expr) ast.Expr { + copy := opt.fac.CopyExpr(e) + copy.RenumberIDs(opt.renumberID) + return copy +} + +// NewBindMacro creates a cel.bind() call with a variable name, initialization expression, and remaining expression. +// +// Note: the macroID indicates the insertion point, the call id that matched the macro signature, which will be used +// for coordinating macro metadata with the bind call. This piece of data is what makes it possible to unparse +// optimized expressions which use the bind() call. +// +// Example: +// +// cel.bind(myVar, a && b || c, !myVar || (myVar && d)) +// - varName: myVar +// - varInit: a && b || c +// - remaining: !myVar || (myVar && d) +func (opt *optimizerExprFactory) NewBindMacro(macroID int64, varName string, varInit, remaining ast.Expr) ast.Expr { + bindID := opt.nextID() + varID := opt.nextID() + + varInit = opt.CopyExpr(varInit) + varInit.RenumberIDs(opt.renumberID) + + remaining = opt.fac.CopyExpr(remaining) + remaining.RenumberIDs(opt.renumberID) + + // Place the expanded macro form in the macro calls list so that the inlined + // call can be unparsed. + opt.sourceInfo.SetMacroCall(macroID, + opt.fac.NewMemberCall(0, "bind", + opt.fac.NewIdent(opt.nextID(), "cel"), + opt.fac.NewIdent(varID, varName), + varInit, + remaining)) + + // Replace the parent node with the intercepted inlining using cel.bind()-like + // generated comprehension AST. + return opt.fac.NewComprehension(bindID, + opt.fac.NewList(opt.nextID(), []ast.Expr{}, []int32{}), + "#unused", + varName, + opt.fac.CopyExpr(varInit), + opt.fac.NewLiteral(opt.nextID(), types.False), + opt.fac.NewIdent(varID, varName), + opt.fac.CopyExpr(remaining)) +} + +// NewCall creates a global function call invocation expression. +// +// Example: +// +// countByField(list, fieldName) +// - function: countByField +// - args: [list, fieldName] +func (opt *optimizerExprFactory) NewCall(function string, args ...ast.Expr) ast.Expr { + return opt.fac.NewCall(opt.nextID(), function, args...) +} + +// NewMemberCall creates a member function call invocation expression where 'target' is the receiver of the call. +// +// Example: +// +// list.countByField(fieldName) +// - function: countByField +// - target: list +// - args: [fieldName] +func (opt *optimizerExprFactory) NewMemberCall(function string, target ast.Expr, args ...ast.Expr) ast.Expr { + return opt.fac.NewMemberCall(opt.nextID(), function, target, args...) +} + +// NewIdent creates a new identifier expression. +// +// Examples: +// +// - simple_var_name +// - qualified.subpackage.var_name +func (opt *optimizerExprFactory) NewIdent(name string) ast.Expr { + return opt.fac.NewIdent(opt.nextID(), name) +} + +// NewLiteral creates a new literal expression value. +// +// The range of valid values for a literal generated during optimization is different than for expressions +// generated via parsing / type-checking, as the ref.Val may be _any_ CEL value so long as the value can +// be converted back to a literal-like form. +func (opt *optimizerExprFactory) NewLiteral(value ref.Val) ast.Expr { + return opt.fac.NewLiteral(opt.nextID(), value) +} + +// NewList creates a list expression with a set of optional indices. +// +// Examples: +// +// [a, b] +// - elems: [a, b] +// - optIndices: [] +// +// [a, ?b, ?c] +// - elems: [a, b, c] +// - optIndices: [1, 2] +func (opt *optimizerExprFactory) NewList(elems []ast.Expr, optIndices []int32) ast.Expr { + return opt.fac.NewList(opt.nextID(), elems, optIndices) +} + +// NewMap creates a map from a set of entry expressions which contain a key and value expression. +func (opt *optimizerExprFactory) NewMap(entries []ast.EntryExpr) ast.Expr { + return opt.fac.NewMap(opt.nextID(), entries) +} + +// NewMapEntry creates a map entry with a key and value expression and a flag to indicate whether the +// entry is optional. +// +// Examples: +// +// {a: b} +// - key: a +// - value: b +// - optional: false +// +// {?a: ?b} +// - key: a +// - value: b +// - optional: true +func (opt *optimizerExprFactory) NewMapEntry(key, value ast.Expr, isOptional bool) ast.EntryExpr { + return opt.fac.NewMapEntry(opt.nextID(), key, value, isOptional) +} + +// NewPresenceTest creates a new presence test macro call. +// +// Example: +// +// has(msg.field_name) +// - operand: msg +// - field: field_name +func (opt *optimizerExprFactory) NewPresenceTest(macroID int64, operand ast.Expr, field string) ast.Expr { + // Copy the input operand and renumber it. + operand = opt.CopyExpr(operand) + operand.RenumberIDs(opt.renumberID) + + // Place the expanded macro form in the macro calls list so that the inlined call can be unparsed. + opt.sourceInfo.SetMacroCall(macroID, + opt.fac.NewCall(0, "has", + opt.fac.NewSelect(opt.nextID(), operand, field))) + + // Generate a new presence test macro. + return opt.fac.NewPresenceTest(opt.nextID(), operand, field) +} + +// NewSelect creates a select expression where a field value is selected from an operand. +// +// Example: +// +// msg.field_name +// - operand: msg +// - field: field_name +func (opt *optimizerExprFactory) NewSelect(operand ast.Expr, field string) ast.Expr { + return opt.fac.NewSelect(opt.nextID(), operand, field) +} + +// NewStruct creates a new typed struct value with an set of field initializations. +// +// Example: +// +// pkg.TypeName{field: value} +// - typeName: pkg.TypeName +// - fields: [{field: value}] +func (opt *optimizerExprFactory) NewStruct(typeName string, fields []ast.EntryExpr) ast.Expr { + return opt.fac.NewStruct(opt.nextID(), typeName, fields) +} + +// NewStructField creates a struct field initialization. +// +// Examples: +// +// {count: 3u} +// - field: count +// - value: 3u +// - optional: false +// +// {?count: x} +// - field: count +// - value: x +// - optional: true +func (opt *optimizerExprFactory) NewStructField(field string, value ast.Expr, isOptional bool) ast.EntryExpr { + return opt.fac.NewStructField(opt.nextID(), field, value, isOptional) +} diff --git a/common/ast/ast.go b/common/ast/ast.go index 7610b467..c3620eb9 100644 --- a/common/ast/ast.go +++ b/common/ast/ast.go @@ -249,10 +249,16 @@ func (s *SourceInfo) GetMacroCall(id int64) (Expr, bool) { // SetMacroCall records a macro call at a specific location. func (s *SourceInfo) SetMacroCall(id int64, e Expr) { - if s == nil { - return + if s != nil { + s.macroCalls[id] = e + } +} + +// ClearMacroCall removes the macro call at the given expression id. +func (s *SourceInfo) ClearMacroCall(id int64) { + if s != nil { + delete(s.macroCalls, id) } - s.macroCalls[id] = e } // OffsetRanges returns a map of expression id to OffsetRange values where the range indicates either: @@ -407,6 +413,7 @@ func (r *ReferenceInfo) Equals(other *ReferenceInfo) bool { type maxIDVisitor struct { maxID int64 + *baseVisitor } // VisitExpr updates the max identifier if the incoming expression id is greater than previously observed. diff --git a/common/ast/expr.go b/common/ast/expr.go index 5811e395..aac3bf3d 100644 --- a/common/ast/expr.go +++ b/common/ast/expr.go @@ -184,6 +184,9 @@ type ListExpr interface { // OptionalIndicies returns the list of optional indices in the list literal. OptionalIndices() []int32 + // IsOptional indicates whether the given element index is optional. + IsOptional(int32) bool + // Size returns the number of elements in the list. Size() int @@ -606,6 +609,15 @@ func (e *baseListExpr) Elements() []Expr { return e.elements } +func (e *baseListExpr) IsOptional(index int32) bool { + for _, optIndex := range e.OptionalIndices() { + if optIndex == index { + return true + } + } + return false +} + func (e *baseListExpr) OptionalIndices() []int32 { if e == nil { return []int32{} diff --git a/common/ast/navigable.go b/common/ast/navigable.go index 2836b565..f5ddf6aa 100644 --- a/common/ast/navigable.go +++ b/common/ast/navigable.go @@ -423,6 +423,10 @@ func (l navigableListImpl) Elements() []Expr { return elems } +func (l navigableListImpl) IsOptional(index int32) bool { + return l.Expr.AsList().IsOptional(index) +} + func (l navigableListImpl) OptionalIndices() []int32 { return l.Expr.AsList().OptionalIndices() } From 1db6903b16ff5ee978f6c1e52a50e92a0227ca0a Mon Sep 17 00:00:00 2001 From: TristonianJones Date: Mon, 28 Aug 2023 23:45:22 -0700 Subject: [PATCH 2/4] Better logical folds and additional tests --- cel/folding.go | 80 +++++++- cel/folding_test.go | 416 ++++++++++++++++++++++++++++++++++++++++-- cel/optimizer.go | 76 +------- common/ast/expr.go | 22 ++- common/ast/factory.go | 11 +- 5 files changed, 500 insertions(+), 105 deletions(-) diff --git a/cel/folding.go b/cel/folding.go index fce15da5..d1612bc6 100644 --- a/cel/folding.go +++ b/cel/folding.go @@ -48,7 +48,7 @@ func (*constantFoldingOptimizer) Optimize(ctx *OptimizerContext, a *ast.AST) *as for _, fold := range foldableExprs { // If the expression could be folded because it's a non-strict call, and the // branches are pruned, continue to the next fold. - if fold.Kind() == ast.CallKind && maybePruneBranches(fold) { + if fold.Kind() == ast.CallKind && maybePruneBranches(ctx, fold) { continue } // Otherwise, assume all context is needed to evaluate the expression. @@ -72,7 +72,8 @@ func (*constantFoldingOptimizer) Optimize(ctx *OptimizerContext, a *ast.AST) *as pruneOptionalElements(ctx, root) // Ensure that all intermediate values in the folded expression can be represented as valid - // CEL literals within the AST structure. + // CEL literals within the AST structure. Use `PostOrderVisit` rather than `MatchDescendents` + // to avoid extra allocations during this final pass through the AST. ast.PostOrderVisit(root, ast.NewExprVisitor(func(e ast.Expr) { if e.Kind() != ast.LiteralKind { return @@ -117,24 +118,75 @@ func tryFold(ctx *OptimizerContext, a *ast.AST, expr ast.Expr) error { // a branch can be removed. Evaluation will naturally prune logical and / or calls, // but conditional will not be pruned cleanly, so this is one small area where the // constant folding step reimplements a portion of the evaluator. -func maybePruneBranches(expr ast.NavigableExpr) bool { +func maybePruneBranches(ctx *OptimizerContext, expr ast.NavigableExpr) bool { call := expr.AsCall() + args := call.Args() switch call.FunctionName() { + case operators.LogicalAnd, operators.LogicalOr: + return maybeShortcircuitLogic(ctx, call.FunctionName(), args, expr) case operators.Conditional: - args := call.Args() cond := args[0] truthy := args[1] falsy := args[2] + if cond.Kind() != ast.LiteralKind { + return false + } if cond.AsLiteral() == types.True { expr.SetKindCase(truthy) } else { expr.SetKindCase(falsy) } return true + case operators.In: + haystack := args[1] + if haystack.Kind() == ast.ListKind && haystack.AsList().Size() == 0 { + expr.SetKindCase(ctx.NewLiteral(types.False)) + return true + } + needle := args[0] + if needle.Kind() == ast.LiteralKind && haystack.Kind() == ast.ListKind { + needleValue := needle.AsLiteral() + list := haystack.AsList() + for _, e := range list.Elements() { + if e.Kind() == ast.LiteralKind && e.AsLiteral().Equal(needleValue) == types.True { + expr.SetKindCase(ctx.NewLiteral(types.True)) + return true + } + } + } } return false } +func maybeShortcircuitLogic(ctx *OptimizerContext, function string, args []ast.Expr, expr ast.NavigableExpr) bool { + shortcircuit := types.False + skip := types.True + if function == operators.LogicalOr { + shortcircuit = types.True + skip = types.False + } + newArgs := []ast.Expr{} + for _, arg := range args { + if arg.Kind() != ast.LiteralKind { + newArgs = append(newArgs, arg) + continue + } + if arg.AsLiteral() == skip { + continue + } + if arg.AsLiteral() == shortcircuit { + expr.SetKindCase(arg) + return true + } + } + if len(newArgs) == 1 { + expr.SetKindCase(newArgs[0]) + return true + } + expr.SetKindCase(ctx.NewCall(function, newArgs...)) + return true +} + // pruneOptionalElements works from the bottom up to resolve optional elements within // aggregate literals. // @@ -403,14 +455,14 @@ func constantCallMatcher(e ast.NavigableExpr) bool { fnName := call.FunctionName() if fnName == operators.LogicalAnd { for _, child := range children { - if child.Kind() == ast.LiteralKind && child.AsLiteral() == types.False { + if child.Kind() == ast.LiteralKind { return true } } } if fnName == operators.LogicalOr { for _, child := range children { - if child.Kind() == ast.LiteralKind && child.AsLiteral() == types.True { + if child.Kind() == ast.LiteralKind { return true } } @@ -421,6 +473,22 @@ func constantCallMatcher(e ast.NavigableExpr) bool { return true } } + if fnName == operators.In { + haystack := children[1] + if haystack.Kind() == ast.ListKind && haystack.AsList().Size() == 0 { + return true + } + needle := children[0] + if needle.Kind() == ast.LiteralKind && haystack.Kind() == ast.ListKind { + needleValue := needle.AsLiteral() + list := haystack.AsList() + for _, e := range list.Elements() { + if e.Kind() == ast.LiteralKind && e.AsLiteral().Equal(needleValue) == types.True { + return true + } + } + } + } // convert all other calls with constant arguments for _, child := range children { if !constantMatcher(child) { diff --git a/cel/folding_test.go b/cel/folding_test.go index c871a5bd..4e40be3a 100644 --- a/cel/folding_test.go +++ b/cel/folding_test.go @@ -12,13 +12,20 @@ // See the License for the specific language governing permissions and // limitations under the License. -package cel_test +package cel import ( + "reflect" + "sort" "testing" - "github.com/google/cel-go/cel" - "github.com/google/cel-go/test/proto3pb" + "google.golang.org/protobuf/encoding/prototext" + "google.golang.org/protobuf/proto" + + "github.com/google/cel-go/common/ast" + + proto3pb "github.com/google/cel-go/test/proto3pb" + exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1" ) func TestConstantFoldingOptimizer(t *testing.T) { @@ -44,7 +51,15 @@ func TestConstantFoldingOptimizer(t *testing.T) { }, { expr: `1 in [1, x + 2, 1 + (2 + 3)]`, - folded: `1 in [1, x + 2, 6]`, + folded: `true`, + }, + { + expr: `1 in [x, x + 2, 1 + (2 + 3)]`, + folded: `1 in [x, x + 2, 6]`, + }, + { + expr: `x in []`, + folded: `false`, }, { expr: `{'hello': 'world'}.hello == x`, @@ -90,6 +105,14 @@ func TestConstantFoldingOptimizer(t *testing.T) { expr: `[1, 2, 3].map(i, [1, 2, 3].map(j, i * j))`, folded: `[[1, 2, 3], [2, 4, 6], [3, 6, 9]]`, }, + { + expr: `[1, 2, 3].map(i, [1, 2, 3].map(j, i * j).filter(k, k % 2 == 0))`, + folded: `[[2], [2, 4, 6], [6]]`, + }, + { + expr: `[1, 2, 3].map(i, [1, 2, 3].map(j, i * j).filter(k, k % 2 == x))`, + folded: `[1, 2, 3].map(i, [1, 2, 3].map(j, i * j).filter(k, k % 2 == x))`, + }, { expr: `[{}, {"a": 1}, {"b": 2}].filter(m, has(m.a))`, folded: `[{"a": 1}]`, @@ -134,6 +157,46 @@ func TestConstantFoldingOptimizer(t *testing.T) { expr: `false ? x + 'world' : 'hello' + 'world'`, folded: `"helloworld"`, }, + { + expr: `true && x`, + folded: `x`, + }, + { + expr: `x && true`, + folded: `x`, + }, + { + expr: `false && x`, + folded: `false`, + }, + { + expr: `x && false`, + folded: `false`, + }, + { + expr: `true || x`, + folded: `true`, + }, + { + expr: `x || true`, + folded: `true`, + }, + { + expr: `false || x`, + folded: `x`, + }, + { + expr: `x || false`, + folded: `x`, + }, + { + expr: `true && x && true && x`, + folded: `x && x`, + }, + { + expr: `false || x || false || x`, + folded: `x || x`, + }, { expr: `null`, folded: `null`, @@ -162,14 +225,26 @@ func TestConstantFoldingOptimizer(t *testing.T) { expr: `{'a': dyn([1, 2]), 'b': x}`, folded: `{"a": [1, 2], "b": x}`, }, + { + expr: `1 + x + 2 == 2 + x + 1`, + folded: `1 + x + 2 == 2 + x + 1`, + }, + { + // The order of operations makes it such that the appearance of x in the first means that + // none of the values provided into the addition call will be folded with the current + // implementation. Ideally, the result would be 3 + x == x + 3 (which could be trivially true + // and more easily observed as a result of common subexpression eliminiation) + expr: `1 + 2 + x == x + 2 + 1`, + folded: `3 + x == x + 2 + 1`, + }, } - e, err := cel.NewEnv( - cel.OptionalTypes(), - cel.EnableMacroCallTracking(), - cel.Types(&proto3pb.TestAllTypes{}), - cel.Variable("x", cel.DynType)) + e, err := NewEnv( + OptionalTypes(), + EnableMacroCallTracking(), + Types(&proto3pb.TestAllTypes{}), + Variable("x", DynType)) if err != nil { - t.Fatalf("cel.NewEnv() failed: %v", err) + t.Fatalf("NewEnv() failed: %v", err) } for _, tst := range tests { tc := tst @@ -178,14 +253,14 @@ func TestConstantFoldingOptimizer(t *testing.T) { if iss.Err() != nil { t.Fatalf("Compile() failed: %v", iss.Err()) } - opt := cel.NewStaticOptimizer(cel.NewConstantFoldingOptimizer()) + opt := NewStaticOptimizer(NewConstantFoldingOptimizer()) optimized, iss := opt.Optimize(e, checked) if iss.Err() != nil { t.Fatalf("Optimize() generated an invalid AST: %v", iss.Err()) } - folded, err := cel.AstToString(optimized) + folded, err := AstToString(optimized) if err != nil { - t.Fatalf("cel.AstToString() failed: %v", err) + t.Fatalf("AstToString() failed: %v", err) } if folded != tc.folded { t.Errorf("got %q, wanted %q", folded, tc.folded) @@ -193,3 +268,318 @@ func TestConstantFoldingOptimizer(t *testing.T) { }) } } + +func TestConstantFoldingNormalizeIDs(t *testing.T) { + tests := []struct { + expr string + ids []int64 + macros map[int64]string + normalizedIDs []int64 + normalizedMacros map[int64]string + }{ + { + expr: `[1, 2, 3]`, + ids: []int64{1, 2, 3, 4}, + normalizedIDs: []int64{1, 2, 3, 4}, + }, + { + expr: `google.expr.proto3.test.TestAllTypes{single_int32: 0}`, + ids: []int64{1, 2, 3}, + normalizedIDs: []int64{1, 2, 3}, + }, + { + expr: `has({x: 'value'}.single_int32)`, + ids: []int64{2, 3, 4, 5, 7}, + macros: map[int64]string{7: ` + call_expr: { + function: "has" + args: { + id: 6 + select_expr: { + operand: { + id: 2 + struct_expr: { + entries: { + id: 3 + map_key: { + id: 4 + ident_expr: { + name: "x" + } + } + value: { + id: 5 + const_expr: { + string_value: "value" + } + } + } + } + } + field: "single_int32" + } + } + }`}, + normalizedIDs: []int64{1, 2, 3, 4, 5}, + normalizedMacros: map[int64]string{1: ` + call_expr: { + function: "has" + args: { + id: 6 + select_expr: { + operand: { + id: 2 + struct_expr: { + entries: { + id: 3 + map_key: { + id: 4 + ident_expr: { + name: "x" + } + } + value: { + id: 5 + const_expr: { + string_value: "value" + } + } + } + } + } + field: "single_int32" + } + } + }`, + }, + }, + { + expr: `has(google.expr.proto3.test.TestAllTypes{}.single_int32)`, + ids: []int64{2, 4}, + macros: map[int64]string{ + 4: `call_expr: { + function: "has" + args: { + id: 3 + select_expr: { + operand: { + id: 2 + struct_expr: { + message_name: "google.expr.proto3.test.TestAllTypes" + } + } + field: "single_int32" + } + } + }`, + }, + normalizedIDs: []int64{1}, + }, + { + expr: `[true].exists(i, i)`, + ids: []int64{1, 2, 5, 6, 7, 8, 9, 10, 11, 12, 13}, + macros: map[int64]string{ + 13: `call_expr: { + target: { + id: 1 + list_expr: { + elements: { + id: 2 + const_expr: { + bool_value: true + } + } + } + } + function: "exists" + args: { + id: 4 + ident_expr: { + name: "i" + } + } + args: { + id: 5 + ident_expr: { + name: "i" + } + } + }`, + }, + normalizedIDs: []int64{1}, + }, + { + expr: `[x].exists(i, i)`, + ids: []int64{1, 2, 5, 6, 7, 8, 9, 10, 11, 12, 13}, + macros: map[int64]string{ + 13: `call_expr: { + target: { + id: 1 + list_expr: { + elements: { + id: 2 + ident_expr: { + name: "x" + } + } + } + } + function: "exists" + args: { + id: 4 + ident_expr: { + name: "i" + } + } + args: { + id: 5 + ident_expr: { + name: "i" + } + } + }`, + }, + normalizedIDs: []int64{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}, + normalizedMacros: map[int64]string{ + 1: `call_expr: { + target: { + id: 2 + list_expr: { + elements: { + id: 3 + ident_expr: { + name: "x" + } + } + } + } + function: "exists" + args: { + id: 12 + ident_expr: { + name: "i" + } + } + args: { + id: 10 + ident_expr: { + name: "i" + } + } + }`, + }, + }, + } + e, err := NewEnv( + EnableMacroCallTracking(), + Types(&proto3pb.TestAllTypes{}), + Variable("x", DynType)) + if err != nil { + t.Fatalf("NewEnv() failed: %v", err) + } + for _, tst := range tests { + tc := tst + t.Run(tc.expr, func(t *testing.T) { + checked, iss := e.Compile(tc.expr) + if iss.Err() != nil { + t.Fatalf("Compile() failed: %v", iss.Err()) + } + preOpt := newIDCollector() + ast.PostOrderVisit(checked.impl.Expr(), preOpt) + if !reflect.DeepEqual(preOpt.IDs(), tc.ids) { + t.Errorf("Compile() got ids %v, expected %v", preOpt.IDs(), tc.ids) + } + for id, call := range checked.impl.SourceInfo().MacroCalls() { + macroText, found := tc.macros[id] + if !found { + t.Fatalf("Compile() did not find macro %d", id) + } + pbCall, err := ast.ExprToProto(call) + if err != nil { + t.Fatalf("ast.ExprToProto() failed: %v", err) + } + pbMacro := &exprpb.Expr{} + err = prototext.Unmarshal([]byte(macroText), pbMacro) + if err != nil { + t.Fatalf("prototext.Unmarshal() failed: %v", err) + } + if !proto.Equal(pbCall, pbMacro) { + t.Errorf("Compile() for macro %d got %s, expected %s", id, prototext.Format(pbCall), macroText) + } + } + opt := NewStaticOptimizer(NewConstantFoldingOptimizer()) + optimized, iss := opt.Optimize(e, checked) + if iss.Err() != nil { + t.Fatalf("Optimize() generated an invalid AST: %v", iss.Err()) + } + postOpt := newIDCollector() + ast.PostOrderVisit(optimized.impl.Expr(), postOpt) + if !reflect.DeepEqual(postOpt.IDs(), tc.normalizedIDs) { + t.Errorf("Optimize() got ids %v, expected %v", postOpt.IDs(), tc.normalizedIDs) + } + for id, call := range optimized.impl.SourceInfo().MacroCalls() { + macroText, found := tc.normalizedMacros[id] + if !found { + t.Fatalf("Optimize() did not find macro %d", id) + } + pbCall, err := ast.ExprToProto(call) + if err != nil { + t.Fatalf("ast.ExprToProto() failed: %v", err) + } + pbMacro := &exprpb.Expr{} + err = prototext.Unmarshal([]byte(macroText), pbMacro) + if err != nil { + t.Fatalf("prototext.Unmarshal() failed: %v", err) + } + if !proto.Equal(pbCall, pbMacro) { + t.Errorf("Optimize() for macro %d got %s, expected %s", id, prototext.Format(pbCall), macroText) + } + } + }) + } +} + +func newIDCollector() *idCollector { + return &idCollector{ + ids: int64Slice{}, + } +} + +type idCollector struct { + ids int64Slice +} + +func (c *idCollector) VisitExpr(e ast.Expr) { + if e.ID() == 0 { + return + } + c.ids = append(c.ids, e.ID()) +} + +// VisitEntryExpr updates the max identifier if the incoming entry id is greater than previously observed. +func (c *idCollector) VisitEntryExpr(e ast.EntryExpr) { + if e.ID() == 0 { + return + } + c.ids = append(c.ids, e.ID()) +} + +func (c *idCollector) IDs() []int64 { + sort.Sort(c.ids) + return c.ids +} + +// int64Slice is an implementation of the sort.Interface +type int64Slice []int64 + +// Len returns the number of elements in the slice. +func (x int64Slice) Len() int { return len(x) } + +// Less indicates whether the value at index i is less than the value at index j. +func (x int64Slice) Less(i, j int) bool { return x[i] < x[j] } + +// Swap swaps the values at indices i and j in place. +func (x int64Slice) Swap(i, j int) { x[i], x[j] = x[j], x[i] } + +// Sort is a convenience method: x.Sort() calls Sort(x). +func (x int64Slice) Sort() { sort.Sort(x) } diff --git a/cel/optimizer.go b/cel/optimizer.go index 4a62d5bf..a2023ed7 100644 --- a/cel/optimizer.go +++ b/cel/optimizer.go @@ -17,7 +17,6 @@ package cel import ( "github.com/google/cel-go/common" "github.com/google/cel-go/common/ast" - "github.com/google/cel-go/common/types" "github.com/google/cel-go/common/types/ref" ) @@ -91,8 +90,7 @@ func (opt *StaticOptimizer) Optimize(env *Env, a *Ast) (*Ast, *Issues) { } // normalizeIDs ensures that the metadata present with an AST is reset in a manner such -// that the ids within the expression correspond to the ids within macros. This function -// ensures that +// that the ids within the expression correspond to the ids within macros. func normalizeIDs(e *Env, optimized *ast.AST) { ids := newStableIDGen() optimized.Expr().RenumberIDs(ids.renumberID) @@ -199,57 +197,6 @@ type optimizerExprFactory struct { sourceInfo *ast.SourceInfo } -// CopyExpr copies the structure of the input ast.Expr and renumbers the identifiers in a manner -// consistent with the CEL parser / checker. -func (opt *optimizerExprFactory) CopyExpr(e ast.Expr) ast.Expr { - copy := opt.fac.CopyExpr(e) - copy.RenumberIDs(opt.renumberID) - return copy -} - -// NewBindMacro creates a cel.bind() call with a variable name, initialization expression, and remaining expression. -// -// Note: the macroID indicates the insertion point, the call id that matched the macro signature, which will be used -// for coordinating macro metadata with the bind call. This piece of data is what makes it possible to unparse -// optimized expressions which use the bind() call. -// -// Example: -// -// cel.bind(myVar, a && b || c, !myVar || (myVar && d)) -// - varName: myVar -// - varInit: a && b || c -// - remaining: !myVar || (myVar && d) -func (opt *optimizerExprFactory) NewBindMacro(macroID int64, varName string, varInit, remaining ast.Expr) ast.Expr { - bindID := opt.nextID() - varID := opt.nextID() - - varInit = opt.CopyExpr(varInit) - varInit.RenumberIDs(opt.renumberID) - - remaining = opt.fac.CopyExpr(remaining) - remaining.RenumberIDs(opt.renumberID) - - // Place the expanded macro form in the macro calls list so that the inlined - // call can be unparsed. - opt.sourceInfo.SetMacroCall(macroID, - opt.fac.NewMemberCall(0, "bind", - opt.fac.NewIdent(opt.nextID(), "cel"), - opt.fac.NewIdent(varID, varName), - varInit, - remaining)) - - // Replace the parent node with the intercepted inlining using cel.bind()-like - // generated comprehension AST. - return opt.fac.NewComprehension(bindID, - opt.fac.NewList(opt.nextID(), []ast.Expr{}, []int32{}), - "#unused", - varName, - opt.fac.CopyExpr(varInit), - opt.fac.NewLiteral(opt.nextID(), types.False), - opt.fac.NewIdent(varID, varName), - opt.fac.CopyExpr(remaining)) -} - // NewCall creates a global function call invocation expression. // // Example: @@ -330,27 +277,6 @@ func (opt *optimizerExprFactory) NewMapEntry(key, value ast.Expr, isOptional boo return opt.fac.NewMapEntry(opt.nextID(), key, value, isOptional) } -// NewPresenceTest creates a new presence test macro call. -// -// Example: -// -// has(msg.field_name) -// - operand: msg -// - field: field_name -func (opt *optimizerExprFactory) NewPresenceTest(macroID int64, operand ast.Expr, field string) ast.Expr { - // Copy the input operand and renumber it. - operand = opt.CopyExpr(operand) - operand.RenumberIDs(opt.renumberID) - - // Place the expanded macro form in the macro calls list so that the inlined call can be unparsed. - opt.sourceInfo.SetMacroCall(macroID, - opt.fac.NewCall(0, "has", - opt.fac.NewSelect(opt.nextID(), operand, field))) - - // Generate a new presence test macro. - return opt.fac.NewPresenceTest(opt.nextID(), operand, field) -} - // NewSelect creates a select expression where a field value is selected from an operand. // // Example: diff --git a/common/ast/expr.go b/common/ast/expr.go index aac3bf3d..c9d88bba 100644 --- a/common/ast/expr.go +++ b/common/ast/expr.go @@ -407,9 +407,14 @@ func (e *expr) SetKindCase(other Expr) { e.exprKindCase = baseIdentExpr(other.AsIdent()) case ListKind: l := other.AsList() + optIndexMap := make(map[int32]struct{}, len(l.OptionalIndices())) + for _, idx := range l.OptionalIndices() { + optIndexMap[idx] = struct{}{} + } e.exprKindCase = &baseListExpr{ - elements: l.Elements(), - optIndices: l.OptionalIndices(), + elements: l.Elements(), + optIndices: l.OptionalIndices(), + optIndexMap: optIndexMap, } case LiteralKind: e.exprKindCase = &baseLiteral{Val: other.AsLiteral()} @@ -594,8 +599,9 @@ func (*baseLiteral) isExpr() {} var _ ListExpr = &baseListExpr{} type baseListExpr struct { - elements []Expr - optIndices []int32 + elements []Expr + optIndices []int32 + optIndexMap map[int32]struct{} } func (*baseListExpr) Kind() ExprKind { @@ -610,12 +616,8 @@ func (e *baseListExpr) Elements() []Expr { } func (e *baseListExpr) IsOptional(index int32) bool { - for _, optIndex := range e.OptionalIndices() { - if optIndex == index { - return true - } - } - return false + _, found := e.optIndexMap[index] + return found } func (e *baseListExpr) OptionalIndices() []int32 { diff --git a/common/ast/factory.go b/common/ast/factory.go index 0111c289..b7f36e72 100644 --- a/common/ast/factory.go +++ b/common/ast/factory.go @@ -137,7 +137,16 @@ func (fac *baseExprFactory) NewLiteral(id int64, value ref.Val) Expr { } func (fac *baseExprFactory) NewList(id int64, elems []Expr, optIndices []int32) Expr { - return fac.newExpr(id, &baseListExpr{elements: elems, optIndices: optIndices}) + optIndexMap := make(map[int32]struct{}, len(optIndices)) + for _, idx := range optIndices { + optIndexMap[idx] = struct{}{} + } + return fac.newExpr(id, + &baseListExpr{ + elements: elems, + optIndices: optIndices, + optIndexMap: optIndexMap, + }) } func (fac *baseExprFactory) NewMap(id int64, entries []EntryExpr) Expr { From 0d705646d72fe8729a158c17472b4d92c3d81287 Mon Sep 17 00:00:00 2001 From: TristonianJones Date: Tue, 29 Aug 2023 13:29:15 -0700 Subject: [PATCH 3/4] Add a configurable limit to constant folding --- cel/folding.go | 41 ++++++++++++++++++++++---- cel/folding_test.go | 71 +++++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 105 insertions(+), 7 deletions(-) diff --git a/cel/folding.go b/cel/folding.go index d1612bc6..84bccf55 100644 --- a/cel/folding.go +++ b/cel/folding.go @@ -25,26 +25,52 @@ import ( "github.com/google/cel-go/common/types/traits" ) +// ConstantFoldingOption defines a functional option for configuring constant folding. +type ConstantFoldingOption func(opt *constantFoldingOptimizer) (*constantFoldingOptimizer, error) + +// MaxConstantFoldIterations limits the number of times literals may be folding during optimization. +// +// Defaults to 100 if not set. +func MaxConstantFoldIterations(limit int) ConstantFoldingOption { + return func(opt *constantFoldingOptimizer) (*constantFoldingOptimizer, error) { + opt.maxFoldIterations = limit + return opt, nil + } +} + // NewConstantFoldingOptimizer creates an optimizer which inlines constant scalar an aggregate // literal values within function calls and select statements with their evaluated result. -func NewConstantFoldingOptimizer() ASTOptimizer { - return &constantFoldingOptimizer{} +func NewConstantFoldingOptimizer(opts ...ConstantFoldingOption) (ASTOptimizer, error) { + folder := &constantFoldingOptimizer{ + maxFoldIterations: defaultMaxConstantFoldIterations, + } + var err error + for _, o := range opts { + folder, err = o(folder) + if err != nil { + return nil, err + } + } + return folder, nil } -type constantFoldingOptimizer struct{} +type constantFoldingOptimizer struct { + maxFoldIterations int +} // Optimize queries the expression graph for scalar and aggregate literal expressions within call and // select statements and then evaluates them and replaces the call site with the literal result. // // Note: only values which can be represented as literals in CEL syntax are supported. -func (*constantFoldingOptimizer) Optimize(ctx *OptimizerContext, a *ast.AST) *ast.AST { +func (opt *constantFoldingOptimizer) Optimize(ctx *OptimizerContext, a *ast.AST) *ast.AST { root := ast.NavigateAST(a) // Walk the list of foldable expression and continue to fold until there are no more folds left. // All of the fold candidates returned by the constantExprMatcher should succeed unless there's // a logic bug with the selection of expressions. foldableExprs := ast.MatchDescendants(root, constantExprMatcher) - for len(foldableExprs) != 0 { + foldCount := 0 + for len(foldableExprs) != 0 && foldCount < opt.maxFoldIterations { for _, fold := range foldableExprs { // If the expression could be folded because it's a non-strict call, and the // branches are pruned, continue to the next fold. @@ -58,6 +84,7 @@ func (*constantFoldingOptimizer) Optimize(ctx *OptimizerContext, a *ast.AST) *as return a } } + foldCount++ foldableExprs = ast.MatchDescendants(root, constantExprMatcher) } // Once all of the constants have been folded, try to run through the remaining comprehensions @@ -516,3 +543,7 @@ func aggregateLiteralMatcher(e ast.NavigableExpr) bool { var ( constantMatcher = ast.ConstantValueMatcher() ) + +const ( + defaultMaxConstantFoldIterations = 100 +) diff --git a/cel/folding_test.go b/cel/folding_test.go index 4e40be3a..18fe4d58 100644 --- a/cel/folding_test.go +++ b/cel/folding_test.go @@ -253,7 +253,70 @@ func TestConstantFoldingOptimizer(t *testing.T) { if iss.Err() != nil { t.Fatalf("Compile() failed: %v", iss.Err()) } - opt := NewStaticOptimizer(NewConstantFoldingOptimizer()) + folder, err := NewConstantFoldingOptimizer() + if err != nil { + t.Fatalf("NewConstantFoldingOptimizer() failed: %v", err) + } + opt := NewStaticOptimizer(folder) + optimized, iss := opt.Optimize(e, checked) + if iss.Err() != nil { + t.Fatalf("Optimize() generated an invalid AST: %v", iss.Err()) + } + folded, err := AstToString(optimized) + if err != nil { + t.Fatalf("AstToString() failed: %v", err) + } + if folded != tc.folded { + t.Errorf("got %q, wanted %q", folded, tc.folded) + } + }) + } +} + +func TestConstantFoldingOptimizerWithLimit(t *testing.T) { + tests := []struct { + expr string + limit int + folded string + }{ + { + expr: `[1, 1 + 2, 1 + (2 + 3)]`, + limit: 1, + folded: `[1, 3, 1 + 5]`, + }, + { + expr: `5 in [1, 1 + 2, 1 + (2 + 3)]`, + limit: 2, + folded: `5 in [1, 3, 6]`, + }, + { + // though more complex, the final tryFold() at the end of the optimization pass + // results in this computed output. + expr: `[1, 2, 3].map(i, [1, 2, 3].map(j, i * j))`, + limit: 1, + folded: `[[1, 2, 3], [2, 4, 6], [3, 6, 9]]`, + }, + } + e, err := NewEnv( + OptionalTypes(), + EnableMacroCallTracking(), + Types(&proto3pb.TestAllTypes{}), + Variable("x", DynType)) + if err != nil { + t.Fatalf("NewEnv() failed: %v", err) + } + for _, tst := range tests { + tc := tst + t.Run(tc.expr, func(t *testing.T) { + checked, iss := e.Compile(tc.expr) + if iss.Err() != nil { + t.Fatalf("Compile() failed: %v", iss.Err()) + } + folder, err := NewConstantFoldingOptimizer(MaxConstantFoldIterations(tc.limit)) + if err != nil { + t.Fatalf("NewConstantFoldingOptimizer() failed: %v", err) + } + opt := NewStaticOptimizer(folder) optimized, iss := opt.Optimize(e, checked) if iss.Err() != nil { t.Fatalf("Optimize() generated an invalid AST: %v", iss.Err()) @@ -507,7 +570,11 @@ func TestConstantFoldingNormalizeIDs(t *testing.T) { t.Errorf("Compile() for macro %d got %s, expected %s", id, prototext.Format(pbCall), macroText) } } - opt := NewStaticOptimizer(NewConstantFoldingOptimizer()) + folder, err := NewConstantFoldingOptimizer() + if err != nil { + t.Fatalf("NewConstantFoldingOptimizer() failed: %v", err) + } + opt := NewStaticOptimizer(folder) optimized, iss := opt.Optimize(e, checked) if iss.Err() != nil { t.Fatalf("Optimize() generated an invalid AST: %v", iss.Err()) From c15a14cdbc4c5bdce9f6469c5a170e97694c9cec Mon Sep 17 00:00:00 2001 From: TristonianJones Date: Wed, 30 Aug 2023 15:46:04 -0700 Subject: [PATCH 4/4] Additional comments --- cel/folding.go | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/cel/folding.go b/cel/folding.go index 84bccf55..5903d732 100644 --- a/cel/folding.go +++ b/cel/folding.go @@ -217,7 +217,7 @@ func maybeShortcircuitLogic(ctx *OptimizerContext, function string, args []ast.E // pruneOptionalElements works from the bottom up to resolve optional elements within // aggregate literals. // -// Note, may aggregate literals will be resolved as arguments to functions or select +// Note, many aggregate literals will be resolved as arguments to functions or select // statements, so this method exists to handle the case where the literal could not be // fully resolved or exists outside of a call, select, or comprehension context. func pruneOptionalElements(ctx *OptimizerContext, root ast.NavigableExpr) { @@ -277,6 +277,8 @@ func pruneOptionalMapEntries(ctx *OptimizerContext, e ast.Expr) { entry := e.AsMapEntry() key := entry.Key() val := entry.Value() + // If the entry is not optional, or the value-side of the optional hasn't + // been resolved to a literal, then preserve the entry as-is. if !entry.IsOptional() || val.Kind() != ast.LiteralKind { updatedEntries = append(updatedEntries, e) continue @@ -286,6 +288,8 @@ func pruneOptionalMapEntries(ctx *OptimizerContext, e ast.Expr) { updatedEntries = append(updatedEntries, e) continue } + // When the key is not a literal, but the value is, then it needs to be + // restored to an optional value. if key.Kind() != ast.LiteralKind { undoOptVal, err := adaptLiteral(ctx, optElemVal) if err != nil {