Skip to content

Commit

Permalink
Support for multi-mapped attribute trails
Browse files Browse the repository at this point in the history
Additional test cases provided for unknown string formatting
as well as for handling unknown field selections across numeric
types.

This change also revealed an issues with how custom attribute
factorie qualifiers were not being tracked correctly during
state-tracking, nor was the type information correctly collected
for array index operations.
  • Loading branch information
TristonianJones committed Jul 17, 2023
1 parent 744cd19 commit d54183c
Show file tree
Hide file tree
Showing 6 changed files with 297 additions and 71 deletions.
90 changes: 77 additions & 13 deletions common/types/unknown.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package types

import (
"fmt"
"math"
"reflect"
"strings"
"unicode"
Expand Down Expand Up @@ -51,13 +52,42 @@ func (a *AttributeTrail) Equal(other *AttributeTrail) bool {
}
for i, q := range a.QualifierPath() {
qual := other.QualifierPath()[i]
if q != qual {
if !qualifiersEqual(q, qual) {
return false
}
}
return true
}

func qualifiersEqual(a, b any) bool {
if a == b {
return true
}
switch numA := a.(type) {
case int64:
numB, ok := b.(uint64)
if !ok {
return false
}
return intUintEqual(numA, numB)
case uint64:
numB, ok := b.(int64)
if !ok {
return false
}
return intUintEqual(numB, numA)
default:
return false
}
}

func intUintEqual(i int64, u uint64) bool {
if i < 0 || u > math.MaxInt64 {
return false
}
return i == int64(u)
}

// Variable returns the variable name associated with the attribute.
func (a *AttributeTrail) Variable() string {
return a.variable
Expand Down Expand Up @@ -115,7 +145,7 @@ func QualifyAttribute[T AttributeQualifier](attr *AttributeTrail, qualifier T) *

// Unknown type which collects expression ids which caused the current value to become unknown.
type Unknown struct {
attributeTrails map[int64]*AttributeTrail
attributeTrails map[int64][]*AttributeTrail
}

// NewUnknown creates a new unknown at a given expression id for an attribute.
Expand All @@ -126,17 +156,29 @@ func NewUnknown(id int64, attr *AttributeTrail) *Unknown {
attr = unspecifiedAttribute
}
return &Unknown{
attributeTrails: map[int64]*AttributeTrail{id: attr},
attributeTrails: map[int64][]*AttributeTrail{id: {attr}},
}
}

// Contains returns true if the input unknown is a subset of the current unknown.
func (u *Unknown) Contains(other *Unknown) bool {
for id, trail := range other.attributeTrails {
t, found := u.attributeTrails[id]
if !found || !t.Equal(trail) {
for id, otherTrails := range other.attributeTrails {
trails, found := u.attributeTrails[id]
if !found || len(otherTrails) != len(trails) {
return false
}
for _, ot := range otherTrails {
found := false
for _, t := range trails {
if t.Equal(ot) {
found = true
break
}
}
if !found {
return false
}
}
}
return true
}
Expand All @@ -159,11 +201,15 @@ func (u *Unknown) Equal(other ref.Val) ref.Val {
// String implements the Stringer interface
func (u *Unknown) String() string {
var str strings.Builder
for id, attr := range u.attributeTrails {
for id, attrs := range u.attributeTrails {
if str.Len() != 0 {
str.WriteString(", ")
}
str.WriteString(fmt.Sprintf("%v (%d)", attr, id))
if len(attrs) == 1 {
str.WriteString(fmt.Sprintf("%v (%d)", attrs[0], id))
} else {
str.WriteString(fmt.Sprintf("%v (%d)", attrs, id))
}
}
return str.String()
}
Expand Down Expand Up @@ -214,13 +260,31 @@ func MergeUnknowns(unk1, unk2 *Unknown) *Unknown {
return unk1
}
out := &Unknown{
attributeTrails: make(map[int64]*AttributeTrail, len(unk1.attributeTrails)+len(unk2.attributeTrails)),
attributeTrails: make(map[int64][]*AttributeTrail, len(unk1.attributeTrails)+len(unk2.attributeTrails)),
}
for id, at := range unk1.attributeTrails {
out.attributeTrails[id] = at
for id, ats := range unk1.attributeTrails {
out.attributeTrails[id] = ats
}
for id, at := range unk2.attributeTrails {
out.attributeTrails[id] = at
for id, ats := range unk2.attributeTrails {
existing, found := out.attributeTrails[id]
if !found {
out.attributeTrails[id] = ats
continue
}

for _, at := range ats {
found := false
for _, et := range existing {
if at.Equal(et) {
found = true
break
}
}
if !found {
existing = append(existing, at)
}
}
out.attributeTrails[id] = existing
}
return out
}
76 changes: 72 additions & 4 deletions common/types/unknown_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ package types

import (
"fmt"
"math"
"strings"
"testing"

"github.com/google/cel-go/common/types/ref"
Expand Down Expand Up @@ -70,11 +72,51 @@ func TestAttributeEquals(t *testing.T) {
b: QualifyAttribute[int64](NewAttributeTrail("a"), 1),
equal: false,
},
{
a: QualifyAttribute[int64](NewAttributeTrail("a"), 1),
b: QualifyAttribute[string](NewAttributeTrail("a"), "1"),
equal: false,
},
{
a: QualifyAttribute[uint64](NewAttributeTrail("a"), 1),
b: QualifyAttribute[string](NewAttributeTrail("a"), "1"),
equal: false,
},
{
a: QualifyAttribute[string](NewAttributeTrail("a"), "b"),
b: QualifyAttribute[string](NewAttributeTrail("a"), "b"),
equal: true,
},
{
a: QualifyAttribute[int64](NewAttributeTrail("a"), 20),
b: QualifyAttribute[uint64](NewAttributeTrail("a"), 20),
equal: true,
},
{
a: QualifyAttribute[uint64](NewAttributeTrail("a"), 20),
b: QualifyAttribute[int64](NewAttributeTrail("a"), 20),
equal: true,
},
{
a: QualifyAttribute[uint64](NewAttributeTrail("a"), 21),
b: QualifyAttribute[int64](NewAttributeTrail("a"), 20),
equal: false,
},
{
a: QualifyAttribute[int64](NewAttributeTrail("a"), 20),
b: QualifyAttribute[uint64](NewAttributeTrail("a"), 21),
equal: false,
},
{
a: QualifyAttribute[int64](NewAttributeTrail("a"), -1),
b: QualifyAttribute[uint64](NewAttributeTrail("a"), 0),
equal: false,
},
{
a: QualifyAttribute[int64](NewAttributeTrail("a"), 1),
b: QualifyAttribute[uint64](NewAttributeTrail("a"), math.MaxInt64+1),
equal: false,
},
}
for i, tst := range tests {
tc := tst
Expand Down Expand Up @@ -189,7 +231,7 @@ func TestUnknownContains(t *testing.T) {
func TestUnknownString(t *testing.T) {
tests := []struct {
unk *Unknown
out string
out any
}{
{
unk: NewUnknown(1, nil),
Expand All @@ -212,16 +254,42 @@ func TestUnknownString(t *testing.T) {
NewUnknown(3, QualifyAttribute[bool](NewAttributeTrail("a"), true)),
NewUnknown(4, QualifyAttribute[string](NewAttributeTrail("a"), "b")),
),
out: "a[true] (3), a.b (4)",
out: []string{"a[true] (3)", "a.b (4)"},
},
{
// this case might occur in a logical condition where the attributes are equal.
unk: MergeUnknowns(
NewUnknown(3, QualifyAttribute[int64](NewAttributeTrail("a"), 0)),
NewUnknown(3, QualifyAttribute[int64](NewAttributeTrail("a"), 0)),
),
out: "a[0] (3)",
},
{
// this case might occur if attribute tracking through comprehensions is supported
unk: MergeUnknowns(
NewUnknown(3, QualifyAttribute[int64](NewAttributeTrail("a"), 0)),
NewUnknown(3, QualifyAttribute[int64](NewAttributeTrail("a"), 1)),
),
out: "[a[0] a[1]] (3)",
},
}
for i, tst := range tests {
tc := tst
t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
out := tc.unk.String()
if out != tc.out {
t.Errorf("%v.String() got %v, wanted %v", tc.unk, out, tc.out)
switch want := tc.out.(type) {
case string:
if out != want {
t.Errorf("%v.String() got %v, wanted %v", tc.unk, out, want)
}
case []string:
for _, w := range want {
if !strings.Contains(out, w) {
t.Errorf("%v.String() got %v, wanted it to contain %v", tc.unk, out, w)
}
}
}

})
}
}
Expand Down
16 changes: 11 additions & 5 deletions interpreter/attributes_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1191,20 +1191,30 @@ type custAttrFactory struct {

func (r *custAttrFactory) NewQualifier(objType *types.Type, qualID int64, val any, opt bool) (Qualifier, error) {
if objType.Kind() == types.StructKind && objType.TypeName() == "google.expr.proto3.test.TestAllTypes.NestedMessage" {
return &nestedMsgQualifier{id: qualID, field: val.(string)}, nil
switch v := val.(type) {
case string:
return &nestedMsgQualifier{id: qualID, field: v, opt: opt}, nil
case types.String:
return &nestedMsgQualifier{id: qualID, field: string(v), opt: opt}, nil
}
}
return r.AttributeFactory.NewQualifier(objType, qualID, val, opt)
}

type nestedMsgQualifier struct {
id int64
field string
opt bool
}

func (q *nestedMsgQualifier) ID() int64 {
return q.id
}

func (q *nestedMsgQualifier) IsOptional() bool {
return q.opt
}

func (q *nestedMsgQualifier) Qualify(vars Activation, obj any) (any, error) {
pb := obj.(*proto3pb.TestAllTypes_NestedMessage)
return pb.GetBb(), nil
Expand All @@ -1218,10 +1228,6 @@ func (q *nestedMsgQualifier) QualifyIfPresent(vars Activation, obj any, presence
return pb.GetBb(), true, nil
}

func (q *nestedMsgQualifier) IsOptional() bool {
return false
}

func addQualifier(t testing.TB, attr Attribute, qual Qualifier) Attribute {
t.Helper()
_, err := attr.AddQualifier(qual)
Expand Down
Loading

0 comments on commit d54183c

Please sign in to comment.