Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DefaultTypeAdapter: Add support for missing custom scalars #893

Merged
merged 1 commit into from
Feb 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions common/types/int.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,18 @@ func (i Int) ConvertToNative(typeDesc reflect.Type) (any, error) {
return nil, err
}
return reflect.ValueOf(v).Convert(typeDesc).Interface(), nil
case reflect.Int8:
v, err := int64ToInt8Checked(int64(i))
if err != nil {
return nil, err
}
return reflect.ValueOf(v).Convert(typeDesc).Interface(), nil
case reflect.Int16:
v, err := int64ToInt16Checked(int64(i))
if err != nil {
return nil, err
}
return reflect.ValueOf(v).Convert(typeDesc).Interface(), nil
case reflect.Int64:
return reflect.ValueOf(i).Convert(typeDesc).Interface(), nil
case reflect.Ptr:
Expand Down
30 changes: 30 additions & 0 deletions common/types/int_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,36 @@ func TestIntConvertToNative_Error(t *testing.T) {
}
}

func TestIntConvertToNative_Int8(t *testing.T) {
val, err := Int(127).ConvertToNative(reflect.TypeOf(int8(0)))
if err != nil {
t.Fatalf("Int.ConvertToNative(int8) failed: %v", err)
}
if val.(int8) != 127 {
t.Errorf("Got '%v', expected 20050", val)
}
val, err = Int(math.MaxInt8 + 1).ConvertToNative(reflect.TypeOf(int8(0)))
if err == nil {
t.Errorf("(MaxInt+1).ConvertToNative(int8) did not error, got: %v", val)
} else if !strings.Contains(err.Error(), "integer overflow") {
t.Errorf("ConvertToNative(int8) returned unexpected error: %v, wanted integer overflow", err)
}
}
func TestIntConvertToNative_Int16(t *testing.T) {
val, err := Int(20050).ConvertToNative(reflect.TypeOf(int16(0)))
if err != nil {
t.Fatalf("Int.ConvertToNative(int16) failed: %v", err)
}
if val.(int16) != 20050 {
t.Errorf("Got '%v', expected 20050", val)
}
val, err = Int(math.MaxInt16 + 1).ConvertToNative(reflect.TypeOf(int16(0)))
if err == nil {
t.Errorf("(MaxInt+1).ConvertToNative(int16) did not error, got: %v", val)
} else if !strings.Contains(err.Error(), "integer overflow") {
t.Errorf("ConvertToNative(int32) returned unexpected error: %v, wanted integer overflow", err)
}
}
func TestIntConvertToNative_Int32(t *testing.T) {
val, err := Int(20050).ConvertToNative(reflect.TypeOf(int32(0)))
if err != nil {
Expand Down
40 changes: 40 additions & 0 deletions common/types/overflow.go
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,26 @@ func int64ToUint64Checked(v int64) (uint64, error) {
return uint64(v), nil
}

// int64ToInt8Checked converts an int64 to an int8 value.
//
// If the conversion fails due to overflow the error return value will be non-nil.
func int64ToInt8Checked(v int64) (int8, error) {
TristonianJones marked this conversation as resolved.
Show resolved Hide resolved
if v < math.MinInt8 || v > math.MaxInt8 {
return 0, errIntOverflow
}
return int8(v), nil
}

// int64ToInt16Checked converts an int64 to an int16 value.
//
// If the conversion fails due to overflow the error return value will be non-nil.
func int64ToInt16Checked(v int64) (int16, error) {
if v < math.MinInt16 || v > math.MaxInt16 {
return 0, errIntOverflow
}
return int16(v), nil
}

// int64ToInt32Checked converts an int64 to an int32 value.
//
// If the conversion fails due to overflow the error return value will be non-nil.
Expand All @@ -336,6 +356,26 @@ func int64ToInt32Checked(v int64) (int32, error) {
return int32(v), nil
}

// uint64ToUint8Checked converts a uint64 to a uint8 value.
//
// If the conversion fails due to overflow the error return value will be non-nil.
func uint64ToUint8Checked(v uint64) (uint8, error) {
if v > math.MaxUint8 {
return 0, errUintOverflow
}
return uint8(v), nil
}

// uint64ToUint16Checked converts a uint64 to a uint16 value.
//
// If the conversion fails due to overflow the error return value will be non-nil.
func uint64ToUint16Checked(v uint64) (uint16, error) {
if v > math.MaxUint16 {
return 0, errUintOverflow
}
return uint16(v), nil
}

// uint64ToUint32Checked converts a uint64 to a uint32 value.
//
// If the conversion fails due to overflow the error return value will be non-nil.
Expand Down
24 changes: 24 additions & 0 deletions common/types/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -590,12 +590,33 @@ func nativeToValue(a Adapter, value any) (ref.Val, bool) {
return NewDynamicMap(a, v), true
// type aliases of primitive types cannot be asserted as that type, but rather need
// to be downcast to int32 before being converted to a CEL representation.
case reflect.Bool:
boolTupe := reflect.TypeOf(false)
return Bool(refValue.Convert(boolTupe).Interface().(bool)), true
case reflect.Int:
intType := reflect.TypeOf(int(0))
return Int(refValue.Convert(intType).Interface().(int)), true
case reflect.Int8:
intType := reflect.TypeOf(int8(0))
return Int(refValue.Convert(intType).Interface().(int8)), true
case reflect.Int16:
intType := reflect.TypeOf(int16(0))
return Int(refValue.Convert(intType).Interface().(int16)), true
case reflect.Int32:
intType := reflect.TypeOf(int32(0))
return Int(refValue.Convert(intType).Interface().(int32)), true
case reflect.Int64:
intType := reflect.TypeOf(int64(0))
return Int(refValue.Convert(intType).Interface().(int64)), true
case reflect.Uint:
uintType := reflect.TypeOf(uint(0))
return Uint(refValue.Convert(uintType).Interface().(uint)), true
case reflect.Uint8:
uintType := reflect.TypeOf(uint8(0))
return Uint(refValue.Convert(uintType).Interface().(uint8)), true
case reflect.Uint16:
uintType := reflect.TypeOf(uint16(0))
return Uint(refValue.Convert(uintType).Interface().(uint16)), true
case reflect.Uint32:
uintType := reflect.TypeOf(uint32(0))
return Uint(refValue.Convert(uintType).Interface().(uint32)), true
Expand All @@ -608,6 +629,9 @@ func nativeToValue(a Adapter, value any) (ref.Val, bool) {
case reflect.Float64:
doubleType := reflect.TypeOf(float64(0))
return Double(refValue.Convert(doubleType).Interface().(float64)), true
case reflect.String:
stringType := reflect.TypeOf("")
return String(refValue.Convert(stringType).Interface().(string)), true
}
}
return nil, false
Expand Down
32 changes: 31 additions & 1 deletion common/types/provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -566,6 +566,21 @@ func TestConvertToNative(t *testing.T) {
// Proto conversion tests.
parsedExpr := &exprpb.ParsedExpr{}
expectValueToNative(t, reg.NativeToValue(parsedExpr), parsedExpr)

// Custom scalars
expectValueToNative(t, Int(1), testInt(1))
expectValueToNative(t, Int(1), testInt8(1))
expectValueToNative(t, Int(1), testInt16(1))
expectValueToNative(t, Int(1), testInt32(1))
expectValueToNative(t, Int(1), testInt64(1))
expectValueToNative(t, Uint(1), testUint(1))
expectValueToNative(t, Uint(1), testUint8(1))
expectValueToNative(t, Uint(1), testUint16(1))
expectValueToNative(t, Uint(1), testUint32(1))
expectValueToNative(t, Uint(1), testUint64(1))
expectValueToNative(t, Double(4.5), testFloat32(4.5))
expectValueToNative(t, Double(-5.1), testFloat64(-5.1))
expectValueToNative(t, String("foo"), testString("foo"))
}

func TestNativeToValue_Any(t *testing.T) {
Expand Down Expand Up @@ -758,12 +773,19 @@ func TestNativeToValue_Primitive(t *testing.T) {
expectNativeToValue(t, &rBytes, rBytes)

// Extensions to core types.
expectNativeToValue(t, testInt(1), Int(1))
TristonianJones marked this conversation as resolved.
Show resolved Hide resolved
expectNativeToValue(t, testInt8(1), Int(1))
expectNativeToValue(t, testInt16(1), Int(1))
expectNativeToValue(t, testInt32(1), Int(1))
expectNativeToValue(t, testInt64(-100), Int(-100))
expectNativeToValue(t, testUint(1), Uint(1))
expectNativeToValue(t, testUint8(1), Uint(1))
expectNativeToValue(t, testUint16(1), Uint(1))
expectNativeToValue(t, testUint32(2), Uint(2))
expectNativeToValue(t, testUint64(3), Uint(3))
expectNativeToValue(t, testFloat32(4.5), Double(4.5))
expectNativeToValue(t, testFloat64(-5.1), Double(-5.1))
expectNativeToValue(t, testString("foo"), String("foo"))

// Null conversion test.
expectNativeToValue(t, nil, NullValue)
Expand Down Expand Up @@ -795,7 +817,7 @@ func expectValueToNative(t *testing.T, in ref.Val, out any) {
}
if !equals {
t.Errorf("Unexpected conversion from expr to proto.\n"+
"expected: %T, actual: %T", val, out)
"expected: %T, actual: %T", out, val)
}
}
}
Expand Down Expand Up @@ -870,12 +892,20 @@ func BenchmarkTypeProviderCopy(b *testing.B) {
type nonConvertible struct {
Field string
}
type testBool bool
type testInt int
type testInt8 int8
type testInt16 int16
type testInt32 int32
type testInt64 int64
type testUint uint
type testUint8 uint8
type testUint16 uint16
type testUint32 uint32
type testUint64 uint64
type testFloat32 float32
type testFloat64 float64
type testString string

func newTestRegistry(t *testing.T, types ...proto.Message) *Registry {
t.Helper()
Expand Down
5 changes: 1 addition & 4 deletions common/types/string.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,7 @@ func (s String) Compare(other ref.Val) ref.Val {
func (s String) ConvertToNative(typeDesc reflect.Type) (any, error) {
switch typeDesc.Kind() {
case reflect.String:
if reflect.TypeOf(s).AssignableTo(typeDesc) {
return s, nil
}
return s.Value(), nil
return reflect.ValueOf(s).Convert(typeDesc).Interface(), nil
case reflect.Ptr:
switch typeDesc {
case anyValueType:
Expand Down
11 changes: 11 additions & 0 deletions common/types/string_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,17 @@ func TestStringConvertToNative_String(t *testing.T) {
}
}

type customString string

func TestStringConvertToNative_CustomString(t *testing.T) {
val, err := String("hello").ConvertToNative(reflect.TypeOf(customString("")))
if err != nil {
t.Error(err)
} else if v, ok := val.(customString); !ok || v != "hello" {
t.Errorf("Got %T with val '%v', expected %T with val 'hello'", val, v, customString(""))
}
}

func TestStringConvertToNative_Wrapper(t *testing.T) {
val, err := String("hello").ConvertToNative(stringWrapperType)
if err != nil {
Expand Down
12 changes: 12 additions & 0 deletions common/types/uint.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,18 @@ func (i Uint) ConvertToNative(typeDesc reflect.Type) (any, error) {
return 0, err
}
return reflect.ValueOf(v).Convert(typeDesc).Interface(), nil
case reflect.Uint8:
v, err := uint64ToUint8Checked(uint64(i))
if err != nil {
return 0, err
}
return reflect.ValueOf(v).Convert(typeDesc).Interface(), nil
case reflect.Uint16:
v, err := uint64ToUint16Checked(uint64(i))
if err != nil {
return 0, err
}
return reflect.ValueOf(v).Convert(typeDesc).Interface(), nil
case reflect.Uint64:
return reflect.ValueOf(i).Convert(typeDesc).Interface(), nil
case reflect.Ptr:
Expand Down
32 changes: 32 additions & 0 deletions common/types/uint_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,38 @@ func TestUintConvertToNative_Json(t *testing.T) {
}
}

func TestUintConvertToNative_Uint8(t *testing.T) {
val, err := Uint(128).ConvertToNative(reflect.TypeOf(uint8(0)))
if err != nil {
t.Fatalf("Uint.ConvertToNative(uint8) failed: %v", err)
}
if val.(uint8) != 128 {
t.Errorf("Got '%v', expected 128", val)
}
val, err = Uint(math.MaxUint8 + 1).ConvertToNative(reflect.TypeOf(uint8(0)))
if err == nil {
t.Errorf("(MaxUint+1).ConvertToNative(uint8) did not error, got: %v", val)
} else if !strings.Contains(err.Error(), "unsigned integer overflow") {
t.Errorf("ConvertToNative(uint8) returned unexpected error: %v, wanted unsigned integer overflow", err)
}
}

func TestUintConvertToNative_Uint16(t *testing.T) {
val, err := Uint(20050).ConvertToNative(reflect.TypeOf(uint16(0)))
if err != nil {
t.Fatalf("Uint.ConvertToNative(uint16) failed: %v", err)
}
if val.(uint16) != 20050 {
t.Errorf("Got '%v', expected 20050", val)
}
val, err = Uint(math.MaxUint16 + 1).ConvertToNative(reflect.TypeOf(uint16(0)))
if err == nil {
t.Errorf("(MaxUint+1).ConvertToNative(uint16) did not error, got: %v", val)
} else if !strings.Contains(err.Error(), "unsigned integer overflow") {
t.Errorf("ConvertToNative(uint16) returned unexpected error: %v, wanted unsigned integer overflow", err)
}
}

func TestUintConvertToNative_Uint32(t *testing.T) {
val, err := Uint(20050).ConvertToNative(reflect.TypeOf(uint32(0)))
if err != nil {
Expand Down