Skip to content

Commit

Permalink
Expose an option to track macro call replacements (#470)
Browse files Browse the repository at this point in the history
  • Loading branch information
TristonianJones authored Nov 9, 2021
1 parent 3d3d767 commit dfef54b
Show file tree
Hide file tree
Showing 7 changed files with 211 additions and 37 deletions.
50 changes: 48 additions & 2 deletions cel/cel_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1046,20 +1046,66 @@ func TestEnvExtensionIsolation(t *testing.T) {
}
}

func TestParseError(t *testing.T) {
e, err := NewEnv()
if err != nil {
t.Fatalf("NewEnv() failed: %v", err)
}
_, iss := e.Parse("invalid & logical_and")
if iss.Err() == nil {
t.Fatal("e.Parse('invalid & logical_and') did not error")
}
}

func TestParseWithMacroTracking(t *testing.T) {
e, err := NewEnv(EnableMacroCallTracking())
if err != nil {
t.Fatalf("NewEnv(EnableMacroCallTracking()) failed: %v", err)
}
ast, iss := e.Parse("has(a.b) && a.b.exists(c, c < 10)")
if iss.Err() != nil {
t.Fatalf("e.Parse() failed: %v", iss.Err())
}
pe, err := AstToParsedExpr(ast)
if err != nil {
t.Fatalf("AstToParsedExpr(%v) failed: %v", ast, err)
}
macroCalls := pe.GetSourceInfo().GetMacroCalls()
if len(macroCalls) != 2 {
t.Errorf("got %d macro calls, wanted 2", len(macroCalls))
}
callsFound := map[string]bool{"has": false, "exists": false}
for _, expr := range macroCalls {
f := expr.GetCallExpr().GetFunction()
_, found := callsFound[f]
if !found {
t.Errorf("Unexpected macro call: %v", expr)
}
callsFound[f] = true
}
callsWanted := map[string]bool{"has": true, "exists": true}
if !reflect.DeepEqual(callsFound, callsWanted) {
t.Errorf("Tracked calls %v, but wanted %v", callsFound, callsWanted)
}
}

func TestParseAndCheckConcurrently(t *testing.T) {
e, _ := NewEnv(
e, err := NewEnv(
Container("google.api.expr.v1alpha1"),
Types(&exprpb.Expr{}),
Declarations(
decls.NewVar("expr",
decls.NewObjectType("google.api.expr.v1alpha1.Expr")),
),
)
if err != nil {
t.Fatalf("NewEnv() failed: %v", err)
}

parseAndCheck := func(expr string) {
_, iss := e.Compile(expr)
if iss.Err() != nil {
t.Fatalf("failed to parse '%s': %v", expr, iss.Err())
t.Fatalf("e.Compile('%s') failed: %v", expr, iss.Err())
}
}

Expand Down
38 changes: 26 additions & 12 deletions cel/env.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,13 +89,17 @@ type Env struct {
adapter ref.TypeAdapter
provider ref.TypeProvider
features map[int]bool
// program options tied to the environment.
progOpts []ProgramOption

// Internal parser representation
prsr *parser.Parser

// Internal checker representation
chk *checker.Env
chkErr error
once sync.Once
chk *checker.Env
chkErr error
chkOnce sync.Once

// Program options tied to the environment
progOpts []ProgramOption
}

// NewEnv creates a program environment configured with the standard library of CEL functions and
Expand Down Expand Up @@ -147,10 +151,10 @@ func (e *Env) Check(ast *Ast) (*Ast, *Issues) {
pe, _ := AstToParsedExpr(ast)

// Construct the internal checker env, erroring if there is an issue adding the declarations.
e.once.Do(func() {
e.chkOnce.Do(func() {
ce := checker.NewEnv(e.Container, e.provider)
ce.EnableDynamicAggregateLiterals(true)
if e.HasFeature(FeatureDisableDynamicAggregateLiterals) {
if e.HasFeature(featureDisableDynamicAggregateLiterals) {
ce.EnableDynamicAggregateLiterals(false)
}
err := ce.Add(e.declarations...)
Expand Down Expand Up @@ -207,11 +211,10 @@ func (e *Env) CompileSource(src common.Source) (*Ast, *Issues) {
return nil, iss
}
checked, iss2 := e.Check(ast)
iss = iss.Append(iss2)
if iss.Err() != nil {
return nil, iss
if iss2.Err() != nil {
return nil, iss2
}
return checked, iss
return checked, iss2
}

// Extend the current environment with additional options to produce a new Env.
Expand Down Expand Up @@ -301,7 +304,7 @@ func (e *Env) Parse(txt string) (*Ast, *Issues) {
// It is possible to have both non-nil Ast and Issues values returned from this call; however,
// the mere presence of an Ast does not imply that it is valid for use.
func (e *Env) ParseSource(src common.Source) (*Ast, *Issues) {
res, errs := parser.ParseWithMacros(src, e.macros)
res, errs := e.prsr.Parse(src)
if len(errs.GetErrors()) > 0 {
return nil, &Issues{errs: errs}
}
Expand Down Expand Up @@ -413,6 +416,14 @@ func (e *Env) configure(opts []EnvOption) (*Env, error) {
return nil, err
}
}
prsrOpts := []parser.Option{parser.Macros(e.macros...)}
if e.HasFeature(featureEnableMacroCallTracking) {
prsrOpts = append(prsrOpts, parser.PopulateMacroCalls(true))
}
e.prsr, err = parser.NewParser(prsrOpts...)
if err != nil {
return nil, err
}
return e, nil
}

Expand Down Expand Up @@ -454,6 +465,9 @@ func (i *Issues) Append(other *Issues) *Issues {
if i == nil {
return other
}
if other == nil {
return i
}
return NewIssues(i.errs.Append(other.errs.GetErrors()))
}

Expand Down
81 changes: 81 additions & 0 deletions cel/env_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
// Copyright 2021 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package cel

import (
"reflect"
"testing"

"github.com/google/cel-go/common"
)

func TestIssuesNil(t *testing.T) {
var iss *Issues
iss = iss.Append(iss)
if iss.Err() != nil {
t.Errorf("iss.Err() got %v, wanted nil given nil issue set", iss.Err())
}
if len(iss.Errors()) != 0 {
t.Errorf("iss.Errors() got %v, wanted empty value", iss.Errors())
}
if iss.String() != "" {
t.Errorf("iss.String() returned %v, wanted empty value", iss.String())
}
}

func TestIssuesEmpty(t *testing.T) {
iss := NewIssues(common.NewErrors(nil))
if iss.Err() != nil {
t.Errorf("iss.Err() got %v, wanted nil given nil issue set", iss.Err())
}
if len(iss.Errors()) != 0 {
t.Errorf("iss.Errors() got %v, wanted empty value", iss.Errors())
}
if iss.String() != "" {
t.Errorf("iss.String() returned %v, wanted empty value", iss.String())
}
var iss2 *Issues
iss3 := iss.Append(iss2)
iss4 := iss3.Append(nil)
if !reflect.DeepEqual(iss4, iss) {
t.Error("Append() with a nil value resulted in the creation of a new issue set")
}
}

func TestIssues(t *testing.T) {
e, err := NewEnv()
if err != nil {
t.Fatalf("NewEnv() failed: %v", err)
}
_, iss := e.Compile("-")
_, iss2 := e.Compile("b")
iss = iss.Append(iss2)
if len(iss.Errors()) != 3 {
t.Errorf("iss.Errors() got %v, wanted 3 errors", iss.Errors())
}

wantIss := `ERROR: <input>:1:1: undeclared reference to 'b' (in container '')
| -
| ^
ERROR: <input>:1:2: Syntax error: no viable alternative at input '-'
| -
| .^
ERROR: <input>:1:2: Syntax error: mismatched input '<EOF>' expecting {'[', '{', '(', '.', '-', '!', 'true', 'false', 'null', NUM_FLOAT, NUM_INT, NUM_UINT, STRING, BYTES, IDENTIFIER}
| -
| .^`
if iss.String() != wantIss {
t.Errorf("iss.String() returned %v, wanted %v", iss.String(), wantIss)
}
}
34 changes: 33 additions & 1 deletion cel/io_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ import (

"github.com/google/cel-go/checker/decls"
"google.golang.org/protobuf/proto"

exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
)

func TestAstToProto(t *testing.T) {
Expand Down Expand Up @@ -90,7 +92,7 @@ func TestAstToString(t *testing.T) {
}
}

func TestCheckedExprToAst_ConstantExpr(t *testing.T) {
func TestCheckedExprToAstConstantExpr(t *testing.T) {
stdEnv, err := NewEnv()
if err != nil {
t.Fatalf("NewEnv() failed: %v", err)
Expand All @@ -109,3 +111,33 @@ func TestCheckedExprToAst_ConstantExpr(t *testing.T) {
t.Fatalf("got ast %v, wanted %v", ast2, ast)
}
}

func TestCheckedExprToAstMissingInfo(t *testing.T) {
stdEnv, err := NewEnv()
if err != nil {
t.Fatalf("NewEnv() failed: %v", err)
}
in := "10"
ast, iss := stdEnv.Parse(in)
if iss.Err() != nil {
t.Fatalf("stdEnv.Compile(%q) failed: %v", in, iss.Err())
}
if ast.ResultType() != decls.Dyn {
t.Fatalf("ast.ResultType() got %v, wanted 'dyn'", ast.ResultType())
}
expr, err := AstToParsedExpr(ast)
if err != nil {
t.Fatalf("AstToParsedExpr(ast) failed: %v", err)
}
checkedExpr := &exprpb.CheckedExpr{
TypeMap: map[int64]*exprpb.Type{expr.GetExpr().GetId(): decls.Int},
Expr: expr.GetExpr(),
}
ast2 := CheckedExprToAst(checkedExpr)
if !ast2.IsChecked() {
t.Fatal("CheckedExprToAst() did not produce a 'checked' ast")
}
if ast2.ResultType() != decls.Int {
t.Fatalf("ast2.ResultType() got %v, wanted 'int'", ast.ResultType())
}
}
33 changes: 22 additions & 11 deletions cel/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,10 @@ const (
// provided as variables to the expression, as well as via conversion
// of well-known dynamic types, or with unchecked expressions.
// Affects checking. Provides a subset of standard behavior.
FeatureDisableDynamicAggregateLiterals
featureDisableDynamicAggregateLiterals

// Enable the tracking of function call expressions replaced by macros.
featureEnableMacroCallTracking
)

// EnvOption is a functional interface for configuring the environment.
Expand Down Expand Up @@ -113,7 +116,7 @@ func Features(flags ...int) EnvOption {
// expression, as well as via conversion of well-known dynamic types, or with unchecked
// expressions.
func HomogeneousAggregateLiterals() EnvOption {
return Features(FeatureDisableDynamicAggregateLiterals)
return Features(featureDisableDynamicAggregateLiterals)
}

// Macros option extends the macro set configured in the environment.
Expand Down Expand Up @@ -334,8 +337,7 @@ func Functions(funcs ...*functions.Overload) ProgramOption {
// The vars value may either be an `interpreter.Activation` instance or a `map[string]interface{}`.
func Globals(vars interface{}) ProgramOption {
return func(p *prog) (*prog, error) {
defaultVars, err :=
interpreter.NewActivation(vars)
defaultVars, err := interpreter.NewActivation(vars)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -411,19 +413,19 @@ func fieldToDecl(field protoreflect.FieldDescriptor) (*exprpb.Decl, error) {
return nil, err
}
return decls.NewVar(name, decls.NewMapType(keyType, valueType)), nil
} else if field.IsList() {
}
if field.IsList() {
elemType, err := fieldToCELType(field)
if err != nil {
return nil, err
}
return decls.NewVar(name, decls.NewListType(elemType)), nil
} else {
celType, err := fieldToCELType(field)
if err != nil {
return nil, err
}
return decls.NewVar(name, celType), nil
}
celType, err := fieldToCELType(field)
if err != nil {
return nil, err
}
return decls.NewVar(name, celType), nil
}

// DeclareContextProto returns an option to extend CEL environment with declarations from the given context proto.
Expand All @@ -449,3 +451,12 @@ func DeclareContextProto(descriptor protoreflect.MessageDescriptor) EnvOption {
return Types(dynamicpb.NewMessage(descriptor))(e)
}
}

// EnableMacroCallTracking ensures that call expressions which are replaced by macros
// are tracked in the `SourceInfo` of parsed and checked expressions.
func EnableMacroCallTracking() EnvOption {
return func(e *Env) (*Env, error) {
e.features[featureEnableMacroCallTracking] = true
return e, nil
}
}
3 changes: 0 additions & 3 deletions cel/program.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,9 +129,6 @@ func newProgram(e *Env, ast *Ast, opts []ProgramOption) (Program, error) {
// Configure the program via the ProgramOption values.
var err error
for _, opt := range opts {
if opt == nil {
return nil, fmt.Errorf("program options should be non-nil")
}
p, err = opt(p)
if err != nil {
return nil, err
Expand Down
9 changes: 1 addition & 8 deletions parser/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,14 +142,7 @@ var reservedIds = map[string]struct{}{
//
// Deprecated: Use NewParser().Parse() instead.
func Parse(source common.Source) (*exprpb.ParsedExpr, *common.Errors) {
return ParseWithMacros(source, AllMacros)
}

// ParseWithMacros converts a source input and macros set to a parsed expression.
//
// Deprecated: Use NewParser().Parse() instead.
func ParseWithMacros(source common.Source, macros []Macro) (*exprpb.ParsedExpr, *common.Errors) {
return mustNewParser(Macros(macros...)).Parse(source)
return mustNewParser(Macros(AllMacros...)).Parse(source)
}

type recursionError struct {
Expand Down

0 comments on commit dfef54b

Please sign in to comment.