From d54183c3b340777606b6486ecc16a0bc7257b8b2 Mon Sep 17 00:00:00 2001 From: TristonianJones Date: Mon, 17 Jul 2023 13:38:09 -0700 Subject: [PATCH] Support for multi-mapped attribute trails Additional test cases provided for unknown string formatting as well as for handling unknown field selections across numeric types. This change also revealed an issues with how custom attribute factorie qualifiers were not being tracked correctly during state-tracking, nor was the type information correctly collected for array index operations. --- common/types/unknown.go | 90 ++++++++++++++++++++++---- common/types/unknown_test.go | 76 ++++++++++++++++++++-- interpreter/attributes_test.go | 16 +++-- interpreter/interpretable.go | 73 +++++++++++++++++++-- interpreter/interpreter_test.go | 111 ++++++++++++++++++++------------ interpreter/planner.go | 2 +- 6 files changed, 297 insertions(+), 71 deletions(-) diff --git a/common/types/unknown.go b/common/types/unknown.go index bd15da1b0..4b6c80221 100644 --- a/common/types/unknown.go +++ b/common/types/unknown.go @@ -16,6 +16,7 @@ package types import ( "fmt" + "math" "reflect" "strings" "unicode" @@ -51,13 +52,42 @@ func (a *AttributeTrail) Equal(other *AttributeTrail) bool { } for i, q := range a.QualifierPath() { qual := other.QualifierPath()[i] - if q != qual { + if !qualifiersEqual(q, qual) { return false } } return true } +func qualifiersEqual(a, b any) bool { + if a == b { + return true + } + switch numA := a.(type) { + case int64: + numB, ok := b.(uint64) + if !ok { + return false + } + return intUintEqual(numA, numB) + case uint64: + numB, ok := b.(int64) + if !ok { + return false + } + return intUintEqual(numB, numA) + default: + return false + } +} + +func intUintEqual(i int64, u uint64) bool { + if i < 0 || u > math.MaxInt64 { + return false + } + return i == int64(u) +} + // Variable returns the variable name associated with the attribute. func (a *AttributeTrail) Variable() string { return a.variable @@ -115,7 +145,7 @@ func QualifyAttribute[T AttributeQualifier](attr *AttributeTrail, qualifier T) * // Unknown type which collects expression ids which caused the current value to become unknown. type Unknown struct { - attributeTrails map[int64]*AttributeTrail + attributeTrails map[int64][]*AttributeTrail } // NewUnknown creates a new unknown at a given expression id for an attribute. @@ -126,17 +156,29 @@ func NewUnknown(id int64, attr *AttributeTrail) *Unknown { attr = unspecifiedAttribute } return &Unknown{ - attributeTrails: map[int64]*AttributeTrail{id: attr}, + attributeTrails: map[int64][]*AttributeTrail{id: {attr}}, } } // Contains returns true if the input unknown is a subset of the current unknown. func (u *Unknown) Contains(other *Unknown) bool { - for id, trail := range other.attributeTrails { - t, found := u.attributeTrails[id] - if !found || !t.Equal(trail) { + for id, otherTrails := range other.attributeTrails { + trails, found := u.attributeTrails[id] + if !found || len(otherTrails) != len(trails) { return false } + for _, ot := range otherTrails { + found := false + for _, t := range trails { + if t.Equal(ot) { + found = true + break + } + } + if !found { + return false + } + } } return true } @@ -159,11 +201,15 @@ func (u *Unknown) Equal(other ref.Val) ref.Val { // String implements the Stringer interface func (u *Unknown) String() string { var str strings.Builder - for id, attr := range u.attributeTrails { + for id, attrs := range u.attributeTrails { if str.Len() != 0 { str.WriteString(", ") } - str.WriteString(fmt.Sprintf("%v (%d)", attr, id)) + if len(attrs) == 1 { + str.WriteString(fmt.Sprintf("%v (%d)", attrs[0], id)) + } else { + str.WriteString(fmt.Sprintf("%v (%d)", attrs, id)) + } } return str.String() } @@ -214,13 +260,31 @@ func MergeUnknowns(unk1, unk2 *Unknown) *Unknown { return unk1 } out := &Unknown{ - attributeTrails: make(map[int64]*AttributeTrail, len(unk1.attributeTrails)+len(unk2.attributeTrails)), + attributeTrails: make(map[int64][]*AttributeTrail, len(unk1.attributeTrails)+len(unk2.attributeTrails)), } - for id, at := range unk1.attributeTrails { - out.attributeTrails[id] = at + for id, ats := range unk1.attributeTrails { + out.attributeTrails[id] = ats } - for id, at := range unk2.attributeTrails { - out.attributeTrails[id] = at + for id, ats := range unk2.attributeTrails { + existing, found := out.attributeTrails[id] + if !found { + out.attributeTrails[id] = ats + continue + } + + for _, at := range ats { + found := false + for _, et := range existing { + if at.Equal(et) { + found = true + break + } + } + if !found { + existing = append(existing, at) + } + } + out.attributeTrails[id] = existing } return out } diff --git a/common/types/unknown_test.go b/common/types/unknown_test.go index 1a698263e..d0d83d6a4 100644 --- a/common/types/unknown_test.go +++ b/common/types/unknown_test.go @@ -16,6 +16,8 @@ package types import ( "fmt" + "math" + "strings" "testing" "github.com/google/cel-go/common/types/ref" @@ -70,11 +72,51 @@ func TestAttributeEquals(t *testing.T) { b: QualifyAttribute[int64](NewAttributeTrail("a"), 1), equal: false, }, + { + a: QualifyAttribute[int64](NewAttributeTrail("a"), 1), + b: QualifyAttribute[string](NewAttributeTrail("a"), "1"), + equal: false, + }, + { + a: QualifyAttribute[uint64](NewAttributeTrail("a"), 1), + b: QualifyAttribute[string](NewAttributeTrail("a"), "1"), + equal: false, + }, { a: QualifyAttribute[string](NewAttributeTrail("a"), "b"), b: QualifyAttribute[string](NewAttributeTrail("a"), "b"), equal: true, }, + { + a: QualifyAttribute[int64](NewAttributeTrail("a"), 20), + b: QualifyAttribute[uint64](NewAttributeTrail("a"), 20), + equal: true, + }, + { + a: QualifyAttribute[uint64](NewAttributeTrail("a"), 20), + b: QualifyAttribute[int64](NewAttributeTrail("a"), 20), + equal: true, + }, + { + a: QualifyAttribute[uint64](NewAttributeTrail("a"), 21), + b: QualifyAttribute[int64](NewAttributeTrail("a"), 20), + equal: false, + }, + { + a: QualifyAttribute[int64](NewAttributeTrail("a"), 20), + b: QualifyAttribute[uint64](NewAttributeTrail("a"), 21), + equal: false, + }, + { + a: QualifyAttribute[int64](NewAttributeTrail("a"), -1), + b: QualifyAttribute[uint64](NewAttributeTrail("a"), 0), + equal: false, + }, + { + a: QualifyAttribute[int64](NewAttributeTrail("a"), 1), + b: QualifyAttribute[uint64](NewAttributeTrail("a"), math.MaxInt64+1), + equal: false, + }, } for i, tst := range tests { tc := tst @@ -189,7 +231,7 @@ func TestUnknownContains(t *testing.T) { func TestUnknownString(t *testing.T) { tests := []struct { unk *Unknown - out string + out any }{ { unk: NewUnknown(1, nil), @@ -212,16 +254,42 @@ func TestUnknownString(t *testing.T) { NewUnknown(3, QualifyAttribute[bool](NewAttributeTrail("a"), true)), NewUnknown(4, QualifyAttribute[string](NewAttributeTrail("a"), "b")), ), - out: "a[true] (3), a.b (4)", + out: []string{"a[true] (3)", "a.b (4)"}, + }, + { + // this case might occur in a logical condition where the attributes are equal. + unk: MergeUnknowns( + NewUnknown(3, QualifyAttribute[int64](NewAttributeTrail("a"), 0)), + NewUnknown(3, QualifyAttribute[int64](NewAttributeTrail("a"), 0)), + ), + out: "a[0] (3)", + }, + { + // this case might occur if attribute tracking through comprehensions is supported + unk: MergeUnknowns( + NewUnknown(3, QualifyAttribute[int64](NewAttributeTrail("a"), 0)), + NewUnknown(3, QualifyAttribute[int64](NewAttributeTrail("a"), 1)), + ), + out: "[a[0] a[1]] (3)", }, } for i, tst := range tests { tc := tst t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { out := tc.unk.String() - if out != tc.out { - t.Errorf("%v.String() got %v, wanted %v", tc.unk, out, tc.out) + switch want := tc.out.(type) { + case string: + if out != want { + t.Errorf("%v.String() got %v, wanted %v", tc.unk, out, want) + } + case []string: + for _, w := range want { + if !strings.Contains(out, w) { + t.Errorf("%v.String() got %v, wanted it to contain %v", tc.unk, out, w) + } + } } + }) } } diff --git a/interpreter/attributes_test.go b/interpreter/attributes_test.go index 10e34f452..e70f6ebaf 100644 --- a/interpreter/attributes_test.go +++ b/interpreter/attributes_test.go @@ -1191,7 +1191,12 @@ type custAttrFactory struct { func (r *custAttrFactory) NewQualifier(objType *types.Type, qualID int64, val any, opt bool) (Qualifier, error) { if objType.Kind() == types.StructKind && objType.TypeName() == "google.expr.proto3.test.TestAllTypes.NestedMessage" { - return &nestedMsgQualifier{id: qualID, field: val.(string)}, nil + switch v := val.(type) { + case string: + return &nestedMsgQualifier{id: qualID, field: v, opt: opt}, nil + case types.String: + return &nestedMsgQualifier{id: qualID, field: string(v), opt: opt}, nil + } } return r.AttributeFactory.NewQualifier(objType, qualID, val, opt) } @@ -1199,12 +1204,17 @@ func (r *custAttrFactory) NewQualifier(objType *types.Type, qualID int64, val an type nestedMsgQualifier struct { id int64 field string + opt bool } func (q *nestedMsgQualifier) ID() int64 { return q.id } +func (q *nestedMsgQualifier) IsOptional() bool { + return q.opt +} + func (q *nestedMsgQualifier) Qualify(vars Activation, obj any) (any, error) { pb := obj.(*proto3pb.TestAllTypes_NestedMessage) return pb.GetBb(), nil @@ -1218,10 +1228,6 @@ func (q *nestedMsgQualifier) QualifyIfPresent(vars Activation, obj any, presence return pb.GetBb(), true, nil } -func (q *nestedMsgQualifier) IsOptional() bool { - return false -} - func addQualifier(t testing.TB, attr Attribute, qual Qualifier) Attribute { t.Helper() _, err := attr.AddQualifier(qual) diff --git a/interpreter/interpretable.go b/interpreter/interpretable.go index f1f899a57..12e158149 100644 --- a/interpreter/interpretable.go +++ b/interpreter/interpretable.go @@ -861,18 +861,40 @@ type evalWatchAttr struct { // AddQualifier creates a wrapper over the incoming qualifier which observes the qualification // result. func (e *evalWatchAttr) AddQualifier(q Qualifier) (Attribute, error) { - cq, isConst := q.(ConstantQualifier) - if isConst { + switch qual := q.(type) { + // By default, the qualifier is either a constant or an attribute + // There may be some custom cases where the attribute is neither. + case ConstantQualifier: + // Expose a method to test whether the qualifier matches the input pattern. q = &evalWatchConstQual{ - ConstantQualifier: cq, + ConstantQualifier: qual, observer: e.observer, - adapter: e.InterpretableAttribute.Adapter(), + adapter: e.Adapter(), } - } else { + case *evalWatchAttr: + // Unwrap the evalWatchAttr since the observation will be applied during Qualify or + // QualifyIfPresent rather than Eval. + q = &evalWatchAttrQual{ + Attribute: qual.InterpretableAttribute, + observer: e.observer, + adapter: e.Adapter(), + } + case Attribute: + // Expose methods which intercept the qualification prior to being applied as a qualifier. + // Using this interface ensures that the qualifier is converted to a constant value one + // time during attribute pattern matching as the method embeds the Attribute interface + // needed to trip the conversion to a constant. + q = &evalWatchAttrQual{ + Attribute: qual, + observer: e.observer, + adapter: e.Adapter(), + } + default: + // This is likely a custom qualifier type. q = &evalWatchQual{ - Qualifier: q, + Qualifier: qual, observer: e.observer, - adapter: e.InterpretableAttribute.Adapter(), + adapter: e.Adapter(), } } _, err := e.InterpretableAttribute.AddQualifier(q) @@ -930,6 +952,43 @@ func (e *evalWatchConstQual) QualifierValueEquals(value any) bool { return ok && qve.QualifierValueEquals(value) } +// evalWatchAttrQual observes the qualification of an object by a value computed at runtime. +type evalWatchAttrQual struct { + Attribute + observer EvalObserver + adapter ref.TypeAdapter +} + +// Qualify observes the qualification of a object via a value computed at runtime. +func (e *evalWatchAttrQual) Qualify(vars Activation, obj any) (any, error) { + out, err := e.Attribute.Qualify(vars, obj) + var val ref.Val + if err != nil { + val = types.WrapErr(err) + } else { + val = e.adapter.NativeToValue(out) + } + e.observer(e.ID(), e.Attribute, val) + return out, err +} + +// QualifyIfPresent conditionally qualifies the variable and only records a value if one is present. +func (e *evalWatchAttrQual) QualifyIfPresent(vars Activation, obj any, presenceOnly bool) (any, bool, error) { + out, present, err := e.Attribute.QualifyIfPresent(vars, obj, presenceOnly) + var val ref.Val + if err != nil { + val = types.WrapErr(err) + } else if out != nil { + val = e.adapter.NativeToValue(out) + } else if presenceOnly { + val = types.Bool(present) + } + if present || presenceOnly { + e.observer(e.ID(), e.Attribute, val) + } + return out, present, err +} + // evalWatchQual observes the qualification of an object by a value computed at runtime. type evalWatchQual struct { Qualifier diff --git a/interpreter/interpreter_test.go b/interpreter/interpreter_test.go index e4061138b..51c7453f0 100644 --- a/interpreter/interpreter_test.go +++ b/interpreter/interpreter_test.go @@ -59,7 +59,7 @@ type testCase struct { unchecked bool extraOpts []InterpretableDecorator - in map[string]any + in any out any err string progErr string @@ -1215,8 +1215,8 @@ func testData(t testing.TB) []testCase { attrs: &custAttrFactory{ AttributeFactory: NewAttributeFactory( testContainer("google.expr.proto3.test"), - types.DefaultTypeAdapter, - types.NewEmptyRegistry(), + newTestRegistry(t, &proto3pb.TestAllTypes_NestedMessage{}), + newTestRegistry(t, &proto3pb.TestAllTypes_NestedMessage{}), ), }, in: map[string]any{ @@ -1226,6 +1226,29 @@ func testData(t testing.TB) []testCase { }, out: types.True, }, + { + name: "select_custom_pb3_optional_field", + expr: `a.?bb`, + container: "google.expr.proto3.test", + types: []proto.Message{&proto3pb.TestAllTypes_NestedMessage{}}, + vars: []*decls.VariableDecl{ + decls.NewVariable("a", + types.NewObjectType("google.expr.proto3.test.TestAllTypes.NestedMessage")), + }, + attrs: &custAttrFactory{ + AttributeFactory: NewAttributeFactory( + testContainer("google.expr.proto3.test"), + newTestRegistry(t, &proto3pb.TestAllTypes_NestedMessage{}), + newTestRegistry(t, &proto3pb.TestAllTypes_NestedMessage{}), + ), + }, + in: map[string]any{ + "a": &proto3pb.TestAllTypes_NestedMessage{ + Bb: 101, + }, + }, + out: types.OptionalOf(types.Int(101)), + }, { name: "select_relative", expr: `json('{"hi":"world"}').hi == 'world'`, @@ -1391,6 +1414,36 @@ func testData(t testing.TB) []testCase { unchecked: true, err: `no such attribute(s): goog.pkg.mylistundef, pkg.mylistundef`, }, + { + name: "unknown_attribute", + expr: `a[0]`, + vars: []*decls.VariableDecl{ + decls.NewVariable("a", + types.NewMapType(types.IntType, types.BoolType)), + }, + attrs: NewPartialAttributeFactory(testContainer(""), types.DefaultTypeAdapter, types.NewEmptyRegistry()), + in: newTestPartialActivation(t, map[string]any{ + "a": map[int64]any{ + 1: true, + }, + }, NewAttributePattern("a").QualInt(0)), + out: types.NewUnknown(2, types.QualifyAttribute[int64](types.NewAttributeTrail("a"), 0)), + }, + { + name: "unknown_attribute_mixed_qualifier", + expr: `a[dyn(0u)]`, + vars: []*decls.VariableDecl{ + decls.NewVariable("a", + types.NewMapType(types.IntType, types.BoolType)), + }, + attrs: NewPartialAttributeFactory(testContainer(""), types.DefaultTypeAdapter, types.NewEmptyRegistry()), + in: newTestPartialActivation(t, map[string]any{ + "a": map[int64]any{ + 1: true, + }, + }, NewAttributePattern("a").QualInt(0)), + out: types.NewUnknown(2, types.QualifyAttribute[uint64](types.NewAttributeTrail("a"), 0)), + }, } } @@ -1965,7 +2018,10 @@ func program(ctx testing.TB, tst *testCase, opts ...InterpretableDecorator) (Int // Configure the program input. vars := EmptyActivation() if tst.in != nil { - vars, _ = NewActivation(tst.in) + vars, err = NewActivation(tst.in) + if err != nil { + ctx.Fatalf("NewActivation(%v) failed: %v", tst.in, err) + } } // Adapt the test output, if needed. if tst.out != nil { @@ -2069,7 +2125,7 @@ func newTestEnv(t testing.TB, cont *containers.Container, reg ref.TypeRegistry) return env } -func newTestRegistry(t *testing.T, msgs ...proto.Message) ref.TypeRegistry { +func newTestRegistry(t testing.TB, msgs ...proto.Message) ref.TypeRegistry { t.Helper() reg, err := types.NewRegistry(msgs...) if err != nil { @@ -2078,6 +2134,15 @@ func newTestRegistry(t *testing.T, msgs ...proto.Message) ref.TypeRegistry { return reg } +func newTestPartialActivation(t testing.TB, in any, unknowns ...*AttributePattern) any { + t.Helper() + vars, err := NewPartialActivation(in, unknowns...) + if err != nil { + t.Fatalf("NewPartialActivation(%v) failed: %v", in, err) + } + return vars +} + // newStandardInterpreter builds a Dispatcher and TypeProvider with support for all of the CEL // builtins defined in the language definition. func newStandardInterpreter(t *testing.T, @@ -2126,39 +2191,3 @@ func funcBindings(t testing.TB, funcs ...*decls.FunctionDecl) []*functions.Overl } return bindings } - -func funcExprDecl(t testing.TB, fn *decls.FunctionDecl) *exprpb.Decl { - t.Helper() - d, err := decls.FunctionDeclToExprDecl(fn) - if err != nil { - t.Fatalf("decls.FunctionDeclToExprDecl(%v) failed: %v", fn, err) - } - return d -} - -func funcExprDecls(t testing.TB, funcs ...*decls.FunctionDecl) []*exprpb.Decl { - t.Helper() - d := make([]*exprpb.Decl, 0, len(funcs)) - for _, fn := range funcs { - d = append(d, funcExprDecl(t, fn)) - } - return d -} - -func varExprDecl(t testing.TB, v *decls.VariableDecl) *exprpb.Decl { - t.Helper() - d, err := decls.VariableDeclToExprDecl(v) - if err != nil { - t.Fatalf("decls.VariableDeclToExprDecl(%v) failed: %v", v, err) - } - return d -} - -func varExprDecls(t testing.TB, vars ...*decls.VariableDecl) []*exprpb.Decl { - t.Helper() - d := make([]*exprpb.Decl, 0, len(vars)) - for _, v := range vars { - d = append(d, varExprDecl(t, v)) - } - return d -} diff --git a/interpreter/planner.go b/interpreter/planner.go index 04771a3bf..ae1a6bb21 100644 --- a/interpreter/planner.go +++ b/interpreter/planner.go @@ -473,7 +473,7 @@ func (p *planner) planCallConditional(expr *exprpb.Expr, args []Interpretable) ( func (p *planner) planCallIndex(expr *exprpb.Expr, args []Interpretable, optional bool) (Interpretable, error) { op := args[0] ind := args[1] - opType := p.typeMap[expr.GetCallExpr().GetTarget().GetId()] + opType := p.typeMap[op.ID()] // Establish the attribute reference. var err error