Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Estimate costs for list and string concatination and conditionals #487

Merged
merged 1 commit into from
Feb 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 33 additions & 6 deletions cel/cel_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ import (
"fmt"
"io/ioutil"
"log"
"math"
"reflect"
"strings"
"sync"
Expand Down Expand Up @@ -1254,10 +1253,10 @@ func (tc testCostEstimator) EstimateSize(element checker.AstNode) *checker.SizeE
return nil
}

func (tc testCostEstimator) EstimateCallCost(overloadId string, target *checker.AstNode, args []checker.AstNode) *checker.CostEstimate {
func (tc testCostEstimator) EstimateCallCost(overloadId string, target *checker.AstNode, args []checker.AstNode) *checker.CallEstimate {
switch overloadId {
case overloads.TimestampToYear:
return &checker.CostEstimate{Min: 7, Max: 7}
return &checker.CallEstimate{CostEstimate: checker.CostEstimate{Min: 7, Max: 7}}
}
return nil
}
Expand Down Expand Up @@ -1518,9 +1517,8 @@ func TestEstimateCost(t *testing.T) {
decls.NewVar("input1", allList),
decls.NewVar("input2", allList),
},
hints: map[string]int64{"input1": 1, "input2": 1},
// TODO: we don't track cost in the specific case
wanted: checker.CostEstimate{Min: 6, Max: math.MaxUint64},
hints: map[string]int64{"input1": 1, "input2": 1},
wanted: checker.CostEstimate{Min: 6, Max: 10},
},
{
name: "comprehension over map",
Expand Down Expand Up @@ -1558,6 +1556,35 @@ func TestEstimateCost(t *testing.T) {
hints: map[string]int64{"input": 2, "input.@values": 2, "input.@keys": 5},
wanted: checker.CostEstimate{Min: 2, Max: 42},
},
{
name: "comprehension variable shadowing",
program: `input.all(k, input[k].all(k, true) && k.contains(k))`,
decls: []*exprpb.Decl{
decls.NewVar("input", nestedMap),
},
hints: map[string]int64{"input": 2, "input.@values": 2, "input.@keys": 5},
wanted: checker.CostEstimate{Min: 2, Max: 42},
},
{
name: "list concat",
program: `(list1 + list2).all(x, true)`,
decls: []*exprpb.Decl{
decls.NewVar("list1", decls.NewListType(decls.Int)),
decls.NewVar("list2", decls.NewListType(decls.Int)),
},
hints: map[string]int64{"list1": 10, "list2": 10},
wanted: checker.CostEstimate{Min: 3, Max: 85},
},
{
name: "str concat",
program: `"abcdefg".contains(str1 + str2)`,
decls: []*exprpb.Decl{
decls.NewVar("str1", decls.String),
decls.NewVar("str2", decls.String),
},
hints: map[string]int64{"str1": 10, "str2": 10},
wanted: checker.CostEstimate{Min: 2, Max: 6},
},
}

for _, tc := range cases {
Expand Down
143 changes: 109 additions & 34 deletions checker/cost.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package checker

import (
Expand All @@ -32,7 +33,14 @@ type CostEstimator interface {
EstimateSize(element AstNode) *SizeEstimate
// EstimateCallCost returns the estimated cost of an invocation, or nil if
// the estimator has no estimate to provide.
EstimateCallCost(overloadId string, target *AstNode, args []AstNode) *CostEstimate
EstimateCallCost(overloadId string, target *AstNode, args []AstNode) *CallEstimate
}

// CallEstimate includes a CostEstimate for the call, and an optional estimate of the result object size.
// The ResultSize should only be provided if the call results in a map, list, string or bytes.
type CallEstimate struct {
CostEstimate
ResultSize *SizeEstimate
}

// AstNode represents an AST node for the purpose of cost estimations.
Expand All @@ -43,15 +51,19 @@ type AstNode interface {
Path() []string
// Type returns the deduced type of the AstNode.
Type() *exprpb.Type
// LiteralSize returns the size of the AstNode if it is a literal defined inline in CEL, or
// nil if the AstNodes size is not statically known.
LiteralSize() *uint64
// Expr returns the expression of the AstNode.
Expr() *exprpb.Expr
// ComputedSize returns a size estimate of the AstNode derived from information available in the CEL expression.
// For constants and inline list and map declarations, the exact size is returned. For concatenated list, strings
// and bytes, the size is derived from the size estimates of the operands.
ComputedSize() *SizeEstimate
}

type astNode struct {
path []string
t *exprpb.Type
expr *exprpb.Expr
path []string
t *exprpb.Type
expr *exprpb.Expr
derivedSize *SizeEstimate
}

func (e astNode) Path() []string {
Expand All @@ -62,7 +74,14 @@ func (e astNode) Type() *exprpb.Type {
return e.t
}

func (e astNode) LiteralSize() *uint64 {
func (e astNode) Expr() *exprpb.Expr {
return e.expr
}

func (e astNode) ComputedSize() *SizeEstimate {
if e.derivedSize != nil {
return e.derivedSize
}
var v uint64
switch ek := e.expr.ExprKind.(type) {
case *exprpb.Expr_ConstExpr:
Expand All @@ -84,15 +103,24 @@ func (e astNode) LiteralSize() *uint64 {
return nil
}

return &v
return &SizeEstimate{Min: v, Max: v}
}

// SizeEstimate represents an estimated size of a variable length string, bytes, map or list.
type SizeEstimate struct {
Min, Max uint64
}

// Multiply multiplies by another SizeEstimate and returns the sum.
// Add adds to another SizeEstimate and returns the sum.
// If add would result in an uint64 overflow, the result is Maxuint64.
func (se SizeEstimate) Add(sizeEstimate SizeEstimate) SizeEstimate {
return SizeEstimate{
addUint64NoOverflow(se.Min, sizeEstimate.Min),
addUint64NoOverflow(se.Max, sizeEstimate.Max),
}
}

// Multiply multiplies by another SizeEstimate and returns the product.
// If multiply would result in an uint64 overflow, the result is Maxuint64.
func (se SizeEstimate) Multiply(sizeEstimate SizeEstimate) SizeEstimate {
return SizeEstimate{
Expand All @@ -110,7 +138,7 @@ func (se SizeEstimate) MultiplyByCostFactor(costPerUnit float64) CostEstimate {
}
}

// MultiplyByCost multiplies by the cost and returns the sum.
// MultiplyByCost multiplies by the cost and returns the product.
// If multiply would result in an uint64 overflow, the result is Maxuint64.
func (se SizeEstimate) MultiplyByCost(cost CostEstimate) CostEstimate {
return CostEstimate{
Expand All @@ -119,6 +147,18 @@ func (se SizeEstimate) MultiplyByCost(cost CostEstimate) CostEstimate {
}
}

// Union returns a SizeEstimate that encompasses both input the SizeEstimate.
func (se SizeEstimate) Union(size SizeEstimate) SizeEstimate {
result := se
if size.Min < result.Min {
result.Min = size.Min
}
if size.Max > result.Max {
result.Max = size.Max
}
return result
}

// CostEstimate represents an estimated cost range and provides add and multiply operations
// that do not overflow.
type CostEstimate struct {
Expand All @@ -134,7 +174,7 @@ func (ce CostEstimate) Add(cost CostEstimate) CostEstimate {
}
}

// Multiply multiplies by the cost and returns the sum.
// Multiply multiplies by the cost and returns the product.
// If multiply would result in an uint64 overflow, the result is Maxuint64.
func (ce CostEstimate) Multiply(cost CostEstimate) CostEstimate {
return CostEstimate{
Expand All @@ -152,6 +192,18 @@ func (ce CostEstimate) MultiplyByCostFactor(costPerUnit float64) CostEstimate {
}
}

// Union returns a CostEstimate that encompasses both input the CostEstimates.
func (ce CostEstimate) Union(size CostEstimate) CostEstimate {
result := ce
if size.Min < result.Min {
result.Min = size.Min
}
if size.Max > result.Max {
result.Max = size.Max
}
return result
}

// addUint64NoOverflow adds non-negative ints. If the result is exceeds math.MaxUint64, math.MaxUint64
// is returned.
func addUint64NoOverflow(x, y uint64) uint64 {
Expand Down Expand Up @@ -193,9 +245,11 @@ type coster struct {
// exprPath maps from Expr Id to field path.
exprPath map[int64][]string
// iterRanges tracks the iterRange of each iterVar.
iterRanges iterRangeScopes
checkedExpr *exprpb.CheckedExpr
estimator CostEstimator
iterRanges iterRangeScopes
// computedSizes tracks the computed sizes of call results.
computedSizes map[int64]SizeEstimate
checkedExpr *exprpb.CheckedExpr
estimator CostEstimator
}

// Use a stack of iterVar -> iterRange Expr Ids to handle shadowed variable names.
Expand All @@ -221,10 +275,11 @@ func (vs iterRangeScopes) peek(varName string) (int64, bool) {
// Cost estimates the cost of the parsed and type checked CEL expression.
func Cost(checker *exprpb.CheckedExpr, estimator CostEstimator) CostEstimate {
c := coster{
checkedExpr: checker,
estimator: estimator,
exprPath: map[int64][]string{},
iterRanges: map[string][]int64{},
checkedExpr: checker,
estimator: estimator,
exprPath: map[int64][]string{},
iterRanges: map[string][]int64{},
computedSizes: map[int64]SizeEstimate{},
}
return c.cost(checker.GetExpr())
}
Expand Down Expand Up @@ -303,7 +358,6 @@ func (c *coster) costCall(e *exprpb.Expr) CostEstimate {
for i, arg := range args {
// TODO: && || operators short circuit, so min cost should only include 1st arg eval
// unless exhaustive evaluation is enabled
// TODO: ternary operator also short circuits, Min cost should be cond + min(a, b) within <cond> ? a : b
sum = sum.Add(c.cost(arg))
argTypes[i] = c.newAstNode(arg)
}
Expand All @@ -321,13 +375,17 @@ func (c *coster) costCall(e *exprpb.Expr) CostEstimate {
}
// Pick a cost estimate range that covers all the overload cost estimation ranges
fnCost := CostEstimate{Min: uint64(math.MaxUint64), Max: 0}
var resultSize *SizeEstimate
for _, overload := range ref.GetOverloadId() {
overloadCost := c.functionCost(overload, &targetType, argTypes)
if overloadCost.Max > fnCost.Max {
fnCost.Max = overloadCost.Max
}
if overloadCost.Min < fnCost.Min {
fnCost.Min = overloadCost.Min
fnCost = fnCost.Union(overloadCost.CostEstimate)
if overloadCost.ResultSize != nil {
if resultSize == nil {
resultSize = overloadCost.ResultSize
} else {
size := resultSize.Union(*overloadCost.ResultSize)
resultSize = &size
}
}
// build and track the field path for index operations
switch overload {
Expand All @@ -341,6 +399,9 @@ func (c *coster) costCall(e *exprpb.Expr) CostEstimate {
}
}
}
if resultSize != nil {
c.computedSizes[e.GetId()] = *resultSize
}
return sum.Add(fnCost)
}

Expand Down Expand Up @@ -405,30 +466,30 @@ func (c *coster) costComprehension(e *exprpb.Expr) CostEstimate {
}

func (c *coster) sizeEstimate(t AstNode) SizeEstimate {
if l := t.LiteralSize(); l != nil {
return SizeEstimate{Min: *l, Max: *l}
if l := t.ComputedSize(); l != nil {
return *l
}
if l := c.estimator.EstimateSize(t); l != nil {
return *l
}
return SizeEstimate{Min: 0, Max: math.MaxUint64}
}

func (c *coster) functionCost(overloadId string, target *AstNode, args []AstNode) CostEstimate {
func (c *coster) functionCost(overloadId string, target *AstNode, args []AstNode) CallEstimate {
if est := c.estimator.EstimateCallCost(overloadId, target, args); est != nil {
return *est
}
switch overloadId {
// O(n) functions
case overloads.StartsWithString, overloads.EndsWithString, overloads.StringToBytes, overloads.BytesToString:
if len(args) == 1 {
return c.sizeEstimate(args[0]).MultiplyByCostFactor(0.1)
return CallEstimate{CostEstimate: c.sizeEstimate(args[0]).MultiplyByCostFactor(0.1)}
}
case overloads.InList:
// If a list is composed entirely of constant values this is O(1), but we don't account for that here.
// We just assume all list containment checks are O(n).
if len(args) == 2 {
return c.sizeEstimate(args[1]).MultiplyByCostFactor(1)
return CallEstimate{CostEstimate: c.sizeEstimate(args[1]).MultiplyByCostFactor(1)}
}
// O(nm) functions
case overloads.MatchesString:
Expand All @@ -441,19 +502,29 @@ func (c *coster) functionCost(overloadId string, target *AstNode, args []AstNode
// For now, we're making a guess that each expression in a regex is typically at least 4 chars
// in length.
regexCost := c.sizeEstimate(args[0]).MultiplyByCostFactor(0.25)
return strCost.Multiply(regexCost)
return CallEstimate{CostEstimate: strCost.Multiply(regexCost)}
}
case overloads.ContainsString:
if target != nil && len(args) == 1 {
strCost := c.sizeEstimate(*target).MultiplyByCostFactor(0.1)
substrCost := c.sizeEstimate(args[0]).MultiplyByCostFactor(0.1)
return strCost.Multiply(substrCost)
return CallEstimate{CostEstimate: strCost.Multiply(substrCost)}
}
case overloads.Conditional:
size := c.sizeEstimate(args[1]).Union(c.sizeEstimate(args[2]))
return CallEstimate{CostEstimate: CostEstimate{Min: 1, Max: 1}, ResultSize: &size}
case overloads.AddString, overloads.AddBytes, overloads.AddList:
if len(args) == 2 {
lhsSize := c.sizeEstimate(args[0])
rhsSize := c.sizeEstimate(args[1])
resultSize := lhsSize.Add(rhsSize)
return CallEstimate{CostEstimate: resultSize.MultiplyByCostFactor(0.1), ResultSize: &resultSize}
}
}
// O(1) functions
// Benchmarks suggest that most of the other operations take +/- 50% of a base cost unit
// which on an Intel xeon 2.20GHz CPU is 50ns.
return CostEstimate{Min: 1, Max: 1}
return CallEstimate{CostEstimate: CostEstimate{Min: 1, Max: 1}}
}

func (c *coster) getType(e *exprpb.Expr) *exprpb.Type {
Expand All @@ -474,5 +545,9 @@ func (c *coster) newAstNode(e *exprpb.Expr) *astNode {
// only provide paths to root vars; omit accumulator vars
path = nil
}
return &astNode{path: path, t: c.getType(e), expr: e}
var derivedSize *SizeEstimate
if size, ok := c.computedSizes[e.GetId()]; ok {
derivedSize = &size
}
return &astNode{path: path, t: c.getType(e), expr: e, derivedSize: derivedSize}
}