Skip to content

Commit

Permalink
Stabilize macro id generation (#962)
Browse files Browse the repository at this point in the history
  • Loading branch information
TristonianJones authored Jun 12, 2024
1 parent e765664 commit b27008a
Show file tree
Hide file tree
Showing 2 changed files with 141 additions and 23 deletions.
22 changes: 19 additions & 3 deletions cel/optimizer.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
package cel

import (
"sort"

"github.com/google/cel-go/common"
"github.com/google/cel-go/common/ast"
"github.com/google/cel-go/common/types"
Expand Down Expand Up @@ -98,14 +100,21 @@ func (opt *StaticOptimizer) Optimize(env *Env, a *Ast) (*Ast, *Issues) {
// that the ids within the expression correspond to the ids within macros.
func normalizeIDs(idGen ast.IDGenerator, optimized ast.Expr, info *ast.SourceInfo) {
optimized.RenumberIDs(idGen)

if len(info.MacroCalls()) == 0 {
return
}

// Sort the macro ids to make sure that the renumbering of macro-specific variables
// is stable across normalization calls.
sortedMacroIDs := []int64{}
for id := range info.MacroCalls() {
sortedMacroIDs = append(sortedMacroIDs, id)
}
sort.Slice(sortedMacroIDs, func(i, j int) bool { return sortedMacroIDs[i] < sortedMacroIDs[j] })

// First, update the macro call ids themselves.
callIDMap := map[int64]int64{}
for id := range info.MacroCalls() {
for _, id := range sortedMacroIDs {
callIDMap[id] = idGen(id)
}
// Then update the macro call definitions which refer to these ids, but
Expand All @@ -116,7 +125,8 @@ func normalizeIDs(idGen ast.IDGenerator, optimized ast.Expr, info *ast.SourceInf
call ast.Expr
}
macroUpdates := []macroUpdate{}
for oldID, newID := range callIDMap {
for _, oldID := range sortedMacroIDs {
newID := callIDMap[oldID]
call, found := info.GetMacroCall(oldID)
if !found {
continue
Expand All @@ -134,6 +144,7 @@ func cleanupMacroRefs(expr ast.Expr, info *ast.SourceInfo) {
if len(info.MacroCalls()) == 0 {
return
}

// Sanitize the macro call references once the optimized expression has been computed
// and the ids normalized between the expression and the macros.
exprRefMap := make(map[int64]struct{})
Expand Down Expand Up @@ -253,6 +264,11 @@ func (opt *optimizerExprFactory) SetMacroCall(id int64, expr ast.Expr) {
opt.sourceInfo.SetMacroCall(id, expr)
}

// MacroCalls returns the map of macro calls currently in the context.
func (opt *optimizerExprFactory) MacroCalls() map[int64]ast.Expr {
return opt.sourceInfo.MacroCalls()
}

// NewBindMacro creates an AST expression representing the expanded bind() macro, and a macro expression
// representing the unexpanded call signature to be inserted into the source info macro call metadata.
func (opt *optimizerExprFactory) NewBindMacro(macroID int64, varName string, varInit, remaining ast.Expr) (astExpr, macroExpr ast.Expr) {
Expand Down
142 changes: 122 additions & 20 deletions cel/optimizer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,19 @@
package cel_test

import (
"reflect"
"sort"
"testing"

"github.com/google/cel-go/cel"
"github.com/google/cel-go/common/ast"
"github.com/google/cel-go/ext"

"google.golang.org/protobuf/encoding/prototext"
"google.golang.org/protobuf/proto"

proto3pb "github.com/google/cel-go/test/proto3pb"

exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
)

func TestStaticOptimizerUpdateExpr(t *testing.T) {
Expand All @@ -49,6 +53,118 @@ func TestStaticOptimizerUpdateExpr(t *testing.T) {
if err != nil {
t.Fatalf("cel.AstToString() failed: %v", err)
}
sourceInfo := optAST.NativeRep().SourceInfo()
sourceInfoPB, err := ast.SourceInfoToProto(sourceInfo)
if err != nil {
t.Fatalf("cel.AstToCheckedExpr() failed: %v", err)
}
wantTextPB := `
location: "<input>"
line_offsets: 9
positions: {
key: 2
value: 4
}
positions: {
key: 3
value: 5
}
positions: {
key: 4
value: 3
}
macro_calls: {
key: 1
value: {
call_expr: {
function: "has"
args: {
id: 21
select_expr: {
operand: {
id: 2
call_expr: {
function: "_[_]"
args: {
id: 3
}
args: {
id: 20
const_expr: {
int64_value: 0
}
}
}
}
field: "z"
}
}
}
}
}
macro_calls: {
key: 3
value: {
call_expr: {
target: {
id: 4
list_expr: {
elements: {
id: 5
ident_expr: {
name: "x"
}
}
elements: {
id: 6
ident_expr: {
name: "y"
}
}
}
}
function: "filter"
args: {
id: 17
ident_expr: {
name: "i"
}
}
args: {
id: 10
call_expr: {
function: "_>_"
args: {
id: 11
call_expr: {
target: {
id: 12
ident_expr: {
name: "i"
}
}
function: "size"
}
}
args: {
id: 13
const_expr: {
int64_value: 0
}
}
}
}
}
}
}
`
var wantSourceInfoPB exprpb.SourceInfo
if err := prototext.Unmarshal([]byte(wantTextPB), &wantSourceInfoPB); err != nil {
t.Fatalf("prototext.Unmarshal() failed: %v", err)
}
if !proto.Equal(&wantSourceInfoPB, sourceInfoPB) {
t.Errorf("got source info: %s, wanted %s", prototext.Format(sourceInfoPB), wantTextPB)
}
expected := `has([x, y].filter(i, i.size() > 0)[0].z)`
if expected != optString {
t.Errorf("inlined got %q, wanted %q", optString, expected)
Expand Down Expand Up @@ -107,30 +223,16 @@ type testOptimizer struct {

func (opt *testOptimizer) Optimize(ctx *cel.OptimizerContext, a *ast.AST) *ast.AST {
opt.t.Helper()
copy, info := ctx.CopyAST(opt.inlineExpr)
infoMacroKeys := getMacroKeys(info.MacroCalls())
for id, call := range info.MacroCalls() {
a.SourceInfo().SetMacroCall(id, call)
}
copy := ctx.CopyASTAndMetadata(opt.inlineExpr)
origID := a.Expr().ID()
exprID := origID + 100
presenceTest, hasMacro := ctx.NewHasMacro(exprID, copy)
macroKeys := getMacroKeys(a.SourceInfo().MacroCalls())
presenceTest, hasMacro := ctx.NewHasMacro(origID, copy)
macroKeys := getMacroKeys(ctx.MacroCalls())
if len(macroKeys) != 2 {
opt.t.Errorf("Got %v macro calls, wanted 2", macroKeys)
}
ctx.UpdateExpr(a.Expr(), presenceTest)
macroKeys = getMacroKeys(a.SourceInfo().MacroCalls())
if _, found := a.SourceInfo().GetMacroCall(origID); found {
opt.t.Errorf("Got %v macro calls, wanted 1", macroKeys)
}

a.SourceInfo().SetMacroCall(exprID, hasMacro)
macroKeys = getMacroKeys(a.SourceInfo().MacroCalls())
if !reflect.DeepEqual(macroKeys, append(infoMacroKeys, int(exprID))) {
opt.t.Errorf("Got %v macro calls, wanted 2", macroKeys)
}
return a
ctx.SetMacroCall(origID, hasMacro)
return ctx.NewAST(a.Expr())
}

func getMacroKeys(macroCalls map[int64]ast.Expr) []int {
Expand Down

0 comments on commit b27008a

Please sign in to comment.