From 8fc2ea8e6b0db2cc49b6f4505bc26f0cff573ac2 Mon Sep 17 00:00:00 2001 From: TristonianJones Date: Mon, 10 Jun 2024 12:43:28 -0700 Subject: [PATCH] Stabilize macro id generation --- cel/optimizer.go | 22 ++++++- cel/optimizer_test.go | 142 ++++++++++++++++++++++++++++++++++++------ 2 files changed, 141 insertions(+), 23 deletions(-) diff --git a/cel/optimizer.go b/cel/optimizer.go index f26df462..ec02773a 100644 --- a/cel/optimizer.go +++ b/cel/optimizer.go @@ -15,6 +15,8 @@ package cel import ( + "sort" + "github.com/google/cel-go/common" "github.com/google/cel-go/common/ast" "github.com/google/cel-go/common/types" @@ -98,14 +100,21 @@ func (opt *StaticOptimizer) Optimize(env *Env, a *Ast) (*Ast, *Issues) { // that the ids within the expression correspond to the ids within macros. func normalizeIDs(idGen ast.IDGenerator, optimized ast.Expr, info *ast.SourceInfo) { optimized.RenumberIDs(idGen) - if len(info.MacroCalls()) == 0 { return } + // Sort the macro ids to make sure that the renumbering of macro-specific variables + // is stable across normalization calls. + sortedMacroIDs := []int64{} + for id := range info.MacroCalls() { + sortedMacroIDs = append(sortedMacroIDs, id) + } + sort.Slice(sortedMacroIDs, func(i, j int) bool { return sortedMacroIDs[i] < sortedMacroIDs[j] }) + // First, update the macro call ids themselves. callIDMap := map[int64]int64{} - for id := range info.MacroCalls() { + for _, id := range sortedMacroIDs { callIDMap[id] = idGen(id) } // Then update the macro call definitions which refer to these ids, but @@ -116,7 +125,8 @@ func normalizeIDs(idGen ast.IDGenerator, optimized ast.Expr, info *ast.SourceInf call ast.Expr } macroUpdates := []macroUpdate{} - for oldID, newID := range callIDMap { + for _, oldID := range sortedMacroIDs { + newID := callIDMap[oldID] call, found := info.GetMacroCall(oldID) if !found { continue @@ -134,6 +144,7 @@ func cleanupMacroRefs(expr ast.Expr, info *ast.SourceInfo) { if len(info.MacroCalls()) == 0 { return } + // Sanitize the macro call references once the optimized expression has been computed // and the ids normalized between the expression and the macros. exprRefMap := make(map[int64]struct{}) @@ -253,6 +264,11 @@ func (opt *optimizerExprFactory) SetMacroCall(id int64, expr ast.Expr) { opt.sourceInfo.SetMacroCall(id, expr) } +// MacroCalls returns the map of macro calls currently in the context. +func (opt *optimizerExprFactory) MacroCalls() map[int64]ast.Expr { + return opt.sourceInfo.MacroCalls() +} + // NewBindMacro creates an AST expression representing the expanded bind() macro, and a macro expression // representing the unexpanded call signature to be inserted into the source info macro call metadata. func (opt *optimizerExprFactory) NewBindMacro(macroID int64, varName string, varInit, remaining ast.Expr) (astExpr, macroExpr ast.Expr) { diff --git a/cel/optimizer_test.go b/cel/optimizer_test.go index 8ecd8216..2faac17e 100644 --- a/cel/optimizer_test.go +++ b/cel/optimizer_test.go @@ -15,7 +15,6 @@ package cel_test import ( - "reflect" "sort" "testing" @@ -23,7 +22,12 @@ import ( "github.com/google/cel-go/common/ast" "github.com/google/cel-go/ext" + "google.golang.org/protobuf/encoding/prototext" + "google.golang.org/protobuf/proto" + proto3pb "github.com/google/cel-go/test/proto3pb" + + exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1" ) func TestStaticOptimizerUpdateExpr(t *testing.T) { @@ -49,6 +53,118 @@ func TestStaticOptimizerUpdateExpr(t *testing.T) { if err != nil { t.Fatalf("cel.AstToString() failed: %v", err) } + sourceInfo := optAST.NativeRep().SourceInfo() + sourceInfoPB, err := ast.SourceInfoToProto(sourceInfo) + if err != nil { + t.Fatalf("cel.AstToCheckedExpr() failed: %v", err) + } + wantTextPB := ` + location: "" + line_offsets: 9 + positions: { + key: 2 + value: 4 + } + positions: { + key: 3 + value: 5 + } + positions: { + key: 4 + value: 3 + } + macro_calls: { + key: 1 + value: { + call_expr: { + function: "has" + args: { + id: 21 + select_expr: { + operand: { + id: 2 + call_expr: { + function: "_[_]" + args: { + id: 3 + } + args: { + id: 20 + const_expr: { + int64_value: 0 + } + } + } + } + field: "z" + } + } + } + } + } + macro_calls: { + key: 3 + value: { + call_expr: { + target: { + id: 4 + list_expr: { + elements: { + id: 5 + ident_expr: { + name: "x" + } + } + elements: { + id: 6 + ident_expr: { + name: "y" + } + } + } + } + function: "filter" + args: { + id: 17 + ident_expr: { + name: "i" + } + } + args: { + id: 10 + call_expr: { + function: "_>_" + args: { + id: 11 + call_expr: { + target: { + id: 12 + ident_expr: { + name: "i" + } + } + function: "size" + } + } + args: { + id: 13 + const_expr: { + int64_value: 0 + } + } + } + } + } + } + } + ` + var wantSourceInfoPB exprpb.SourceInfo + if err := prototext.Unmarshal([]byte(wantTextPB), &wantSourceInfoPB); err != nil { + t.Fatalf("prototext.Unmarshal() failed: %v", err) + } + if !proto.Equal(&wantSourceInfoPB, sourceInfoPB) { + t.Errorf("got source info: %s, wanted %s", prototext.Format(sourceInfoPB), wantTextPB) + } expected := `has([x, y].filter(i, i.size() > 0)[0].z)` if expected != optString { t.Errorf("inlined got %q, wanted %q", optString, expected) @@ -107,30 +223,16 @@ type testOptimizer struct { func (opt *testOptimizer) Optimize(ctx *cel.OptimizerContext, a *ast.AST) *ast.AST { opt.t.Helper() - copy, info := ctx.CopyAST(opt.inlineExpr) - infoMacroKeys := getMacroKeys(info.MacroCalls()) - for id, call := range info.MacroCalls() { - a.SourceInfo().SetMacroCall(id, call) - } + copy := ctx.CopyASTAndMetadata(opt.inlineExpr) origID := a.Expr().ID() - exprID := origID + 100 - presenceTest, hasMacro := ctx.NewHasMacro(exprID, copy) - macroKeys := getMacroKeys(a.SourceInfo().MacroCalls()) + presenceTest, hasMacro := ctx.NewHasMacro(origID, copy) + macroKeys := getMacroKeys(ctx.MacroCalls()) if len(macroKeys) != 2 { opt.t.Errorf("Got %v macro calls, wanted 2", macroKeys) } ctx.UpdateExpr(a.Expr(), presenceTest) - macroKeys = getMacroKeys(a.SourceInfo().MacroCalls()) - if _, found := a.SourceInfo().GetMacroCall(origID); found { - opt.t.Errorf("Got %v macro calls, wanted 1", macroKeys) - } - - a.SourceInfo().SetMacroCall(exprID, hasMacro) - macroKeys = getMacroKeys(a.SourceInfo().MacroCalls()) - if !reflect.DeepEqual(macroKeys, append(infoMacroKeys, int(exprID))) { - opt.t.Errorf("Got %v macro calls, wanted 2", macroKeys) - } - return a + ctx.SetMacroCall(origID, hasMacro) + return ctx.NewAST(a.Expr()) } func getMacroKeys(macroCalls map[int64]ast.Expr) []int {