Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stabilize macro id generation #962

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 19 additions & 3 deletions cel/optimizer.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
package cel

import (
"sort"

"github.com/google/cel-go/common"
"github.com/google/cel-go/common/ast"
"github.com/google/cel-go/common/types"
Expand Down Expand Up @@ -98,14 +100,21 @@ func (opt *StaticOptimizer) Optimize(env *Env, a *Ast) (*Ast, *Issues) {
// that the ids within the expression correspond to the ids within macros.
func normalizeIDs(idGen ast.IDGenerator, optimized ast.Expr, info *ast.SourceInfo) {
optimized.RenumberIDs(idGen)

if len(info.MacroCalls()) == 0 {
return
}

// Sort the macro ids to make sure that the renumbering of macro-specific variables
// is stable across normalization calls.
sortedMacroIDs := []int64{}
for id := range info.MacroCalls() {
sortedMacroIDs = append(sortedMacroIDs, id)
}
sort.Slice(sortedMacroIDs, func(i, j int) bool { return sortedMacroIDs[i] < sortedMacroIDs[j] })

// First, update the macro call ids themselves.
callIDMap := map[int64]int64{}
for id := range info.MacroCalls() {
for _, id := range sortedMacroIDs {
callIDMap[id] = idGen(id)
}
// Then update the macro call definitions which refer to these ids, but
Expand All @@ -116,7 +125,8 @@ func normalizeIDs(idGen ast.IDGenerator, optimized ast.Expr, info *ast.SourceInf
call ast.Expr
}
macroUpdates := []macroUpdate{}
for oldID, newID := range callIDMap {
for _, oldID := range sortedMacroIDs {
newID := callIDMap[oldID]
call, found := info.GetMacroCall(oldID)
if !found {
continue
Expand All @@ -134,6 +144,7 @@ func cleanupMacroRefs(expr ast.Expr, info *ast.SourceInfo) {
if len(info.MacroCalls()) == 0 {
return
}

// Sanitize the macro call references once the optimized expression has been computed
// and the ids normalized between the expression and the macros.
exprRefMap := make(map[int64]struct{})
Expand Down Expand Up @@ -253,6 +264,11 @@ func (opt *optimizerExprFactory) SetMacroCall(id int64, expr ast.Expr) {
opt.sourceInfo.SetMacroCall(id, expr)
}

// MacroCalls returns the map of macro calls currently in the context.
func (opt *optimizerExprFactory) MacroCalls() map[int64]ast.Expr {
return opt.sourceInfo.MacroCalls()
}

// NewBindMacro creates an AST expression representing the expanded bind() macro, and a macro expression
// representing the unexpanded call signature to be inserted into the source info macro call metadata.
func (opt *optimizerExprFactory) NewBindMacro(macroID int64, varName string, varInit, remaining ast.Expr) (astExpr, macroExpr ast.Expr) {
Expand Down
142 changes: 122 additions & 20 deletions cel/optimizer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,19 @@
package cel_test

import (
"reflect"
"sort"
"testing"

"github.com/google/cel-go/cel"
"github.com/google/cel-go/common/ast"
"github.com/google/cel-go/ext"

"google.golang.org/protobuf/encoding/prototext"
"google.golang.org/protobuf/proto"

proto3pb "github.com/google/cel-go/test/proto3pb"

exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
)

func TestStaticOptimizerUpdateExpr(t *testing.T) {
Expand All @@ -49,6 +53,118 @@ func TestStaticOptimizerUpdateExpr(t *testing.T) {
if err != nil {
t.Fatalf("cel.AstToString() failed: %v", err)
}
sourceInfo := optAST.NativeRep().SourceInfo()
sourceInfoPB, err := ast.SourceInfoToProto(sourceInfo)
if err != nil {
t.Fatalf("cel.AstToCheckedExpr() failed: %v", err)
}
wantTextPB := `
location: "<input>"
line_offsets: 9
positions: {
key: 2
value: 4
}
positions: {
key: 3
value: 5
}
positions: {
key: 4
value: 3
}
macro_calls: {
key: 1
value: {
call_expr: {
function: "has"
args: {
id: 21
select_expr: {
operand: {
id: 2
call_expr: {
function: "_[_]"
args: {
id: 3
}
args: {
id: 20
const_expr: {
int64_value: 0
}
}
}
}
field: "z"
}
}
}
}
}
macro_calls: {
key: 3
value: {
call_expr: {
target: {
id: 4
list_expr: {
elements: {
id: 5
ident_expr: {
name: "x"
}
}
elements: {
id: 6
ident_expr: {
name: "y"
}
}
}
}
function: "filter"
args: {
id: 17
ident_expr: {
name: "i"
}
}
args: {
id: 10
call_expr: {
function: "_>_"
args: {
id: 11
call_expr: {
target: {
id: 12
ident_expr: {
name: "i"
}
}
function: "size"
}
}
args: {
id: 13
const_expr: {
int64_value: 0
}
}
}
}
}
}
}
`
var wantSourceInfoPB exprpb.SourceInfo
if err := prototext.Unmarshal([]byte(wantTextPB), &wantSourceInfoPB); err != nil {
t.Fatalf("prototext.Unmarshal() failed: %v", err)
}
if !proto.Equal(&wantSourceInfoPB, sourceInfoPB) {
t.Errorf("got source info: %s, wanted %s", prototext.Format(sourceInfoPB), wantTextPB)
}
expected := `has([x, y].filter(i, i.size() > 0)[0].z)`
if expected != optString {
t.Errorf("inlined got %q, wanted %q", optString, expected)
Expand Down Expand Up @@ -107,30 +223,16 @@ type testOptimizer struct {

func (opt *testOptimizer) Optimize(ctx *cel.OptimizerContext, a *ast.AST) *ast.AST {
opt.t.Helper()
copy, info := ctx.CopyAST(opt.inlineExpr)
infoMacroKeys := getMacroKeys(info.MacroCalls())
for id, call := range info.MacroCalls() {
a.SourceInfo().SetMacroCall(id, call)
}
copy := ctx.CopyASTAndMetadata(opt.inlineExpr)
origID := a.Expr().ID()
exprID := origID + 100
presenceTest, hasMacro := ctx.NewHasMacro(exprID, copy)
macroKeys := getMacroKeys(a.SourceInfo().MacroCalls())
presenceTest, hasMacro := ctx.NewHasMacro(origID, copy)
macroKeys := getMacroKeys(ctx.MacroCalls())
if len(macroKeys) != 2 {
opt.t.Errorf("Got %v macro calls, wanted 2", macroKeys)
}
ctx.UpdateExpr(a.Expr(), presenceTest)
macroKeys = getMacroKeys(a.SourceInfo().MacroCalls())
if _, found := a.SourceInfo().GetMacroCall(origID); found {
opt.t.Errorf("Got %v macro calls, wanted 1", macroKeys)
}

a.SourceInfo().SetMacroCall(exprID, hasMacro)
macroKeys = getMacroKeys(a.SourceInfo().MacroCalls())
if !reflect.DeepEqual(macroKeys, append(infoMacroKeys, int(exprID))) {
opt.t.Errorf("Got %v macro calls, wanted 2", macroKeys)
}
return a
ctx.SetMacroCall(origID, hasMacro)
return ctx.NewAST(a.Expr())
}

func getMacroKeys(macroCalls map[int64]ast.Expr) []int {
Expand Down