Skip to content

Commit

Permalink
Make decls/type.go equivalent to existing type singletons in types pa…
Browse files Browse the repository at this point in the history
…ckage (#743)
  • Loading branch information
TristonianJones authored Jun 16, 2023
1 parent d69f12a commit 10a4b58
Show file tree
Hide file tree
Showing 13 changed files with 577 additions and 298 deletions.
17 changes: 9 additions & 8 deletions cel/decls.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"github.com/google/cel-go/common/decls"
"github.com/google/cel-go/common/functions"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"

exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
)
Expand Down Expand Up @@ -110,21 +111,21 @@ var (
// function references for instantiating new types.

// ListType creates an instances of a list type value with the provided element type.
ListType = decls.ListType
ListType = decls.NewListType
// MapType creates an instance of a map type value with the provided key and value types.
MapType = decls.MapType
MapType = decls.NewMapType
// NullableType creates an instance of a nullable type with the provided wrapped type.
//
// Note: only primitive types are supported as wrapped types.
NullableType = decls.NullableType
NullableType = decls.NewNullableType
// OptionalType creates an abstract parameterized type instance corresponding to CEL's notion of optional.
OptionalType = decls.OptionalType
OptionalType = decls.NewOptionalType
// OpaqueType creates an abstract parameterized type with a given name.
OpaqueType = decls.OpaqueType
OpaqueType = decls.NewOpaqueType
// ObjectType creates a type references to an externally defined type, e.g. a protobuf message type.
ObjectType = decls.ObjectType
ObjectType = decls.NewObjectType
// TypeParamType creates a parameterized type instance.
TypeParamType = decls.TypeParamType
TypeParamType = decls.NewTypeParamType
)

// Type holds a reference to a runtime type with an optional type-checked set of type parameters.
Expand Down Expand Up @@ -338,7 +339,7 @@ func ExprDeclToDeclaration(d *exprpb.Decl) (EnvOption, error) {
}
}

func typeValueToKind(tv *types.TypeValue) (Kind, error) {
func typeValueToKind(tv ref.Type) (Kind, error) {
switch tv {
case types.BoolType:
return BoolKind, nil
Expand Down
2 changes: 1 addition & 1 deletion cel/io.go
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ func RefValueToValue(res ref.Val) (*exprpb.Value, error) {
}

var (
typeNameToTypeValue = map[string]*types.TypeValue{
typeNameToTypeValue = map[string]ref.Val{
"bool": types.BoolType,
"bytes": types.BytesType,
"double": types.DoubleType,
Expand Down
6 changes: 3 additions & 3 deletions cel/program.go
Original file line number Diff line number Diff line change
Expand Up @@ -208,9 +208,9 @@ func newProgram(e *Env, ast *Ast, opts []ProgramOption) (Program, error) {
}
// Enable compile-time checking of syntax/cardinality for string.format calls.
if p.evalOpts&OptCheckStringFormat == OptCheckStringFormat {
var isValidType func(id int64, validTypes ...*types.TypeValue) (bool, error)
var isValidType func(id int64, validTypes ...ref.Type) (bool, error)
if ast.IsChecked() {
isValidType = func(id int64, validTypes ...*types.TypeValue) (bool, error) {
isValidType = func(id int64, validTypes ...ref.Type) (bool, error) {
t, err := ExprTypeToType(ast.typeMap[id])
if err != nil {
return false, err
Expand All @@ -231,7 +231,7 @@ func newProgram(e *Env, ast *Ast, opts []ProgramOption) (Program, error) {
}
} else {
// if the AST isn't type-checked, short-circuit validation
isValidType = func(id int64, validTypes ...*types.TypeValue) (bool, error) {
isValidType = func(id int64, validTypes ...ref.Type) (bool, error) {
return true, nil
}
}
Expand Down
7 changes: 6 additions & 1 deletion common/decls/decls.go
Original file line number Diff line number Diff line change
Expand Up @@ -664,6 +664,11 @@ func VariableDeclToExprDecl(v *VariableDecl) (*exprpb.Decl, error) {
return chkdecls.NewVar(v.Name, varType), nil
}

// TypeVariable creates a new type identifier for use within a ref.TypeProvider
func TypeVariable(t *Type) *VariableDecl {
return NewVariable(t.TypeName(), NewTypeTypeWithParam(t))
}

// FunctionDeclToExprDecl converts a go-native function declaration into a protobuf-typed function declaration.
func FunctionDeclToExprDecl(f *FunctionDecl) (*exprpb.Decl, error) {
overloads := make([]*exprpb.Decl_FunctionDecl_Overload, len(f.Overloads))
Expand Down Expand Up @@ -707,7 +712,7 @@ func FunctionDeclToExprDecl(f *FunctionDecl) (*exprpb.Decl, error) {

func collectParamNames(paramNames map[string]struct{}, arg *Type) {
if arg.Kind == TypeParamKind {
paramNames[arg.RuntimeTypeName()] = struct{}{}
paramNames[arg.TypeName()] = struct{}{}
}
for _, param := range arg.Parameters {
collectParamNames(paramNames, param)
Expand Down
76 changes: 38 additions & 38 deletions common/decls/decls_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ import (

func TestFunctionBindings(t *testing.T) {
sizeFunc, err := NewFunction("size",
MemberOverload("list_size", []*Type{ListType(TypeParamType("T"))}, IntType),
MemberOverload("list_size", []*Type{NewListType(NewTypeParamType("T"))}, IntType),
)
if err != nil {
t.Fatalf("NewFunction() failed: %v", err)
Expand All @@ -44,7 +44,7 @@ func TestFunctionBindings(t *testing.T) {
t.Errorf("sizeFunc.Bindings() produced %d bindings, wanted none", len(bindings))
}
sizeFuncDef, err := NewFunction("size",
MemberOverload("list_size", []*Type{ListType(TypeParamType("T"))}, IntType,
MemberOverload("list_size", []*Type{NewListType(NewTypeParamType("T"))}, IntType,
UnaryBinding(func(list ref.Val) ref.Val {
sizer := list.(traits.Sizer)
return sizer.Size()
Expand Down Expand Up @@ -86,18 +86,18 @@ func TestFunctionVariableArgBindings(t *testing.T) {
return types.DefaultTypeAdapter.NativeToValue(strings.SplitN(str, delim, int(count)))
}
splitFunc, err := NewFunction("split",
MemberOverload("string_split", []*Type{StringType}, ListType(StringType),
MemberOverload("string_split", []*Type{StringType}, NewListType(StringType),
UnaryBinding(func(str ref.Val) ref.Val {
s := str.(types.String)
return splitImpl(string(s), "", -1)
})),
MemberOverload("string_split_string", []*Type{StringType, StringType}, ListType(StringType),
MemberOverload("string_split_string", []*Type{StringType, StringType}, NewListType(StringType),
BinaryBinding(func(str, sep ref.Val) ref.Val {
s := str.(types.String)
delim := sep.(types.String)
return splitImpl(string(s), string(delim), -1)
})),
MemberOverload("string_split_string_int", []*Type{StringType, StringType, IntType}, ListType(StringType),
MemberOverload("string_split_string_int", []*Type{StringType, StringType, IntType}, NewListType(StringType),
FunctionBinding(func(args ...ref.Val) ref.Val {
s := args[0].(types.String)
delim := args[1].(types.String)
Expand Down Expand Up @@ -196,11 +196,11 @@ func TestFunctionSingletonBinding(t *testing.T) {
// doesn't actually give much additional benefit. The drawback is that invalid signatures
// at type-check might be valid at runtime.
DisableTypeGuards(true),
Overload("size_map", []*Type{MapType(TypeParamType("K"), TypeParamType("V"))}, IntType),
Overload("size_list", []*Type{ListType(TypeParamType("V"))}, IntType),
Overload("size_map", []*Type{NewMapType(NewTypeParamType("K"), NewTypeParamType("V"))}, IntType),
Overload("size_list", []*Type{NewListType(NewTypeParamType("V"))}, IntType),
Overload("size_string", []*Type{StringType}, IntType),
MemberOverload("map_size", []*Type{MapType(TypeParamType("K"), TypeParamType("V"))}, IntType),
MemberOverload("list_size", []*Type{ListType(TypeParamType("V"))}, IntType),
MemberOverload("map_size", []*Type{NewMapType(NewTypeParamType("K"), NewTypeParamType("V"))}, IntType),
MemberOverload("list_size", []*Type{NewListType(NewTypeParamType("V"))}, IntType),
MemberOverload("string_size", []*Type{StringType}, IntType),
SingletonUnaryBinding(func(arg ref.Val) ref.Val {
return arg.(traits.Sizer).Size()
Expand Down Expand Up @@ -232,8 +232,8 @@ func TestFunctionSingletonBinding(t *testing.T) {

func TestFunctionMerge(t *testing.T) {
sizeFunc, err := NewFunction("size",
MemberOverload("list_size", []*Type{ListType(TypeParamType("T"))}, IntType),
MemberOverload("map_size", []*Type{MapType(TypeParamType("K"), TypeParamType("V"))}, IntType),
MemberOverload("list_size", []*Type{NewListType(NewTypeParamType("T"))}, IntType),
MemberOverload("map_size", []*Type{NewMapType(NewTypeParamType("K"), NewTypeParamType("V"))}, IntType),
)
if err != nil {
t.Fatalf("NewFunction() failed: %v", err)
Expand All @@ -246,7 +246,7 @@ func TestFunctionMerge(t *testing.T) {
t.Errorf("sizeFunc.Merge(sizeFunc) != sizeFunc: %v", out)
}
sizeVecFunc, err := NewFunction("size",
MemberOverload("vector_size", []*Type{OpaqueType("vector", TypeParamType("T"))}, IntType),
MemberOverload("vector_size", []*Type{NewOpaqueType("vector", NewTypeParamType("T"))}, IntType),
SingletonUnaryBinding(func(sizer ref.Val) ref.Val {
return sizer.(traits.Sizer).Size()
}, traits.SizerType),
Expand Down Expand Up @@ -279,13 +279,13 @@ func TestFunctionMerge(t *testing.T) {

func TestFunctionMergeWrongName(t *testing.T) {
sizeFunc, err := NewFunction("size",
MemberOverload("list_size", []*Type{ListType(TypeParamType("T"))}, IntType),
MemberOverload("list_size", []*Type{NewListType(NewTypeParamType("T"))}, IntType),
)
if err != nil {
t.Fatalf("NewFunction() failed: %v", err)
}
sizeVecFunc, err := NewFunction("sizeN",
MemberOverload("vector_size", []*Type{OpaqueType("vector", TypeParamType("T"))}, IntType),
MemberOverload("vector_size", []*Type{NewOpaqueType("vector", NewTypeParamType("T"))}, IntType),
)
if err != nil {
t.Fatalf("NewFunction() failed: %v", err)
Expand All @@ -298,13 +298,13 @@ func TestFunctionMergeWrongName(t *testing.T) {

func TestFunctionMergeOverloadCollision(t *testing.T) {
sizeFunc, err := NewFunction("size",
MemberOverload("list_size", []*Type{ListType(TypeParamType("T"))}, IntType),
MemberOverload("list_size", []*Type{NewListType(NewTypeParamType("T"))}, IntType),
)
if err != nil {
t.Fatalf("NewFunction() failed: %v", err)
}
sizeVecFunc, err := NewFunction("size",
MemberOverload("list_size2", []*Type{ListType(TypeParamType("K"))}, IntType),
MemberOverload("list_size2", []*Type{NewListType(NewTypeParamType("K"))}, IntType),
)
if err != nil {
t.Fatalf("NewFunction() failed: %v", err)
Expand All @@ -317,13 +317,13 @@ func TestFunctionMergeOverloadCollision(t *testing.T) {

func TestFunctionMergeOverloadArgCountRedefinition(t *testing.T) {
sizeFunc, err := NewFunction("size",
MemberOverload("list_size", []*Type{ListType(TypeParamType("T"))}, IntType),
MemberOverload("list_size", []*Type{NewListType(NewTypeParamType("T"))}, IntType),
)
if err != nil {
t.Fatalf("NewFunction() failed: %v", err)
}
sizeVecFunc, err := NewFunction("size",
MemberOverload("list_size", []*Type{ListType(TypeParamType("T")), IntType}, IntType),
MemberOverload("list_size", []*Type{NewListType(NewTypeParamType("T")), IntType}, IntType),
)
if err != nil {
t.Fatalf("NewFunction() failed: %v", err)
Expand All @@ -336,13 +336,13 @@ func TestFunctionMergeOverloadArgCountRedefinition(t *testing.T) {

func TestFunctionMergeOverloadArgTypeRedefinition(t *testing.T) {
sizeFunc, err := NewFunction("size",
MemberOverload("arg_size", []*Type{ListType(TypeParamType("T"))}, IntType),
MemberOverload("arg_size", []*Type{NewListType(NewTypeParamType("T"))}, IntType),
)
if err != nil {
t.Fatalf("NewFunction() failed: %v", err)
}
sizeVecFunc, err := NewFunction("size",
MemberOverload("arg_size", []*Type{MapType(IntType, StringType)}, IntType),
MemberOverload("arg_size", []*Type{NewMapType(IntType, StringType)}, IntType),
)
if err != nil {
t.Fatalf("NewFunction() failed: %v", err)
Expand All @@ -355,7 +355,7 @@ func TestFunctionMergeOverloadArgTypeRedefinition(t *testing.T) {

func TestFunctionMergeSingletonRedefinition(t *testing.T) {
sizeFunc, err := NewFunction("size",
MemberOverload("list_size", []*Type{ListType(TypeParamType("T"))}, IntType),
MemberOverload("list_size", []*Type{NewListType(NewTypeParamType("T"))}, IntType),
SingletonUnaryBinding(func(ref.Val) ref.Val {
return types.IntZero
}),
Expand Down Expand Up @@ -549,12 +549,12 @@ func TestOverloadFunctionBindingRedefinition(t *testing.T) {
func TestOverloadIsNonStrict(t *testing.T) {
fn, err := NewFunction("getOrDefault",
MemberOverload("get",
[]*Type{MapType(
TypeParamType("K"), TypeParamType("V")),
TypeParamType("K"),
TypeParamType("V"),
[]*Type{NewMapType(
NewTypeParamType("K"), NewTypeParamType("V")),
NewTypeParamType("K"),
NewTypeParamType("V"),
},
TypeParamType("V"),
NewTypeParamType("V"),
OverloadOperandTrait(traits.ContainerType|traits.IndexerType),
OverloadIsNonStrict(),
FunctionBinding(func(args ...ref.Val) ref.Val {
Expand Down Expand Up @@ -596,12 +596,12 @@ func TestOverloadIsNonStrict(t *testing.T) {
func TestOverloadOperandTrait(t *testing.T) {
fn, err := NewFunction("getOrDefault",
MemberOverload("get",
[]*Type{MapType(
TypeParamType("K"), TypeParamType("V")),
TypeParamType("K"),
TypeParamType("V"),
[]*Type{NewMapType(
NewTypeParamType("K"), NewTypeParamType("V")),
NewTypeParamType("K"),
NewTypeParamType("V"),
},
TypeParamType("V"),
NewTypeParamType("V"),
OverloadOperandTrait(traits.ContainerType|traits.IndexerType),
FunctionBinding(func(args ...ref.Val) ref.Val {
container := args[0].(traits.Container)
Expand Down Expand Up @@ -641,7 +641,7 @@ func TestFunctionDisableDeclaration(t *testing.T) {
fn, err := NewFunction("in",
DisableDeclaration(true),
Overload("in_list",
[]*Type{ListType(TypeParamType("K")), TypeParamType("K")},
[]*Type{NewListType(NewTypeParamType("K")), NewTypeParamType("K")},
BoolType,
),
)
Expand All @@ -657,7 +657,7 @@ func TestFunctionEnableDeclaration(t *testing.T) {
fn, err := NewFunction("in",
DisableDeclaration(false),
Overload("in_list",
[]*Type{ListType(TypeParamType("K")), TypeParamType("K")},
[]*Type{NewListType(NewTypeParamType("K")), NewTypeParamType("K")},
BoolType),
)
if err != nil {
Expand All @@ -669,7 +669,7 @@ func TestFunctionEnableDeclaration(t *testing.T) {
fn2, err := NewFunction("in",
DisableDeclaration(true),
Overload("in_list",
[]*Type{ListType(TypeParamType("K")), TypeParamType("K")},
[]*Type{NewListType(NewTypeParamType("K")), NewTypeParamType("K")},
BoolType),
)
if err != nil {
Expand Down Expand Up @@ -704,7 +704,7 @@ func TestFunctionDeclToExprDecl(t *testing.T) {
}{
{
fn: testFunction(t, "equals",
Overload("equals_value_value", []*Type{TypeParamType("T"), TypeParamType("T")}, BoolType)),
Overload("equals_value_value", []*Type{NewTypeParamType("T"), NewTypeParamType("T")}, BoolType)),
exDecl: &exprpb.Decl{
Name: "equals",
DeclKind: &exprpb.Decl_Function{
Expand All @@ -726,7 +726,7 @@ func TestFunctionDeclToExprDecl(t *testing.T) {
},
{
fn: testFunction(t, "equals",
MemberOverload("value_equals_value", []*Type{TypeParamType("T"), TypeParamType("T")}, BoolType)),
MemberOverload("value_equals_value", []*Type{NewTypeParamType("T"), NewTypeParamType("T")}, BoolType)),
exDecl: &exprpb.Decl{
Name: "equals",
DeclKind: &exprpb.Decl_Function{
Expand Down Expand Up @@ -793,8 +793,8 @@ func TestFunctionDeclToExprDecl(t *testing.T) {
{
fn: testFunction(t, "equals",
MemberOverload("list_optional_value_equals_list_optional_value", []*Type{
ListType(OptionalType(TypeParamType("T"))),
ListType(OptionalType(TypeParamType("T"))),
NewListType(NewOptionalType(NewTypeParamType("T"))),
NewListType(NewOptionalType(NewTypeParamType("T"))),
}, BoolType)),
exDecl: &exprpb.Decl{
Name: "equals",
Expand Down
Loading

0 comments on commit 10a4b58

Please sign in to comment.