Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sets cost estimation and tracking options #850

Merged
merged 8 commits into from
Oct 27, 2023
10 changes: 5 additions & 5 deletions cel/cel_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1475,7 +1475,7 @@ func TestEstimateCostAndRuntimeCost(t *testing.T) {
name string
expr string
decls []EnvOption
hints map[string]int64
hints map[string]uint64
want checker.CostEstimate
in any
}{
Expand All @@ -1499,7 +1499,7 @@ func TestEstimateCostAndRuntimeCost(t *testing.T) {
Variable("str1", StringType),
Variable("str2", StringType),
},
hints: map[string]int64{"str1": 10, "str2": 10},
hints: map[string]uint64{"str1": 10, "str2": 10},
want: checker.CostEstimate{Min: 2, Max: 6},
in: map[string]any{"str1": "val1111111", "str2": "val2222222"},
},
Expand All @@ -1510,7 +1510,7 @@ func TestEstimateCostAndRuntimeCost(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
if tc.hints == nil {
tc.hints = map[string]int64{}
tc.hints = map[string]uint64{}
}
env := testEnv(t, tc.decls...)
ast, iss := env.Compile(tc.expr)
Expand Down Expand Up @@ -2768,12 +2768,12 @@ func BenchmarkDynamicDispatch(b *testing.B) {

// TODO: ideally testCostEstimator and testRuntimeCostEstimator would be shared in a test fixtures package
type testCostEstimator struct {
hints map[string]int64
hints map[string]uint64
}

func (tc testCostEstimator) EstimateSize(element checker.AstNode) *checker.SizeEstimate {
if l, ok := tc.hints[strings.Join(element.Path(), ".")]; ok {
return &checker.SizeEstimate{Min: 0, Max: uint64(l)}
return &checker.SizeEstimate{Min: 0, Max: l}
}
return nil
}
Expand Down
10 changes: 9 additions & 1 deletion cel/env.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ type Env struct {
appliedFeatures map[int]bool
libraries map[string]bool
validators []ASTValidator
costOptions []checker.CostOption

// Internal parser representation
prsr *parser.Parser
Expand Down Expand Up @@ -191,6 +192,7 @@ func NewCustomEnv(opts ...EnvOption) (*Env, error) {
libraries: map[string]bool{},
validators: []ASTValidator{},
progOpts: []ProgramOption{},
costOptions: []checker.CostOption{},
}).configure(opts)
}

Expand Down Expand Up @@ -365,6 +367,8 @@ func (e *Env) Extend(opts ...EnvOption) (*Env, error) {
}
validatorsCopy := make([]ASTValidator, len(e.validators))
copy(validatorsCopy, e.validators)
costOptsCopy := make([]checker.CostOption, len(e.costOptions))
copy(costOptsCopy, e.costOptions)

ext := &Env{
Container: e.Container,
Expand All @@ -380,6 +384,7 @@ func (e *Env) Extend(opts ...EnvOption) (*Env, error) {
provider: provider,
chkOpts: chkOptsCopy,
prsrOpts: prsrOptsCopy,
costOptions: costOptsCopy,
}
return ext.configure(opts)
}
Expand Down Expand Up @@ -556,7 +561,10 @@ func (e *Env) ResidualAst(a *Ast, details *EvalDetails) (*Ast, error) {
// EstimateCost estimates the cost of a type checked CEL expression using the length estimates of input data and
// extension functions provided by estimator.
func (e *Env) EstimateCost(ast *Ast, estimator checker.CostEstimator, opts ...checker.CostOption) (checker.CostEstimate, error) {
return checker.Cost(ast.impl, estimator, opts...)
extendedOpts := make([]checker.CostOption, 0, len(e.costOptions))
extendedOpts = append(extendedOpts, opts...)
extendedOpts = append(extendedOpts, e.costOptions...)
return checker.Cost(ast.impl, estimator, extendedOpts...)
}

// configure applies a series of EnvOptions to the current environment.
Expand Down
19 changes: 19 additions & 0 deletions cel/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"google.golang.org/protobuf/reflect/protoregistry"
"google.golang.org/protobuf/types/dynamicpb"

"github.com/google/cel-go/checker"
"github.com/google/cel-go/common/containers"
"github.com/google/cel-go/common/functions"
"github.com/google/cel-go/common/types"
Expand Down Expand Up @@ -471,6 +472,24 @@ func InterruptCheckFrequency(checkFrequency uint) ProgramOption {
}
}

// CostEstimatorOptions configure type-check time options for estimating expression cost.
func CostEstimatorOptions(costOpts ...checker.CostOption) EnvOption {
return func(e *Env) (*Env, error) {
e.costOptions = append(e.costOptions, costOpts...)
return e, nil
}
}

// CostTrackerOptions configures a set of options for cost-tracking.
//
// Note, CostTrackerOptions is a no-op unless CostTracking is also enabled.
func CostTrackerOptions(costOpts ...interpreter.CostTrackerOption) ProgramOption {
return func(p *prog) (*prog, error) {
p.costOptions = append(p.costOptions, costOpts...)
return p, nil
}
}

// CostTracking enables cost tracking and registers a ActualCostEstimator that can optionally provide a runtime cost estimate for any function calls.
func CostTracking(costEstimator interpreter.ActualCostEstimator) ProgramOption {
return func(p *prog) (*prog, error) {
Expand Down
30 changes: 24 additions & 6 deletions cel/program.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ func (ed *EvalDetails) State() interpreter.EvalState {
// ActualCost returns the tracked cost through the course of execution when `CostTracking` is enabled.
// Otherwise, returns nil if the cost was not enabled.
func (ed *EvalDetails) ActualCost() *uint64 {
if ed.costTracker == nil {
if ed == nil || ed.costTracker == nil {
return nil
}
cost := ed.costTracker.ActualCost()
Expand All @@ -129,10 +129,14 @@ type prog struct {
// Interpretable configured from an Ast and aggregate decorator set based on program options.
interpretable interpreter.Interpretable
callCostEstimator interpreter.ActualCostEstimator
costOptions []interpreter.CostTrackerOption
costLimit *uint64
}

func (p *prog) clone() *prog {
costOptsCopy := make([]interpreter.CostTrackerOption, len(p.costOptions))
copy(costOptsCopy, p.costOptions)

return &prog{
Env: p.Env,
evalOpts: p.evalOpts,
Expand All @@ -154,9 +158,10 @@ func newProgram(e *Env, a *Ast, opts []ProgramOption) (Program, error) {
// Ensure the default attribute factory is set after the adapter and provider are
// configured.
p := &prog{
Env: e,
decorators: []interpreter.InterpretableDecorator{},
dispatcher: disp,
Env: e,
decorators: []interpreter.InterpretableDecorator{},
dispatcher: disp,
costOptions: []interpreter.CostTrackerOption{},
}

// Configure the program via the ProgramOption values.
Expand Down Expand Up @@ -213,6 +218,12 @@ func newProgram(e *Env, a *Ast, opts []ProgramOption) (Program, error) {
factory := func(state interpreter.EvalState, costTracker *interpreter.CostTracker) (Program, error) {
costTracker.Estimator = p.callCostEstimator
costTracker.Limit = p.costLimit
for _, costOpt := range p.costOptions {
err := costOpt(costTracker)
if err != nil {
return nil, err
}
}
// Limit capacity to guarantee a reallocation when calling 'append(decs, ...)' below. This
// prevents the underlying memory from being shared between factory function calls causing
// undesired mutations.
Expand Down Expand Up @@ -325,7 +336,11 @@ type progGen struct {
// the test is successful.
func newProgGen(factory progFactory) (Program, error) {
// Test the factory to make sure that configuration errors are spotted at config
_, err := factory(interpreter.NewEvalState(), &interpreter.CostTracker{})
tracker, err := interpreter.NewCostTracker(nil)
if err != nil {
return nil, err
}
_, err = factory(interpreter.NewEvalState(), tracker)
if err != nil {
return nil, err
}
Expand All @@ -338,7 +353,10 @@ func (gen *progGen) Eval(input any) (ref.Val, *EvalDetails, error) {
// new EvalState instance for each call to ensure that unique evaluations yield unique stateful
// results.
state := interpreter.NewEvalState()
costTracker := &interpreter.CostTracker{}
costTracker, err := interpreter.NewCostTracker(nil)
if err != nil {
return nil, nil, err
}
det := &EvalDetails{state: state, costTracker: costTracker}

// Generate a new instance of the interpretable using the factory configured during the call to
Expand Down
47 changes: 37 additions & 10 deletions checker/cost.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package checker

import (
"fmt"
"math"

"github.com/google/cel-go/common"
Expand Down Expand Up @@ -256,9 +257,10 @@ type coster struct {
// iterRanges tracks the iterRange of each iterVar.
iterRanges iterRangeScopes
// computedSizes tracks the computed sizes of call results.
computedSizes map[int64]SizeEstimate
checkedAST *ast.AST
estimator CostEstimator
computedSizes map[int64]SizeEstimate
checkedAST *ast.AST
estimator CostEstimator
functionEstimators map[string]FunctionEstimator
// presenceTestCost will either be a zero or one based on whether has() macros count against cost computations.
presenceTestCost CostEstimate
}
Expand Down Expand Up @@ -287,6 +289,7 @@ func (vs iterRangeScopes) peek(varName string) (int64, bool) {
type CostOption func(*coster) error

// PresenceTestHasCost determines whether presence testing has a cost of one or zero.
//
// Defaults to presence test has a cost of one.
func PresenceTestHasCost(hasCost bool) CostOption {
return func(c *coster) error {
Expand All @@ -299,15 +302,31 @@ func PresenceTestHasCost(hasCost bool) CostOption {
}
}

// FunctionEstimator provides a CallEstimate given the target and arguments for a specific function, overload pair.
type FunctionEstimator func(estimator CostEstimator, target *AstNode, args []AstNode) *CallEstimate

// FunctionCostEstimate binds a FunctionCoster to a specific function, overload pair.
//
// When a FunctionCostEstimate is provided, it will override the cost calculation of the CostEstimator provided to
// the Cost() call.
func FunctionCostEstimate(function, overloadID string, functionCoster FunctionEstimator) CostOption {
return func(c *coster) error {
functionKey := fmt.Sprintf("%s|%s", function, overloadID)
c.functionEstimators[functionKey] = functionCoster
return nil
}
}

// Cost estimates the cost of the parsed and type checked CEL expression.
func Cost(checked *ast.AST, estimator CostEstimator, opts ...CostOption) (CostEstimate, error) {
c := &coster{
checkedAST: checked,
estimator: estimator,
exprPath: map[int64][]string{},
iterRanges: map[string][]int64{},
computedSizes: map[int64]SizeEstimate{},
presenceTestCost: CostEstimate{Min: 1, Max: 1},
checkedAST: checked,
estimator: estimator,
functionEstimators: map[string]FunctionEstimator{},
exprPath: map[int64][]string{},
iterRanges: map[string][]int64{},
computedSizes: map[int64]SizeEstimate{},
presenceTestCost: CostEstimate{Min: 1, Max: 1},
}
for _, opt := range opts {
err := opt(c)
Expand Down Expand Up @@ -518,7 +537,15 @@ func (c *coster) functionCost(function, overloadID string, target *AstNode, args
}
return sum
}

if len(c.functionEstimators) != 0 {
functionKey := fmt.Sprintf("%s|%s", function, overloadID)
if estimator, found := c.functionEstimators[functionKey]; found {
if est := estimator(c.estimator, target, args); est != nil {
callEst := *est
return CallEstimate{CostEstimate: callEst.Add(argCostSum()), ResultSize: est.ResultSize}
}
}
}
if est := c.estimator.EstimateCallCost(function, overloadID, target, args); est != nil {
callEst := *est
return CallEstimate{CostEstimate: callEst.Add(argCostSum()), ResultSize: est.ResultSize}
Expand Down
Loading