diff --git a/cel/decls.go b/cel/decls.go index 3055fef8..c4601af8 100644 --- a/cel/decls.go +++ b/cel/decls.go @@ -21,6 +21,7 @@ import ( "github.com/google/cel-go/common/decls" "github.com/google/cel-go/common/functions" "github.com/google/cel-go/common/types" + "github.com/google/cel-go/common/types/ref" exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1" ) @@ -110,21 +111,21 @@ var ( // function references for instantiating new types. // ListType creates an instances of a list type value with the provided element type. - ListType = decls.ListType + ListType = decls.NewListType // MapType creates an instance of a map type value with the provided key and value types. - MapType = decls.MapType + MapType = decls.NewMapType // NullableType creates an instance of a nullable type with the provided wrapped type. // // Note: only primitive types are supported as wrapped types. - NullableType = decls.NullableType + NullableType = decls.NewNullableType // OptionalType creates an abstract parameterized type instance corresponding to CEL's notion of optional. - OptionalType = decls.OptionalType + OptionalType = decls.NewOptionalType // OpaqueType creates an abstract parameterized type with a given name. - OpaqueType = decls.OpaqueType + OpaqueType = decls.NewOpaqueType // ObjectType creates a type references to an externally defined type, e.g. a protobuf message type. - ObjectType = decls.ObjectType + ObjectType = decls.NewObjectType // TypeParamType creates a parameterized type instance. - TypeParamType = decls.TypeParamType + TypeParamType = decls.NewTypeParamType ) // Type holds a reference to a runtime type with an optional type-checked set of type parameters. @@ -338,7 +339,7 @@ func ExprDeclToDeclaration(d *exprpb.Decl) (EnvOption, error) { } } -func typeValueToKind(tv *types.TypeValue) (Kind, error) { +func typeValueToKind(tv ref.Type) (Kind, error) { switch tv { case types.BoolType: return BoolKind, nil diff --git a/cel/io.go b/cel/io.go index 93ded3cf..9e32ede8 100644 --- a/cel/io.go +++ b/cel/io.go @@ -202,7 +202,7 @@ func RefValueToValue(res ref.Val) (*exprpb.Value, error) { } var ( - typeNameToTypeValue = map[string]*types.TypeValue{ + typeNameToTypeValue = map[string]ref.Val{ "bool": types.BoolType, "bytes": types.BytesType, "double": types.DoubleType, diff --git a/cel/program.go b/cel/program.go index 91ac6552..f7d2d626 100644 --- a/cel/program.go +++ b/cel/program.go @@ -208,9 +208,9 @@ func newProgram(e *Env, ast *Ast, opts []ProgramOption) (Program, error) { } // Enable compile-time checking of syntax/cardinality for string.format calls. if p.evalOpts&OptCheckStringFormat == OptCheckStringFormat { - var isValidType func(id int64, validTypes ...*types.TypeValue) (bool, error) + var isValidType func(id int64, validTypes ...ref.Type) (bool, error) if ast.IsChecked() { - isValidType = func(id int64, validTypes ...*types.TypeValue) (bool, error) { + isValidType = func(id int64, validTypes ...ref.Type) (bool, error) { t, err := ExprTypeToType(ast.typeMap[id]) if err != nil { return false, err @@ -231,7 +231,7 @@ func newProgram(e *Env, ast *Ast, opts []ProgramOption) (Program, error) { } } else { // if the AST isn't type-checked, short-circuit validation - isValidType = func(id int64, validTypes ...*types.TypeValue) (bool, error) { + isValidType = func(id int64, validTypes ...ref.Type) (bool, error) { return true, nil } } diff --git a/common/decls/decls.go b/common/decls/decls.go index cde2d43b..8472ffe0 100644 --- a/common/decls/decls.go +++ b/common/decls/decls.go @@ -664,6 +664,11 @@ func VariableDeclToExprDecl(v *VariableDecl) (*exprpb.Decl, error) { return chkdecls.NewVar(v.Name, varType), nil } +// TypeVariable creates a new type identifier for use within a ref.TypeProvider +func TypeVariable(t *Type) *VariableDecl { + return NewVariable(t.TypeName(), NewTypeTypeWithParam(t)) +} + // FunctionDeclToExprDecl converts a go-native function declaration into a protobuf-typed function declaration. func FunctionDeclToExprDecl(f *FunctionDecl) (*exprpb.Decl, error) { overloads := make([]*exprpb.Decl_FunctionDecl_Overload, len(f.Overloads)) @@ -707,7 +712,7 @@ func FunctionDeclToExprDecl(f *FunctionDecl) (*exprpb.Decl, error) { func collectParamNames(paramNames map[string]struct{}, arg *Type) { if arg.Kind == TypeParamKind { - paramNames[arg.RuntimeTypeName()] = struct{}{} + paramNames[arg.TypeName()] = struct{}{} } for _, param := range arg.Parameters { collectParamNames(paramNames, param) diff --git a/common/decls/decls_test.go b/common/decls/decls_test.go index e59bba12..9ae1ff94 100644 --- a/common/decls/decls_test.go +++ b/common/decls/decls_test.go @@ -31,7 +31,7 @@ import ( func TestFunctionBindings(t *testing.T) { sizeFunc, err := NewFunction("size", - MemberOverload("list_size", []*Type{ListType(TypeParamType("T"))}, IntType), + MemberOverload("list_size", []*Type{NewListType(NewTypeParamType("T"))}, IntType), ) if err != nil { t.Fatalf("NewFunction() failed: %v", err) @@ -44,7 +44,7 @@ func TestFunctionBindings(t *testing.T) { t.Errorf("sizeFunc.Bindings() produced %d bindings, wanted none", len(bindings)) } sizeFuncDef, err := NewFunction("size", - MemberOverload("list_size", []*Type{ListType(TypeParamType("T"))}, IntType, + MemberOverload("list_size", []*Type{NewListType(NewTypeParamType("T"))}, IntType, UnaryBinding(func(list ref.Val) ref.Val { sizer := list.(traits.Sizer) return sizer.Size() @@ -86,18 +86,18 @@ func TestFunctionVariableArgBindings(t *testing.T) { return types.DefaultTypeAdapter.NativeToValue(strings.SplitN(str, delim, int(count))) } splitFunc, err := NewFunction("split", - MemberOverload("string_split", []*Type{StringType}, ListType(StringType), + MemberOverload("string_split", []*Type{StringType}, NewListType(StringType), UnaryBinding(func(str ref.Val) ref.Val { s := str.(types.String) return splitImpl(string(s), "", -1) })), - MemberOverload("string_split_string", []*Type{StringType, StringType}, ListType(StringType), + MemberOverload("string_split_string", []*Type{StringType, StringType}, NewListType(StringType), BinaryBinding(func(str, sep ref.Val) ref.Val { s := str.(types.String) delim := sep.(types.String) return splitImpl(string(s), string(delim), -1) })), - MemberOverload("string_split_string_int", []*Type{StringType, StringType, IntType}, ListType(StringType), + MemberOverload("string_split_string_int", []*Type{StringType, StringType, IntType}, NewListType(StringType), FunctionBinding(func(args ...ref.Val) ref.Val { s := args[0].(types.String) delim := args[1].(types.String) @@ -196,11 +196,11 @@ func TestFunctionSingletonBinding(t *testing.T) { // doesn't actually give much additional benefit. The drawback is that invalid signatures // at type-check might be valid at runtime. DisableTypeGuards(true), - Overload("size_map", []*Type{MapType(TypeParamType("K"), TypeParamType("V"))}, IntType), - Overload("size_list", []*Type{ListType(TypeParamType("V"))}, IntType), + Overload("size_map", []*Type{NewMapType(NewTypeParamType("K"), NewTypeParamType("V"))}, IntType), + Overload("size_list", []*Type{NewListType(NewTypeParamType("V"))}, IntType), Overload("size_string", []*Type{StringType}, IntType), - MemberOverload("map_size", []*Type{MapType(TypeParamType("K"), TypeParamType("V"))}, IntType), - MemberOverload("list_size", []*Type{ListType(TypeParamType("V"))}, IntType), + MemberOverload("map_size", []*Type{NewMapType(NewTypeParamType("K"), NewTypeParamType("V"))}, IntType), + MemberOverload("list_size", []*Type{NewListType(NewTypeParamType("V"))}, IntType), MemberOverload("string_size", []*Type{StringType}, IntType), SingletonUnaryBinding(func(arg ref.Val) ref.Val { return arg.(traits.Sizer).Size() @@ -232,8 +232,8 @@ func TestFunctionSingletonBinding(t *testing.T) { func TestFunctionMerge(t *testing.T) { sizeFunc, err := NewFunction("size", - MemberOverload("list_size", []*Type{ListType(TypeParamType("T"))}, IntType), - MemberOverload("map_size", []*Type{MapType(TypeParamType("K"), TypeParamType("V"))}, IntType), + MemberOverload("list_size", []*Type{NewListType(NewTypeParamType("T"))}, IntType), + MemberOverload("map_size", []*Type{NewMapType(NewTypeParamType("K"), NewTypeParamType("V"))}, IntType), ) if err != nil { t.Fatalf("NewFunction() failed: %v", err) @@ -246,7 +246,7 @@ func TestFunctionMerge(t *testing.T) { t.Errorf("sizeFunc.Merge(sizeFunc) != sizeFunc: %v", out) } sizeVecFunc, err := NewFunction("size", - MemberOverload("vector_size", []*Type{OpaqueType("vector", TypeParamType("T"))}, IntType), + MemberOverload("vector_size", []*Type{NewOpaqueType("vector", NewTypeParamType("T"))}, IntType), SingletonUnaryBinding(func(sizer ref.Val) ref.Val { return sizer.(traits.Sizer).Size() }, traits.SizerType), @@ -279,13 +279,13 @@ func TestFunctionMerge(t *testing.T) { func TestFunctionMergeWrongName(t *testing.T) { sizeFunc, err := NewFunction("size", - MemberOverload("list_size", []*Type{ListType(TypeParamType("T"))}, IntType), + MemberOverload("list_size", []*Type{NewListType(NewTypeParamType("T"))}, IntType), ) if err != nil { t.Fatalf("NewFunction() failed: %v", err) } sizeVecFunc, err := NewFunction("sizeN", - MemberOverload("vector_size", []*Type{OpaqueType("vector", TypeParamType("T"))}, IntType), + MemberOverload("vector_size", []*Type{NewOpaqueType("vector", NewTypeParamType("T"))}, IntType), ) if err != nil { t.Fatalf("NewFunction() failed: %v", err) @@ -298,13 +298,13 @@ func TestFunctionMergeWrongName(t *testing.T) { func TestFunctionMergeOverloadCollision(t *testing.T) { sizeFunc, err := NewFunction("size", - MemberOverload("list_size", []*Type{ListType(TypeParamType("T"))}, IntType), + MemberOverload("list_size", []*Type{NewListType(NewTypeParamType("T"))}, IntType), ) if err != nil { t.Fatalf("NewFunction() failed: %v", err) } sizeVecFunc, err := NewFunction("size", - MemberOverload("list_size2", []*Type{ListType(TypeParamType("K"))}, IntType), + MemberOverload("list_size2", []*Type{NewListType(NewTypeParamType("K"))}, IntType), ) if err != nil { t.Fatalf("NewFunction() failed: %v", err) @@ -317,13 +317,13 @@ func TestFunctionMergeOverloadCollision(t *testing.T) { func TestFunctionMergeOverloadArgCountRedefinition(t *testing.T) { sizeFunc, err := NewFunction("size", - MemberOverload("list_size", []*Type{ListType(TypeParamType("T"))}, IntType), + MemberOverload("list_size", []*Type{NewListType(NewTypeParamType("T"))}, IntType), ) if err != nil { t.Fatalf("NewFunction() failed: %v", err) } sizeVecFunc, err := NewFunction("size", - MemberOverload("list_size", []*Type{ListType(TypeParamType("T")), IntType}, IntType), + MemberOverload("list_size", []*Type{NewListType(NewTypeParamType("T")), IntType}, IntType), ) if err != nil { t.Fatalf("NewFunction() failed: %v", err) @@ -336,13 +336,13 @@ func TestFunctionMergeOverloadArgCountRedefinition(t *testing.T) { func TestFunctionMergeOverloadArgTypeRedefinition(t *testing.T) { sizeFunc, err := NewFunction("size", - MemberOverload("arg_size", []*Type{ListType(TypeParamType("T"))}, IntType), + MemberOverload("arg_size", []*Type{NewListType(NewTypeParamType("T"))}, IntType), ) if err != nil { t.Fatalf("NewFunction() failed: %v", err) } sizeVecFunc, err := NewFunction("size", - MemberOverload("arg_size", []*Type{MapType(IntType, StringType)}, IntType), + MemberOverload("arg_size", []*Type{NewMapType(IntType, StringType)}, IntType), ) if err != nil { t.Fatalf("NewFunction() failed: %v", err) @@ -355,7 +355,7 @@ func TestFunctionMergeOverloadArgTypeRedefinition(t *testing.T) { func TestFunctionMergeSingletonRedefinition(t *testing.T) { sizeFunc, err := NewFunction("size", - MemberOverload("list_size", []*Type{ListType(TypeParamType("T"))}, IntType), + MemberOverload("list_size", []*Type{NewListType(NewTypeParamType("T"))}, IntType), SingletonUnaryBinding(func(ref.Val) ref.Val { return types.IntZero }), @@ -549,12 +549,12 @@ func TestOverloadFunctionBindingRedefinition(t *testing.T) { func TestOverloadIsNonStrict(t *testing.T) { fn, err := NewFunction("getOrDefault", MemberOverload("get", - []*Type{MapType( - TypeParamType("K"), TypeParamType("V")), - TypeParamType("K"), - TypeParamType("V"), + []*Type{NewMapType( + NewTypeParamType("K"), NewTypeParamType("V")), + NewTypeParamType("K"), + NewTypeParamType("V"), }, - TypeParamType("V"), + NewTypeParamType("V"), OverloadOperandTrait(traits.ContainerType|traits.IndexerType), OverloadIsNonStrict(), FunctionBinding(func(args ...ref.Val) ref.Val { @@ -596,12 +596,12 @@ func TestOverloadIsNonStrict(t *testing.T) { func TestOverloadOperandTrait(t *testing.T) { fn, err := NewFunction("getOrDefault", MemberOverload("get", - []*Type{MapType( - TypeParamType("K"), TypeParamType("V")), - TypeParamType("K"), - TypeParamType("V"), + []*Type{NewMapType( + NewTypeParamType("K"), NewTypeParamType("V")), + NewTypeParamType("K"), + NewTypeParamType("V"), }, - TypeParamType("V"), + NewTypeParamType("V"), OverloadOperandTrait(traits.ContainerType|traits.IndexerType), FunctionBinding(func(args ...ref.Val) ref.Val { container := args[0].(traits.Container) @@ -641,7 +641,7 @@ func TestFunctionDisableDeclaration(t *testing.T) { fn, err := NewFunction("in", DisableDeclaration(true), Overload("in_list", - []*Type{ListType(TypeParamType("K")), TypeParamType("K")}, + []*Type{NewListType(NewTypeParamType("K")), NewTypeParamType("K")}, BoolType, ), ) @@ -657,7 +657,7 @@ func TestFunctionEnableDeclaration(t *testing.T) { fn, err := NewFunction("in", DisableDeclaration(false), Overload("in_list", - []*Type{ListType(TypeParamType("K")), TypeParamType("K")}, + []*Type{NewListType(NewTypeParamType("K")), NewTypeParamType("K")}, BoolType), ) if err != nil { @@ -669,7 +669,7 @@ func TestFunctionEnableDeclaration(t *testing.T) { fn2, err := NewFunction("in", DisableDeclaration(true), Overload("in_list", - []*Type{ListType(TypeParamType("K")), TypeParamType("K")}, + []*Type{NewListType(NewTypeParamType("K")), NewTypeParamType("K")}, BoolType), ) if err != nil { @@ -704,7 +704,7 @@ func TestFunctionDeclToExprDecl(t *testing.T) { }{ { fn: testFunction(t, "equals", - Overload("equals_value_value", []*Type{TypeParamType("T"), TypeParamType("T")}, BoolType)), + Overload("equals_value_value", []*Type{NewTypeParamType("T"), NewTypeParamType("T")}, BoolType)), exDecl: &exprpb.Decl{ Name: "equals", DeclKind: &exprpb.Decl_Function{ @@ -726,7 +726,7 @@ func TestFunctionDeclToExprDecl(t *testing.T) { }, { fn: testFunction(t, "equals", - MemberOverload("value_equals_value", []*Type{TypeParamType("T"), TypeParamType("T")}, BoolType)), + MemberOverload("value_equals_value", []*Type{NewTypeParamType("T"), NewTypeParamType("T")}, BoolType)), exDecl: &exprpb.Decl{ Name: "equals", DeclKind: &exprpb.Decl_Function{ @@ -793,8 +793,8 @@ func TestFunctionDeclToExprDecl(t *testing.T) { { fn: testFunction(t, "equals", MemberOverload("list_optional_value_equals_list_optional_value", []*Type{ - ListType(OptionalType(TypeParamType("T"))), - ListType(OptionalType(TypeParamType("T"))), + NewListType(NewOptionalType(NewTypeParamType("T"))), + NewListType(NewOptionalType(NewTypeParamType("T"))), }, BoolType)), exDecl: &exprpb.Decl{ Name: "equals", diff --git a/common/decls/types.go b/common/decls/types.go index 0970102f..f03e0efa 100644 --- a/common/decls/types.go +++ b/common/decls/types.go @@ -16,6 +16,7 @@ package decls import ( "fmt" + "reflect" "strings" chkdecls "github.com/google/cel-go/checker/decls" @@ -48,6 +49,9 @@ const ( // DurationKind represents a CEL duration type. DurationKind + // ErrorKind represents a CEL error type. + ErrorKind + // IntKind represents an integer type. IntKind @@ -80,71 +84,131 @@ const ( // UintKind represents a uint type. UintKind + + // UnknownKind represents an unknown value type. + UnknownKind ) var ( // AnyType represents the google.protobuf.Any type. AnyType = &Type{ - Kind: AnyKind, - runtimeType: types.NewTypeValue("google.protobuf.Any"), + Kind: AnyKind, + runtimeTypeName: "google.protobuf.Any", + traitMask: traits.FieldTesterType | + traits.IndexerType, } // BoolType represents the bool type. BoolType = &Type{ - Kind: BoolKind, - runtimeType: types.BoolType, + Kind: BoolKind, + runtimeTypeName: "bool", + traitMask: traits.ComparerType | + traits.NegatorType, } // BytesType represents the bytes type. BytesType = &Type{ - Kind: BytesKind, - runtimeType: types.BytesType, + Kind: BytesKind, + runtimeTypeName: "bytes", + traitMask: traits.AdderType | + traits.ComparerType | + traits.SizerType, } // DoubleType represents the double type. DoubleType = &Type{ - Kind: DoubleKind, - runtimeType: types.DoubleType, + Kind: DoubleKind, + runtimeTypeName: "double", + traitMask: traits.AdderType | + traits.ComparerType | + traits.DividerType | + traits.MultiplierType | + traits.NegatorType | + traits.SubtractorType, } // DurationType represents the CEL duration type. DurationType = &Type{ - Kind: DurationKind, - runtimeType: types.DurationType, + Kind: DurationKind, + runtimeTypeName: "google.protobuf.Duration", + traitMask: traits.AdderType | + traits.ComparerType | + traits.NegatorType | + traits.ReceiverType | + traits.SubtractorType, } // DynType represents a dynamic CEL type whose type will be determined at runtime from context. DynType = &Type{ - Kind: DynKind, - runtimeType: types.NewTypeValue("dyn"), + Kind: DynKind, + runtimeTypeName: "dyn", + } + // ErrorType represents a CEL error value. + ErrorType = &Type{ + Kind: ErrorKind, + runtimeTypeName: "error", } // IntType represents the int type. IntType = &Type{ - Kind: IntKind, - runtimeType: types.IntType, - } + Kind: IntKind, + runtimeTypeName: "int", + traitMask: traits.AdderType | + traits.ComparerType | + traits.DividerType | + traits.ModderType | + traits.MultiplierType | + traits.NegatorType | + traits.SubtractorType, + } + // ListType represents the runtime list type. + ListType = NewListType(nil) + // MapType represents the runtime map type. + MapType = NewMapType(nil, nil) // NullType represents the type of a null value. NullType = &Type{ - Kind: NullTypeKind, - runtimeType: types.NullType, + Kind: NullTypeKind, + runtimeTypeName: "null_type", } // StringType represents the string type. StringType = &Type{ - Kind: StringKind, - runtimeType: types.StringType, + Kind: StringKind, + runtimeTypeName: "string", + traitMask: traits.AdderType | + traits.ComparerType | + traits.MatcherType | + traits.ReceiverType | + traits.SizerType, } // TimestampType represents the time type. TimestampType = &Type{ - Kind: TimestampKind, - runtimeType: types.TimestampType, + Kind: TimestampKind, + runtimeTypeName: "google.protobuf.Timestamp", + traitMask: traits.AdderType | + traits.ComparerType | + traits.ReceiverType | + traits.SubtractorType, } // TypeType represents a CEL type TypeType = &Type{ - Kind: TypeKind, - runtimeType: types.TypeType, + Kind: TypeKind, + runtimeTypeName: "type", } // UintType represents a uint type. UintType = &Type{ - Kind: UintKind, - runtimeType: types.UintType, + Kind: UintKind, + runtimeTypeName: "uint", + traitMask: traits.AdderType | + traits.ComparerType | + traits.DividerType | + traits.ModderType | + traits.MultiplierType | + traits.SubtractorType, + } + // UnknownType represents an unknown value type. + UnknownType = &Type{ + Kind: UnknownKind, + runtimeTypeName: "unknown", } ) +var _ ref.Type = &Type{} +var _ ref.Val = &Type{} + // Type holds a reference to a runtime type with an optional type-checked set of type parameters. type Type struct { // Kind indicates general category of the type. @@ -153,8 +217,8 @@ type Type struct { // Parameters holds the optional type-checked set of type Parameters that are used during static analysis. Parameters []*Type - // runtimeType is the runtime type of the declaration. - runtimeType ref.Type + // runtimeTypeName indicates the runtime type name of the type. + runtimeTypeName string // isAssignableType function determines whether one type is assignable to this type. // A nil value for the isAssignableType function falls back to equality of kind, runtimeType, and parameters. @@ -163,6 +227,39 @@ type Type struct { // isAssignableRuntimeType function determines whether the runtime type (with erasure) is assignable to this type. // A nil value for the isAssignableRuntimeType function falls back to the equality of the type or type name. isAssignableRuntimeType func(other ref.Val) bool + + // traitMask is a mask of flags which indicate the capabilities of the type. + traitMask int +} + +// ConvertToNative implements ref.Val.ConvertToNative. +func (t *Type) ConvertToNative(typeDesc reflect.Type) (any, error) { + return nil, fmt.Errorf("type conversion not supported for 'type'") +} + +// ConvertToType implements ref.Val.ConvertToType. +func (t *Type) ConvertToType(typeVal ref.Type) ref.Val { + switch typeVal { + case TypeType: + return TypeType + case StringType: + return types.String(t.TypeName()) + } + return types.NewErr("type conversion error from '%s' to '%s'", TypeType, typeVal) +} + +// Equal indicates whether two types have the same runtime type name. +// +// The name Equal is a bit of a misnomer, but for historical reasons, this is the +// runtime behavior. For a more accurate definition see IsType(). +func (t *Type) Equal(other ref.Val) ref.Val { + otherType, ok := other.(ref.Type) + return types.Bool(ok && t.TypeName() == otherType.TypeName()) +} + +// HasTrait implements the ref.Type interface method. +func (t *Type) HasTrait(trait int) bool { + return trait&t.traitMask == trait } // IsType indicates whether two types have the same kind, type name, and parameters. @@ -200,26 +297,30 @@ func (t *Type) IsAssignableRuntimeType(val ref.Val) bool { return t.defaultIsAssignableRuntimeType(val) } -// DeclaredTypeName indicates the type-check type name associated with the type. +// DeclaredTypeName indicates the fully qualified and parameterized type-check type name. func (t *Type) DeclaredTypeName() string { // if the type itself is neither null, nor dyn, but is assignable to null, then it's a wrapper type. if t.Kind != NullTypeKind && !t.isDyn() && t.IsAssignableType(NullType) { - return fmt.Sprintf("wrapper(%s)", t.RuntimeTypeName()) + return fmt.Sprintf("wrapper(%s)", t.TypeName()) } - return t.RuntimeTypeName() + return t.TypeName() } -// RuntimeTypeName indicates the type-erased type name associated with the type. -func (t *Type) RuntimeTypeName() string { - if t.runtimeType == nil { - return "" - } - return t.runtimeType.TypeName() +// Type implements the ref.Val interface method. +func (t *Type) Type() ref.Type { + return TypeType +} + +// Value implements the ref.Val interface method. +func (t *Type) Value() any { + return t.TypeName() } -// TypeVariable creates a new type identifier for use within a ref.TypeProvider -func (t *Type) TypeVariable() *VariableDecl { - return NewVariable(t.RuntimeTypeName(), TypeTypeWithParam(t)) +// TypeName returns the type-erased fully qualified runtime type name. +// +// TypeName implements the ref.Type interface method. +func (t *Type) TypeName() string { + return t.runtimeTypeName } // String returns a human-readable definition of the type name. @@ -251,7 +352,7 @@ func (t *Type) defaultIsAssignableType(fromType *Type) bool { return true } if t.Kind != fromType.Kind || - t.runtimeType.TypeName() != fromType.runtimeType.TypeName() || + t.TypeName() != fromType.TypeName() || len(t.Parameters) != len(fromType.Parameters) { return false } @@ -268,22 +369,21 @@ func (t *Type) defaultIsAssignableType(fromType *Type) bool { // to determine whether a ref.Val is assignable to the declared type for a function signature. func (t *Type) defaultIsAssignableRuntimeType(val ref.Val) bool { valType := val.Type() - if !(t.runtimeType == valType || t.isDyn() || t.runtimeType.TypeName() == valType.TypeName()) { + // If the current type and value type don't agree, then return + if !(t.isDyn() || t.TypeName() == valType.TypeName()) { return false } - switch t.runtimeType { - case types.ListType: + switch t.Kind { + case ListKind: elemType := t.Parameters[0] l := val.(traits.Lister) if l.Size() == types.IntZero { return true } it := l.Iterator() - for it.HasNext() == types.True { - elemVal := it.Next() - return elemType.IsAssignableRuntimeType(elemVal) - } - case types.MapType: + elemVal := it.Next() + return elemType.IsAssignableRuntimeType(elemVal) + case MapKind: keyType := t.Parameters[0] elemType := t.Parameters[1] m := val.(traits.Mapper) @@ -291,41 +391,57 @@ func (t *Type) defaultIsAssignableRuntimeType(val ref.Val) bool { return true } it := m.Iterator() - for it.HasNext() == types.True { - keyVal := it.Next() - elemVal := m.Get(keyVal) - return keyType.IsAssignableRuntimeType(keyVal) && elemType.IsAssignableRuntimeType(elemVal) - } + keyVal := it.Next() + elemVal := m.Get(keyVal) + return keyType.IsAssignableRuntimeType(keyVal) && elemType.IsAssignableRuntimeType(elemVal) } return true } -// ListType creates an instances of a list type value with the provided element type. -func ListType(elemType *Type) *Type { - return &Type{ - Kind: ListKind, - runtimeType: types.ListType, - Parameters: []*Type{elemType}, - } +// NewListType creates an instances of a list type value with the provided element type. +func NewListType(elemType *Type) *Type { + t := &Type{ + Kind: ListKind, + Parameters: []*Type{}, + runtimeTypeName: "list", + traitMask: traits.AdderType | + traits.ContainerType | + traits.IndexerType | + traits.IterableType | + traits.SizerType, + } + if elemType != nil { + t.Parameters = append(t.Parameters, elemType) + } + return t } -// MapType creates an instance of a map type value with the provided key and value types. -func MapType(keyType, valueType *Type) *Type { - return &Type{ - Kind: MapKind, - runtimeType: types.MapType, - Parameters: []*Type{keyType, valueType}, - } +// NewMapType creates an instance of a map type value with the provided key and value types. +func NewMapType(keyType, valueType *Type) *Type { + t := &Type{ + Kind: MapKind, + Parameters: []*Type{}, + runtimeTypeName: "map", + traitMask: traits.ContainerType | + traits.IndexerType | + traits.IterableType | + traits.SizerType, + } + if keyType != nil && valueType != nil { + t.Parameters = append(t.Parameters, keyType, valueType) + } + return t } -// NullableType creates an instance of a nullable type with the provided wrapped type. +// NewNullableType creates an instance of a nullable type with the provided wrapped type. // // Note: only primitive types are supported as wrapped types. -func NullableType(wrapped *Type) *Type { +func NewNullableType(wrapped *Type) *Type { return &Type{ - Kind: wrapped.Kind, - runtimeType: wrapped.runtimeType, - Parameters: wrapped.Parameters, + Kind: wrapped.Kind, + Parameters: wrapped.Parameters, + runtimeTypeName: wrapped.runtimeTypeName, + traitMask: wrapped.traitMask, isAssignableType: func(other *Type) bool { return NullType.IsAssignableType(other) || wrapped.IsAssignableType(other) }, @@ -335,47 +451,73 @@ func NullableType(wrapped *Type) *Type { } } -// OptionalType creates an abstract parameterized type instance corresponding to CEL's notion of optional. -func OptionalType(param *Type) *Type { - return OpaqueType("optional", param) +// NewOptionalType creates an abstract parameterized type instance corresponding to CEL's notion of optional. +func NewOptionalType(param *Type) *Type { + return NewOpaqueType("optional", param) } -// OpaqueType creates an abstract parameterized type with a given name. -func OpaqueType(name string, params ...*Type) *Type { +// NewOpaqueType creates an abstract parameterized type with a given name. +func NewOpaqueType(name string, params ...*Type) *Type { return &Type{ - Kind: OpaqueKind, - runtimeType: types.NewTypeValue(name), - Parameters: params, + Kind: OpaqueKind, + Parameters: params, + runtimeTypeName: name, } } -// ObjectType creates a type references to an externally defined type, e.g. a protobuf message type. -func ObjectType(typeName string) *Type { +// NewObjectType creates a type reference to an externally defined type, e.g. a protobuf message type. +func NewObjectType(typeName string) *Type { // Function sanitizes object types on the fly if wkt, found := checkedWellKnowns[typeName]; found { return wkt } return &Type{ - Kind: StructKind, - runtimeType: types.NewObjectTypeValue(typeName), + Kind: StructKind, + Parameters: []*Type{}, + runtimeTypeName: typeName, + traitMask: traits.FieldTesterType | traits.IndexerType, + } +} + +// NewObjectTypeValue creates a type reference to an externally defined type. +// +// Deprecated: use cel.ObjectType(typeName) +func NewObjectTypeValue(typeName string) *Type { + return NewObjectType(typeName) +} + +// NewTypeValue creates an opaque type which has a set of optional type traits as defined in +// the common/types/traits package. +// +// Deprecated: use cel.OpaqueType(typeName) +func NewTypeValue(typeName string, traits ...int) *Type { + traitMask := 0 + for _, trait := range traits { + traitMask |= trait + } + return &Type{ + Kind: OpaqueKind, + Parameters: []*Type{}, + runtimeTypeName: typeName, + traitMask: traitMask, } } -// TypeParamType creates a parameterized type instance. -func TypeParamType(paramName string) *Type { +// NewTypeParamType creates a parameterized type instance. +func NewTypeParamType(paramName string) *Type { return &Type{ - Kind: TypeParamKind, - runtimeType: types.NewTypeValue(paramName), + Kind: TypeParamKind, + runtimeTypeName: paramName, } } -// TypeTypeWithParam creates a type with a type parameter. +// NewTypeTypeWithParam creates a type with a type parameter. // Used for type-checking purposes, but equivalent to TypeType otherwise. -func TypeTypeWithParam(param *Type) *Type { +func NewTypeTypeWithParam(param *Type) *Type { return &Type{ - Kind: TypeKind, - Parameters: []*Type{param}, - runtimeType: types.TypeType, + Kind: TypeKind, + runtimeTypeName: "type", + Parameters: []*Type{param}, } } @@ -429,15 +571,15 @@ func TypeToExprType(t *Type) (*exprpb.Type, error) { } params[i] = pt } - return chkdecls.NewAbstractType(t.RuntimeTypeName(), params...), nil + return chkdecls.NewAbstractType(t.TypeName(), params...), nil case StringKind: return maybeWrapper(t, chkdecls.String), nil case StructKind: - return chkdecls.NewObjectType(t.RuntimeTypeName()), nil + return chkdecls.NewObjectType(t.TypeName()), nil case TimestampKind: return chkdecls.Timestamp, nil case TypeParamKind: - return chkdecls.NewTypeParamType(t.RuntimeTypeName()), nil + return chkdecls.NewTypeParamType(t.TypeName()), nil case TypeKind: if len(t.Parameters) == 1 { p, err := TypeToExprType(t.Parameters[0]) @@ -467,13 +609,13 @@ func ExprTypeToType(t *exprpb.Type) (*Type, error) { } paramTypes[i] = pt } - return OpaqueType(t.GetAbstractType().GetName(), paramTypes...), nil + return NewOpaqueType(t.GetAbstractType().GetName(), paramTypes...), nil case *exprpb.Type_ListType_: et, err := ExprTypeToType(t.GetListType().GetElemType()) if err != nil { return nil, err } - return ListType(et), nil + return NewListType(et), nil case *exprpb.Type_MapType_: kt, err := ExprTypeToType(t.GetMapType().GetKeyType()) if err != nil { @@ -483,9 +625,9 @@ func ExprTypeToType(t *exprpb.Type) (*Type, error) { if err != nil { return nil, err } - return MapType(kt, vt), nil + return NewMapType(kt, vt), nil case *exprpb.Type_MessageType: - return ObjectType(t.GetMessageType()), nil + return NewObjectType(t.GetMessageType()), nil case *exprpb.Type_Null: return NullType, nil case *exprpb.Type_Primitive: @@ -506,14 +648,14 @@ func ExprTypeToType(t *exprpb.Type) (*Type, error) { return nil, fmt.Errorf("unsupported primitive type: %v", t) } case *exprpb.Type_TypeParam: - return TypeParamType(t.GetTypeParam()), nil + return NewTypeParamType(t.GetTypeParam()), nil case *exprpb.Type_Type: if t.GetType().GetTypeKind() != nil { p, err := ExprTypeToType(t.GetType()) if err != nil { return nil, err } - return TypeTypeWithParam(p), nil + return NewTypeTypeWithParam(p), nil } return TypeType, nil case *exprpb.Type_WellKnown: @@ -532,7 +674,7 @@ func ExprTypeToType(t *exprpb.Type) (*Type, error) { if err != nil { return nil, err } - return NullableType(t), nil + return NewNullableType(t), nil default: return nil, fmt.Errorf("unsupported type: %v", t) } @@ -548,23 +690,23 @@ func maybeWrapper(t *Type, pbType *exprpb.Type) *exprpb.Type { var ( checkedWellKnowns = map[string]*Type{ // Wrapper types. - "google.protobuf.BoolValue": NullableType(BoolType), - "google.protobuf.BytesValue": NullableType(BytesType), - "google.protobuf.DoubleValue": NullableType(DoubleType), - "google.protobuf.FloatValue": NullableType(DoubleType), - "google.protobuf.Int64Value": NullableType(IntType), - "google.protobuf.Int32Value": NullableType(IntType), - "google.protobuf.UInt64Value": NullableType(UintType), - "google.protobuf.UInt32Value": NullableType(UintType), - "google.protobuf.StringValue": NullableType(StringType), + "google.protobuf.BoolValue": NewNullableType(BoolType), + "google.protobuf.BytesValue": NewNullableType(BytesType), + "google.protobuf.DoubleValue": NewNullableType(DoubleType), + "google.protobuf.FloatValue": NewNullableType(DoubleType), + "google.protobuf.Int64Value": NewNullableType(IntType), + "google.protobuf.Int32Value": NewNullableType(IntType), + "google.protobuf.UInt64Value": NewNullableType(UintType), + "google.protobuf.UInt32Value": NewNullableType(UintType), + "google.protobuf.StringValue": NewNullableType(StringType), // Well-known types. "google.protobuf.Any": AnyType, "google.protobuf.Duration": DurationType, "google.protobuf.Timestamp": TimestampType, // Json types. - "google.protobuf.ListValue": ListType(DynType), + "google.protobuf.ListValue": NewListType(DynType), "google.protobuf.NullValue": NullType, - "google.protobuf.Struct": MapType(StringType, DynType), + "google.protobuf.Struct": NewMapType(StringType, DynType), "google.protobuf.Value": DynType, } ) diff --git a/common/decls/types_test.go b/common/decls/types_test.go index a23225d1..461c5190 100644 --- a/common/decls/types_test.go +++ b/common/decls/types_test.go @@ -15,6 +15,8 @@ package decls import ( + "errors" + "reflect" "strings" "testing" "time" @@ -24,6 +26,8 @@ import ( chkdecls "github.com/google/cel-go/checker/decls" "github.com/google/cel-go/common/overloads" "github.com/google/cel-go/common/types" + "github.com/google/cel-go/common/types/ref" + "github.com/google/cel-go/common/types/traits" exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1" ) @@ -34,11 +38,11 @@ func TestTypeString(t *testing.T) { out string }{ { - in: ListType(IntType), + in: NewListType(IntType), out: "list(int)", }, { - in: MapType(UintType, DoubleType), + in: NewMapType(UintType, DoubleType), out: "map(uint, double)", }, { @@ -54,35 +58,35 @@ func TestTypeString(t *testing.T) { out: "null_type", }, { - in: NullableType(BoolType), + in: NewNullableType(BoolType), out: "wrapper(bool)", }, { - in: OptionalType(ListType(StringType)), + in: NewOptionalType(NewListType(StringType)), out: "optional(list(string))", }, { - in: ObjectType("my.type.Message"), + in: NewObjectType("my.type.Message"), out: "my.type.Message", }, { - in: ObjectType("google.protobuf.Int32Value"), + in: NewObjectType("google.protobuf.Int32Value"), out: "wrapper(int)", }, { - in: ObjectType("google.protobuf.UInt32Value"), + in: NewObjectType("google.protobuf.UInt32Value"), out: "wrapper(uint)", }, { - in: ObjectType("google.protobuf.Value"), + in: NewObjectType("google.protobuf.Value"), out: "dyn", }, { - in: TypeTypeWithParam(StringType), + in: NewTypeTypeWithParam(StringType), out: "type(string)", }, { - in: TypeParamType("T"), + in: NewTypeParamType("T"), out: "T", }, } @@ -110,28 +114,28 @@ func TestTypeIsType(t *testing.T) { isType: false, }, { - t1: OptionalType(StringType), - t2: OptionalType(IntType), + t1: NewOptionalType(StringType), + t2: NewOptionalType(IntType), isType: false, }, { - t1: OptionalType(UintType), - t2: OptionalType(UintType), + t1: NewOptionalType(UintType), + t2: NewOptionalType(UintType), isType: true, }, { - t1: MapType(BoolType, IntType), - t2: MapType(BoolType, IntType), + t1: NewMapType(BoolType, IntType), + t2: NewMapType(BoolType, IntType), isType: true, }, { - t1: MapType(TypeParamType("K1"), IntType), - t2: MapType(TypeParamType("K2"), IntType), + t1: NewMapType(NewTypeParamType("K1"), IntType), + t2: NewMapType(NewTypeParamType("K2"), IntType), isType: true, }, { - t1: MapType(TypeParamType("K1"), ObjectType("my.msg.First")), - t2: MapType(TypeParamType("K2"), ObjectType("my.msg.Last")), + t1: NewMapType(NewTypeParamType("K1"), NewObjectType("my.msg.First")), + t2: NewMapType(NewTypeParamType("K2"), NewObjectType("my.msg.Last")), isType: false, }, } @@ -149,24 +153,24 @@ func TestTypeTypeVariable(t *testing.T) { }{ { t: AnyType, - v: NewVariable("google.protobuf.Any", TypeTypeWithParam(AnyType)), + v: NewVariable("google.protobuf.Any", NewTypeTypeWithParam(AnyType)), }, { t: DynType, - v: NewVariable("dyn", TypeTypeWithParam(DynType)), + v: NewVariable("dyn", NewTypeTypeWithParam(DynType)), }, { - t: ObjectType("google.protobuf.Int32Value"), - v: NewVariable("int", TypeTypeWithParam(NullableType(IntType))), + t: NewObjectType("google.protobuf.Int32Value"), + v: NewVariable("int", NewTypeTypeWithParam(NewNullableType(IntType))), }, { - t: ObjectType("google.protobuf.Int32Value"), - v: NewVariable("int", TypeTypeWithParam(NullableType(IntType))), + t: NewObjectType("google.protobuf.Int32Value"), + v: NewVariable("int", NewTypeTypeWithParam(NewNullableType(IntType))), }, } for _, tst := range tests { - if !tst.t.TypeVariable().DeclarationEquals(tst.v) { - t.Errorf("got not equal %v.Equals(%v)", tst.t.TypeVariable(), tst.v) + if !TypeVariable(tst.t).DeclarationEquals(tst.v) { + t.Errorf("got not equal %v.Equals(%v)", TypeVariable(tst.t), tst.v) } } } @@ -178,58 +182,58 @@ func TestTypeIsAssignableType(t *testing.T) { isAssignable bool }{ { - t1: NullableType(DoubleType), + t1: NewNullableType(DoubleType), t2: NullType, isAssignable: true, }, { - t1: NullableType(DoubleType), + t1: NewNullableType(DoubleType), t2: DoubleType, isAssignable: true, }, { - t1: OpaqueType("vector", NullableType(DoubleType)), - t2: OpaqueType("vector", NullType), + t1: NewOpaqueType("vector", NewNullableType(DoubleType)), + t2: NewOpaqueType("vector", NullType), isAssignable: true, }, { - t1: OpaqueType("vector", NullableType(DoubleType)), - t2: OpaqueType("vector", DoubleType), + t1: NewOpaqueType("vector", NewNullableType(DoubleType)), + t2: NewOpaqueType("vector", DoubleType), isAssignable: true, }, { - t1: OpaqueType("vector", DynType), - t2: OpaqueType("vector", NullableType(IntType)), + t1: NewOpaqueType("vector", DynType), + t2: NewOpaqueType("vector", NewNullableType(IntType)), isAssignable: true, }, { - t1: ObjectType("my.msg.MsgName"), - t2: ObjectType("my.msg.MsgName"), + t1: NewObjectType("my.msg.MsgName"), + t2: NewObjectType("my.msg.MsgName"), isAssignable: true, }, { - t1: MapType(TypeParamType("K"), IntType), - t2: MapType(StringType, IntType), + t1: NewMapType(NewTypeParamType("K"), IntType), + t2: NewMapType(StringType, IntType), isAssignable: true, }, { - t1: MapType(StringType, IntType), - t2: MapType(TypeParamType("K"), IntType), + t1: NewMapType(StringType, IntType), + t2: NewMapType(NewTypeParamType("K"), IntType), isAssignable: false, }, { - t1: OpaqueType("vector", DoubleType), - t2: OpaqueType("vector", NullableType(IntType)), + t1: NewOpaqueType("vector", DoubleType), + t2: NewOpaqueType("vector", NewNullableType(IntType)), isAssignable: false, }, { - t1: OpaqueType("vector", NullableType(DoubleType)), - t2: OpaqueType("vector", DynType), + t1: NewOpaqueType("vector", NewNullableType(DoubleType)), + t2: NewOpaqueType("vector", DynType), isAssignable: false, }, { - t1: ObjectType("my.msg.MsgName"), - t2: ObjectType("my.msg.MsgName2"), + t1: NewObjectType("my.msg.MsgName"), + t2: NewObjectType("my.msg.MsgName2"), isAssignable: false, }, } @@ -241,9 +245,9 @@ func TestTypeIsAssignableType(t *testing.T) { } func TestTypeSignatureEquals(t *testing.T) { - paramA := TypeParamType("A") - paramB := TypeParamType("B") - mapOfAB := MapType(paramA, paramB) + paramA := NewTypeParamType("A") + paramB := NewTypeParamType("B") + mapOfAB := NewMapType(paramA, paramB) fn, err := NewFunction(overloads.Size, MemberOverload(overloads.SizeMapInst, []*Type{mapOfAB}, IntType), Overload(overloads.SizeMap, []*Type{mapOfAB}, IntType)) @@ -259,25 +263,25 @@ func TestTypeSignatureEquals(t *testing.T) { } func TestTypeIsAssignableRuntimeType(t *testing.T) { - if !NullableType(DoubleType).IsAssignableRuntimeType(types.NullValue) { + if !NewNullableType(DoubleType).IsAssignableRuntimeType(types.NullValue) { t.Error("nullable double cannot be assigned from null") } - if !NullableType(DoubleType).IsAssignableRuntimeType(types.Double(0.0)) { + if !NewNullableType(DoubleType).IsAssignableRuntimeType(types.Double(0.0)) { t.Error("nullable double cannot be assigned from double") } - if !MapType(StringType, DurationType).IsAssignableRuntimeType( + if !NewMapType(StringType, DurationType).IsAssignableRuntimeType( types.DefaultTypeAdapter.NativeToValue(map[string]time.Duration{})) { t.Error("map(string, duration) not assignable to map at runtime") } - if !MapType(StringType, DurationType).IsAssignableRuntimeType( + if !NewMapType(StringType, DurationType).IsAssignableRuntimeType( types.DefaultTypeAdapter.NativeToValue(map[string]time.Duration{"one": time.Duration(1)})) { t.Error("map(string, duration) not assignable to map at runtime") } - if !MapType(StringType, DynType).IsAssignableRuntimeType( + if !NewMapType(StringType, DynType).IsAssignableRuntimeType( types.DefaultTypeAdapter.NativeToValue(map[string]time.Duration{"one": time.Duration(1)})) { t.Error("map(string, dyn) not assignable to map at runtime") } - if MapType(StringType, DynType).IsAssignableRuntimeType( + if NewMapType(StringType, DynType).IsAssignableRuntimeType( types.DefaultTypeAdapter.NativeToValue(map[int64]time.Duration{1: time.Duration(1)})) { t.Error("map(string, dyn) must not be assignable to map(int, duration) at runtime") } @@ -290,7 +294,7 @@ func TestTypeToExprType(t *testing.T) { unidirectional bool }{ { - in: OpaqueType("vector", DoubleType, DoubleType), + in: NewOpaqueType("vector", DoubleType, DoubleType), out: chkdecls.NewAbstractType("vector", chkdecls.Double, chkdecls.Double), }, { @@ -322,11 +326,11 @@ func TestTypeToExprType(t *testing.T) { out: chkdecls.Int, }, { - in: ListType(TypeParamType("T")), + in: NewListType(NewTypeParamType("T")), out: chkdecls.NewListType(chkdecls.NewTypeParamType("T")), }, { - in: MapType(TypeParamType("K"), TypeParamType("V")), + in: NewMapType(NewTypeParamType("K"), NewTypeParamType("V")), out: chkdecls.NewMapType(chkdecls.NewTypeParamType("K"), chkdecls.NewTypeParamType("V")), }, { @@ -334,7 +338,7 @@ func TestTypeToExprType(t *testing.T) { out: chkdecls.Null, }, { - in: ObjectType("google.type.Expr"), + in: NewObjectType("google.type.Expr"), out: chkdecls.NewObjectType("google.type.Expr"), }, { @@ -354,105 +358,105 @@ func TestTypeToExprType(t *testing.T) { out: chkdecls.Uint, }, { - in: NullableType(BoolType), + in: NewNullableType(BoolType), out: chkdecls.NewWrapperType(chkdecls.Bool), }, { - in: NullableType(BytesType), + in: NewNullableType(BytesType), out: chkdecls.NewWrapperType(chkdecls.Bytes), }, { - in: NullableType(DoubleType), + in: NewNullableType(DoubleType), out: chkdecls.NewWrapperType(chkdecls.Double), }, { - in: NullableType(IntType), + in: NewNullableType(IntType), out: chkdecls.NewWrapperType(chkdecls.Int), }, { - in: NullableType(StringType), + in: NewNullableType(StringType), out: chkdecls.NewWrapperType(chkdecls.String), }, { - in: NullableType(UintType), + in: NewNullableType(UintType), out: chkdecls.NewWrapperType(chkdecls.Uint), }, { - in: TypeTypeWithParam(TypeTypeWithParam(DynType)), + in: NewTypeTypeWithParam(NewTypeTypeWithParam(DynType)), out: chkdecls.NewTypeType(chkdecls.NewTypeType(chkdecls.Dyn)), }, { - in: ObjectType("google.protobuf.Any"), + in: NewObjectType("google.protobuf.Any"), out: chkdecls.Any, unidirectional: true, }, { - in: ObjectType("google.protobuf.Duration"), + in: NewObjectType("google.protobuf.Duration"), out: chkdecls.Duration, unidirectional: true, }, { - in: ObjectType("google.protobuf.Timestamp"), + in: NewObjectType("google.protobuf.Timestamp"), out: chkdecls.Timestamp, unidirectional: true, }, { - in: ObjectType("google.protobuf.Value"), + in: NewObjectType("google.protobuf.Value"), out: chkdecls.Dyn, unidirectional: true, }, { - in: ObjectType("google.protobuf.ListValue"), + in: NewObjectType("google.protobuf.ListValue"), out: chkdecls.NewListType(chkdecls.Dyn), unidirectional: true, }, { - in: ObjectType("google.protobuf.Struct"), + in: NewObjectType("google.protobuf.Struct"), out: chkdecls.NewMapType(chkdecls.String, chkdecls.Dyn), unidirectional: true, }, { - in: ObjectType("google.protobuf.BoolValue"), + in: NewObjectType("google.protobuf.BoolValue"), out: chkdecls.NewWrapperType(chkdecls.Bool), unidirectional: true, }, { - in: ObjectType("google.protobuf.BytesValue"), + in: NewObjectType("google.protobuf.BytesValue"), out: chkdecls.NewWrapperType(chkdecls.Bytes), unidirectional: true, }, { - in: ObjectType("google.protobuf.DoubleValue"), + in: NewObjectType("google.protobuf.DoubleValue"), out: chkdecls.NewWrapperType(chkdecls.Double), unidirectional: true, }, { - in: ObjectType("google.protobuf.FloatValue"), + in: NewObjectType("google.protobuf.FloatValue"), out: chkdecls.NewWrapperType(chkdecls.Double), unidirectional: true, }, { - in: ObjectType("google.protobuf.Int32Value"), + in: NewObjectType("google.protobuf.Int32Value"), out: chkdecls.NewWrapperType(chkdecls.Int), unidirectional: true, }, { - in: ObjectType("google.protobuf.Int64Value"), + in: NewObjectType("google.protobuf.Int64Value"), out: chkdecls.NewWrapperType(chkdecls.Int), unidirectional: true, }, { - in: ObjectType("google.protobuf.StringValue"), + in: NewObjectType("google.protobuf.StringValue"), out: chkdecls.NewWrapperType(chkdecls.String), unidirectional: true, }, { - in: ObjectType("google.protobuf.UInt32Value"), + in: NewObjectType("google.protobuf.UInt32Value"), out: chkdecls.NewWrapperType(chkdecls.Uint), unidirectional: true, }, { - in: ObjectType("google.protobuf.UInt64Value"), + in: NewObjectType("google.protobuf.UInt64Value"), out: chkdecls.NewWrapperType(chkdecls.Uint), unidirectional: true, }, @@ -488,21 +492,21 @@ func TestTypeToExprTypeInvalid(t *testing.T) { out string }{ { - in: &Type{Kind: ListKind, runtimeType: types.ListType}, + in: &Type{Kind: ListKind, runtimeTypeName: "list"}, out: "invalid list", }, { in: &Type{ Kind: ListKind, Parameters: []*Type{ - {Kind: MapKind, runtimeType: types.MapType}, + {Kind: MapKind, runtimeTypeName: "map"}, }, - runtimeType: types.ListType, + runtimeTypeName: "list", }, out: "invalid map", }, { - in: &Type{Kind: MapKind, runtimeType: types.MapType}, + in: &Type{Kind: MapKind, runtimeTypeName: "map"}, out: "invalid map", }, { @@ -510,9 +514,9 @@ func TestTypeToExprTypeInvalid(t *testing.T) { Kind: MapKind, Parameters: []*Type{ StringType, - {Kind: MapKind, runtimeType: types.MapType}, + {Kind: MapKind, runtimeTypeName: "map"}, }, - runtimeType: types.MapType, + runtimeTypeName: "map", }, out: "invalid map", }, @@ -520,25 +524,25 @@ func TestTypeToExprTypeInvalid(t *testing.T) { in: &Type{ Kind: MapKind, Parameters: []*Type{ - {Kind: MapKind, runtimeType: types.MapType}, + {Kind: MapKind, runtimeTypeName: "map"}, StringType, }, - runtimeType: types.MapType, + runtimeTypeName: "map", }, out: "invalid map", }, { in: &Type{ - Kind: TypeKind, - Parameters: []*Type{{Kind: ListKind, runtimeType: types.ListType}}, - runtimeType: types.TypeType, + Kind: TypeKind, + Parameters: []*Type{{Kind: ListKind, runtimeTypeName: "list"}}, + runtimeTypeName: "type", }, out: "invalid list", }, { - in: OpaqueType("bad_list", &Type{ - Kind: ListKind, - runtimeType: types.ListType, + in: NewOpaqueType("bad_list", &Type{ + Kind: ListKind, + runtimeTypeName: "list", }), out: "invalid list", }, @@ -581,47 +585,47 @@ func TestExprTypeToType(t *testing.T) { }, { in: chkdecls.NewObjectType("google.protobuf.ListValue"), - out: ListType(DynType), + out: NewListType(DynType), }, { in: chkdecls.NewObjectType("google.protobuf.Struct"), - out: MapType(StringType, DynType), + out: NewMapType(StringType, DynType), }, { in: chkdecls.NewObjectType("google.protobuf.BoolValue"), - out: NullableType(BoolType), + out: NewNullableType(BoolType), }, { in: chkdecls.NewObjectType("google.protobuf.BytesValue"), - out: NullableType(BytesType), + out: NewNullableType(BytesType), }, { in: chkdecls.NewObjectType("google.protobuf.DoubleValue"), - out: NullableType(DoubleType), + out: NewNullableType(DoubleType), }, { in: chkdecls.NewObjectType("google.protobuf.FloatValue"), - out: NullableType(DoubleType), + out: NewNullableType(DoubleType), }, { in: chkdecls.NewObjectType("google.protobuf.Int32Value"), - out: NullableType(IntType), + out: NewNullableType(IntType), }, { in: chkdecls.NewObjectType("google.protobuf.Int64Value"), - out: NullableType(IntType), + out: NewNullableType(IntType), }, { in: chkdecls.NewObjectType("google.protobuf.StringValue"), - out: NullableType(StringType), + out: NewNullableType(StringType), }, { in: chkdecls.NewObjectType("google.protobuf.UInt32Value"), - out: NullableType(UintType), + out: NewNullableType(UintType), }, { in: chkdecls.NewObjectType("google.protobuf.UInt64Value"), - out: NullableType(UintType), + out: NewNullableType(UintType), }, } @@ -692,3 +696,126 @@ func TestExprTypeToTypeInvalid(t *testing.T) { }) } } + +func TestTypeHasTrait(t *testing.T) { + if !BoolType.HasTrait(traits.ComparerType) { + t.Error("BoolType.HasTrait(ComparerType) returned false") + } +} + +func TestTypeConvertToType(t *testing.T) { + _, err := BoolType.ConvertToNative(reflect.TypeOf(true)) + if err == nil { + t.Error("ConvertToNative() did not error") + } +} + +func TestTypeCommonTypeInterop(t *testing.T) { + tests := []struct { + commonType ref.Type + declType *Type + }{ + { + commonType: types.BoolType, + declType: BoolType, + }, + { + commonType: types.BytesType, + declType: BytesType, + }, + { + commonType: types.DoubleType, + declType: DoubleType, + }, + { + commonType: types.DurationType, + declType: DurationType, + }, + { + commonType: types.ErrType, + declType: ErrorType, + }, + { + commonType: types.IntType, + declType: IntType, + }, + { + commonType: types.ListType, + declType: ListType, + }, + { + commonType: types.MapType, + declType: MapType, + }, + { + commonType: types.NullType, + declType: NullType, + }, + { + commonType: types.StringType, + declType: StringType, + }, + { + commonType: types.TimestampType, + declType: TimestampType, + }, + { + commonType: types.TypeType, + declType: TypeType, + }, + { + commonType: types.UintType, + declType: UintType, + }, + { + commonType: types.UnknownType, + declType: UnknownType, + }, + { + commonType: types.NewObjectTypeValue("dev.cel.Expr"), + declType: NewObjectTypeValue("dev.cel.Expr"), + }, + { + commonType: types.NewTypeValue("vector", traits.AdderType), + declType: NewTypeValue("vector", traits.AdderType), + }, + } + for _, tst := range tests { + tc := tst + t.Run(tc.commonType.TypeName(), func(t *testing.T) { + if tc.commonType.TypeName() != tc.declType.TypeName() { + t.Errorf("type names not equal: got %v, wanted %v", tc.declType.TypeName(), tc.commonType.TypeName()) + } + if !tc.commonType.HasTrait(tc.declType.traitMask) { + t.Errorf("trait masks not equal: mask %v", tc.declType.traitMask) + } + ctVal := tc.commonType.(ref.Val) + if ctVal.Equal(tc.declType) != types.True || + tc.declType.Equal(ctVal) != types.True { + t.Error("types not runtime equal") + } + if ctVal.Type() != types.TypeType { + t.Errorf("type not marked as a type: %v", ctVal.Type()) + } + if tc.declType.Type() != TypeType { + t.Errorf("type not marked as a type: %v", tc.declType.Type()) + } + if ctVal.Value() != tc.declType.Value() { + t.Errorf("type values not equal: got %v, wanted %v", tc.declType.Value(), ctVal.Value()) + } + if ctVal.ConvertToType(types.StringType). + Equal(tc.declType.ConvertToType(StringType)) != types.True { + t.Error("type values did not convert to same string values") + } + if ctVal.ConvertToType(types.TypeType). + Equal(tc.declType.ConvertToType(TypeType)) != types.True { + t.Error("type values did not convert to same type values") + } + if !errors.Is( + ctVal.ConvertToType(types.ErrType).(*types.Err), + tc.declType.ConvertToType(ErrorType).(*types.Err)) { + t.Error("type values did not convert to same error values") + } + }) + } +} diff --git a/common/stdlib/standard.go b/common/stdlib/standard.go index 6804c764..5a4dc87f 100644 --- a/common/stdlib/standard.go +++ b/common/stdlib/standard.go @@ -35,24 +35,24 @@ var ( ) func init() { - paramA := decls.TypeParamType("A") - paramB := decls.TypeParamType("B") - listOfA := decls.ListType(paramA) - mapOfAB := decls.MapType(paramA, paramB) + paramA := decls.NewTypeParamType("A") + paramB := decls.NewTypeParamType("B") + listOfA := decls.NewListType(paramA) + mapOfAB := decls.NewMapType(paramA, paramB) stdTypes = []*decls.VariableDecl{ - decls.BoolType.TypeVariable(), - decls.BytesType.TypeVariable(), - decls.DoubleType.TypeVariable(), - decls.DurationType.TypeVariable(), - decls.IntType.TypeVariable(), - listOfA.TypeVariable(), - mapOfAB.TypeVariable(), - decls.NullType.TypeVariable(), - decls.StringType.TypeVariable(), - decls.TimestampType.TypeVariable(), - decls.TypeType.TypeVariable(), - decls.UintType.TypeVariable(), + decls.TypeVariable(decls.BoolType), + decls.TypeVariable(decls.BytesType), + decls.TypeVariable(decls.DoubleType), + decls.TypeVariable(decls.DurationType), + decls.TypeVariable(decls.IntType), + decls.TypeVariable(listOfA), + decls.TypeVariable(mapOfAB), + decls.TypeVariable(decls.NullType), + decls.TypeVariable(decls.StringType), + decls.TypeVariable(decls.TimestampType), + decls.TypeVariable(decls.TypeType), + decls.TypeVariable(decls.UintType), } stdTypeDecls = make([]*exprpb.Decl, 0, len(stdTypes)) @@ -386,7 +386,7 @@ func init() { // Type conversions function(overloads.TypeConvertType, - decls.Overload(overloads.TypeConvertType, argTypes(paramA), decls.TypeTypeWithParam(paramA)), + decls.Overload(overloads.TypeConvertType, argTypes(paramA), decls.NewTypeTypeWithParam(paramA)), decls.SingletonUnaryBinding(convertToType(types.TypeType))), // Bool conversions diff --git a/common/types/object.go b/common/types/object.go index 9955e2dc..884730c7 100644 --- a/common/types/object.go +++ b/common/types/object.go @@ -32,7 +32,7 @@ type protoObj struct { ref.TypeAdapter value proto.Message typeDesc *pb.TypeDescription - typeValue *TypeValue + typeValue ref.Val } // NewObject returns an object based on a proto.Message value which handles @@ -44,7 +44,7 @@ type protoObj struct { // then this will result in an error within the type adapter / provider. func NewObject(adapter ref.TypeAdapter, typeDesc *pb.TypeDescription, - typeValue *TypeValue, + typeValue ref.Val, value proto.Message) ref.Val { return &protoObj{ TypeAdapter: adapter, @@ -157,7 +157,7 @@ func (o *protoObj) Get(index ref.Val) ref.Val { } func (o *protoObj) Type() ref.Type { - return o.typeValue + return o.typeValue.(ref.Type) } func (o *protoObj) Value() any { diff --git a/common/types/object_test.go b/common/types/object_test.go index 85b06a8c..14171a90 100644 --- a/common/types/object_test.go +++ b/common/types/object_test.go @@ -190,7 +190,7 @@ func TestProtoObjectConvertToType(t *testing.T) { } reg := newTestRegistry(t, msg) objVal := reg.NativeToValue(msg) - tv := objVal.Type().(*TypeValue) + tv := objVal.Type().(ref.Val) if objVal.ConvertToType(TypeType).Equal(tv) != True { t.Errorf("got non-type value: %v, wanted objet type", objVal.ConvertToType(TypeType)) } diff --git a/common/types/provider.go b/common/types/provider.go index e66951f5..a6b4e6ff 100644 --- a/common/types/provider.go +++ b/common/types/provider.go @@ -218,7 +218,7 @@ func (p *protoTypeRegistry) NativeToValue(value any) ref.Val { if !found { return NewErr("unknown type: '%s'", typeName) } - return NewObject(p, td, typeVal.(*TypeValue), v) + return NewObject(p, td, typeVal, v) case *pb.Map: return NewProtoMap(p, v) case protoreflect.List: diff --git a/common/types/type_test.go b/common/types/type_test.go index 5c1b21b6..a63d1705 100644 --- a/common/types/type_test.go +++ b/common/types/type_test.go @@ -14,10 +14,14 @@ package types -import "testing" +import ( + "testing" + + "github.com/google/cel-go/common/types/ref" +) func TestType_ConvertToType(t *testing.T) { - stdTypes := []*TypeValue{ + stdTypes := []ref.Val{ BoolType, BytesType, DoubleType, diff --git a/interpreter/formatting.go b/interpreter/formatting.go index 6a98f6fa..e3f75337 100644 --- a/interpreter/formatting.go +++ b/interpreter/formatting.go @@ -25,7 +25,7 @@ import ( "github.com/google/cel-go/common/types/ref" ) -type typeVerifier func(int64, ...*types.TypeValue) (bool, error) +type typeVerifier func(int64, ...ref.Type) (bool, error) // InterpolateFormattedString checks the syntax and cardinality of any string.format calls present in the expression and reports // any errors at compile time.