From 241abf68c08b39ade9b0aae59a240817bed5b33d Mon Sep 17 00:00:00 2001 From: TristonianJones Date: Thu, 1 Jun 2023 12:31:05 -0700 Subject: [PATCH 1/3] Introduce an expression ID to the common.Error value and surface the common.Error type as cel.Error --- cel/cel_test.go | 7 ++--- cel/env.go | 7 +++-- cel/library.go | 8 ++--- cel/macro.go | 11 ++++--- checker/checker.go | 17 +++++------ checker/checker_test.go | 4 +-- checker/errors.go | 66 +++++++++++++++++++++++++++-------------- common/error.go | 8 ++++- common/errors.go | 10 +++---- ext/BUILD.bazel | 2 -- ext/bindings.go | 8 ++--- ext/encoders.go | 5 ---- ext/guards.go | 1 + ext/math.go | 33 ++++++--------------- ext/protos.go | 12 +++----- parser/errors.go | 4 +++ parser/helper.go | 15 ++++++++-- parser/macro.go | 17 +++++------ parser/parser.go | 4 +-- parser/parser_test.go | 8 ++--- server/BUILD.bazel | 1 - server/server.go | 5 ++-- 22 files changed, 128 insertions(+), 125 deletions(-) diff --git a/cel/cel_test.go b/cel/cel_test.go index 11ade5a2..3c4afc0d 100644 --- a/cel/cel_test.go +++ b/cel/cel_test.go @@ -29,7 +29,6 @@ import ( "google.golang.org/protobuf/reflect/protoreflect" "github.com/google/cel-go/checker" - "github.com/google/cel-go/common" "github.com/google/cel-go/common/operators" "github.com/google/cel-go/common/overloads" "github.com/google/cel-go/common/types" @@ -773,7 +772,7 @@ func TestMacroSubset(t *testing.T) { func TestCustomMacro(t *testing.T) { joinMacro := NewReceiverMacro("join", 1, - func(meh MacroExprHelper, iterRange *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *common.Error) { + func(meh MacroExprHelper, iterRange *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *Error) { delim := args[0] iterIdent := meh.Ident("__iter__") accuIdent := meh.AccuIdent() @@ -820,7 +819,7 @@ func TestCustomExistsMacro(t *testing.T) { Variable("attr", MapType(StringType, BoolType)), Macros( NewGlobalVarArgMacro("kleeneOr", - func(meh MacroExprHelper, unused *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *common.Error) { + func(meh MacroExprHelper, unused *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *Error) { inputs := meh.NewList(args...) eqOne, err := ExistsMacroExpander(meh, inputs, []*exprpb.Expr{ meh.Ident("__iter__"), @@ -850,7 +849,7 @@ func TestCustomExistsMacro(t *testing.T) { }, ), NewGlobalMacro("kleeneEq", 2, - func(meh MacroExprHelper, unused *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *common.Error) { + func(meh MacroExprHelper, unused *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *Error) { attr := args[0] value := args[1] hasAttr, err := HasMacroExpander(meh, nil, []*exprpb.Expr{meh.Copy(attr)}) diff --git a/cel/env.go b/cel/env.go index d9c2ef63..60bade61 100644 --- a/cel/env.go +++ b/cel/env.go @@ -596,6 +596,9 @@ func (e *Env) maybeApplyFeature(feature int, option EnvOption) (*Env, error) { return e, nil } +// Error type which references an expression id, a location within source, and a message. +type Error = common.Error + // Issues defines methods for inspecting the error details of parse and check calls. // // Note: in the future, non-fatal warnings and notices may be inspectable via the Issues struct. @@ -622,9 +625,9 @@ func (i *Issues) Err() error { } // Errors returns the collection of errors encountered in more granular detail. -func (i *Issues) Errors() []common.Error { +func (i *Issues) Errors() []*Error { if i == nil { - return []common.Error{} + return []*Error{} } return i.errs.GetErrors() } diff --git a/cel/library.go b/cel/library.go index bcfd44f7..87678e29 100644 --- a/cel/library.go +++ b/cel/library.go @@ -20,7 +20,6 @@ import ( "time" "github.com/google/cel-go/checker" - "github.com/google/cel-go/common" "github.com/google/cel-go/common/operators" "github.com/google/cel-go/common/overloads" "github.com/google/cel-go/common/types" @@ -206,17 +205,14 @@ func (optionalLibrary) CompileOptions() []EnvOption { } } -func optMap(meh MacroExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *common.Error) { +func optMap(meh MacroExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *Error) { varIdent := args[0] varName := "" switch varIdent.GetExprKind().(type) { case *exprpb.Expr_IdentExpr: varName = varIdent.GetIdentExpr().GetName() default: - return nil, &common.Error{ - Message: "optMap() variable name must be a simple identifier", - Location: meh.OffsetLocation(varIdent.GetId()), - } + return nil, meh.NewError(varIdent.GetId(), "optMap() variable name must be a simple identifier") } mapExpr := args[1] return meh.GlobalCall( diff --git a/cel/macro.go b/cel/macro.go index e48c5bf8..1eb414c8 100644 --- a/cel/macro.go +++ b/cel/macro.go @@ -15,7 +15,6 @@ package cel import ( - "github.com/google/cel-go/common" "github.com/google/cel-go/parser" exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1" @@ -63,21 +62,21 @@ func NewReceiverVarArgMacro(function string, expander MacroExpander) Macro { } // HasMacroExpander expands the input call arguments into a presence test, e.g. has(.field) -func HasMacroExpander(meh MacroExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *common.Error) { +func HasMacroExpander(meh MacroExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *Error) { return parser.MakeHas(meh, target, args) } // ExistsMacroExpander expands the input call arguments into a comprehension that returns true if any of the // elements in the range match the predicate expressions: // .exists(, ) -func ExistsMacroExpander(meh MacroExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *common.Error) { +func ExistsMacroExpander(meh MacroExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *Error) { return parser.MakeExists(meh, target, args) } // ExistsOneMacroExpander expands the input call arguments into a comprehension that returns true if exactly // one of the elements in the range match the predicate expressions: // .exists_one(, ) -func ExistsOneMacroExpander(meh MacroExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *common.Error) { +func ExistsOneMacroExpander(meh MacroExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *Error) { return parser.MakeExistsOne(meh, target, args) } @@ -91,14 +90,14 @@ func ExistsOneMacroExpander(meh MacroExprHelper, target *exprpb.Expr, args []*ex // // In the second form only iterVar values which return true when provided to the predicate expression // are transformed. -func MapMacroExpander(meh MacroExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *common.Error) { +func MapMacroExpander(meh MacroExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *Error) { return parser.MakeMap(meh, target, args) } // FilterMacroExpander expands the input call arguments into a comprehension which produces a list which contains // only elements which match the provided predicate expression: // .filter(, ) -func FilterMacroExpander(meh MacroExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *common.Error) { +func FilterMacroExpander(meh MacroExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *Error) { return parser.MakeFilter(meh, target, args) } diff --git a/checker/checker.go b/checker/checker.go index 257cffec..505d9616 100644 --- a/checker/checker.go +++ b/checker/checker.go @@ -18,7 +18,6 @@ package checker import ( "fmt" - "reflect" "github.com/google/cel-go/checker/decls" "github.com/google/cel-go/common" @@ -50,9 +49,10 @@ type checker struct { func Check(parsedExpr *exprpb.ParsedExpr, source common.Source, env *Env) (*exprpb.CheckedExpr, *common.Errors) { + errs := common.NewErrors(source) c := checker{ env: env, - errors: &typeErrors{common.NewErrors(source)}, + errors: &typeErrors{errs: errs}, mappings: newMapping(), freeTypeVarCounter: 0, sourceInfo: parsedExpr.GetSourceInfo(), @@ -73,7 +73,7 @@ func Check(parsedExpr *exprpb.ParsedExpr, SourceInfo: parsedExpr.GetSourceInfo(), TypeMap: m, ReferenceMap: c.references, - }, c.errors.Errors + }, errs } func (c *checker) check(e *exprpb.Expr) { @@ -113,8 +113,7 @@ func (c *checker) check(e *exprpb.Expr) { case *exprpb.Expr_ComprehensionExpr: c.checkComprehension(e) default: - c.errors.ReportError( - c.location(e), "Unrecognized ast type: %v", reflect.TypeOf(e)) + c.errors.unexpectedASTType(c.location(e), e) } } @@ -200,7 +199,7 @@ func (c *checker) checkOptSelect(e *exprpb.Expr) { field := call.GetArgs()[1] fieldName, isString := maybeUnwrapString(field) if !isString { - c.errors.ReportError(c.location(field), "unsupported optional field selection: %v", field) + c.errors.notAnOptionalFieldSelection(c.location(field), field) return } @@ -637,8 +636,7 @@ func (c *checker) lookupFieldType(l common.Location, messageType string, fieldNa func (c *checker) setType(e *exprpb.Expr, t *exprpb.Type) { if old, found := c.types[e.GetId()]; found && !proto.Equal(old, t) { - c.errors.ReportError(c.location(e), - "(Incompatible) Type already exists for expression: %v(%d) old:%v, new:%v", e, e.GetId(), old, t) + c.errors.incompatibleType(c.location(e), e, old, t) return } c.types[e.GetId()] = t @@ -650,8 +648,7 @@ func (c *checker) getType(e *exprpb.Expr) *exprpb.Type { func (c *checker) setReference(e *exprpb.Expr, r *exprpb.Reference) { if old, found := c.references[e.GetId()]; found && !proto.Equal(old, r) { - c.errors.ReportError(c.location(e), - "Reference already exists for expression: %v(%d) old:%v, new:%v", e, e.GetId(), old, r) + c.errors.referenceRedefinition(c.location(e), e, old, r) return } c.references[e.GetId()] = r diff --git a/checker/checker_test.go b/checker/checker_test.go index d070766d..af43573c 100644 --- a/checker/checker_test.go +++ b/checker/checker_test.go @@ -351,7 +351,7 @@ _!=_(_-_(_+_(1~double, _*_(2~double, 3~double)~double^multiply_double) }, }, err: ` - ERROR: :1:2: [internal] unexpected failed resolution of 'google.expr.proto3.test.Proto2Message' + ERROR: :1:2: unexpected failed resolution of 'google.expr.proto3.test.Proto2Message' | x.single_int32 != null | .^ `, @@ -1667,7 +1667,7 @@ _&&_(_==_(list~type(list(dyn))^list, { in: `[].length`, err: ` - ERROR: :1:3: type 'list_type:{elem_type:{type_param:"_var0"}}' does not support field selection + ERROR: :1:3: type 'list(_var0)' does not support field selection | [].length | ..^ `, diff --git a/checker/errors.go b/checker/errors.go index 0014f9ab..94717182 100644 --- a/checker/errors.go +++ b/checker/errors.go @@ -15,6 +15,8 @@ package checker import ( + "reflect" + "github.com/google/cel-go/common" exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1" @@ -22,53 +24,71 @@ import ( // typeErrors is a specialization of Errors. type typeErrors struct { - *common.Errors + errs *common.Errors } -func (e *typeErrors) undeclaredReference(l common.Location, container string, name string) { - e.ReportError(l, "undeclared reference to '%s' (in container '%s')", name, container) -} - -func (e *typeErrors) typeDoesNotSupportFieldSelection(l common.Location, t *exprpb.Type) { - e.ReportError(l, "type '%s' does not support field selection", t) +func (e *typeErrors) fieldTypeMismatch(l common.Location, name string, field *exprpb.Type, value *exprpb.Type) { + e.errs.ReportError(l, "expected type of field '%s' is '%s' but provided type is '%s'", + name, FormatCheckedType(field), FormatCheckedType(value)) } -func (e *typeErrors) undefinedField(l common.Location, field string) { - e.ReportError(l, "undefined field '%s'", field) +func (e *typeErrors) incompatibleType(l common.Location, ex *exprpb.Expr, prev, next *exprpb.Type) { + e.errs.ReportError(l, + "incompatible type already exists for expression: %v(%d) old:%v, new:%v", ex, ex.GetId(), prev, next) } func (e *typeErrors) noMatchingOverload(l common.Location, name string, args []*exprpb.Type, isInstance bool) { signature := formatFunction(nil, args, isInstance) - e.ReportError(l, "found no matching overload for '%s' applied to '%s'", name, signature) + e.errs.ReportError(l, "found no matching overload for '%s' applied to '%s'", name, signature) } -func (e *typeErrors) notAType(l common.Location, t *exprpb.Type) { - e.ReportError(l, "'%s(%v)' is not a type", FormatCheckedType(t), t) +func (e *typeErrors) notAComprehensionRange(l common.Location, t *exprpb.Type) { + e.errs.ReportError(l, "expression of type '%s' cannot be range of a comprehension (must be list, map, or dynamic)", + FormatCheckedType(t)) } -func (e *typeErrors) notAMessageType(l common.Location, t *exprpb.Type) { - e.ReportError(l, "'%s' is not a message type", FormatCheckedType(t)) +func (e *typeErrors) notAnOptionalFieldSelection(l common.Location, field *exprpb.Expr) { + e.errs.ReportError(l, "unsupported optional field selection: %v", field) } -func (e *typeErrors) fieldTypeMismatch(l common.Location, name string, field *exprpb.Type, value *exprpb.Type) { - e.ReportError(l, "expected type of field '%s' is '%s' but provided type is '%s'", - name, FormatCheckedType(field), FormatCheckedType(value)) +func (e *typeErrors) notAType(l common.Location, t *exprpb.Type) { + e.errs.ReportError(l, "'%s' is not a type", FormatCheckedType(t)) } -func (e *typeErrors) unexpectedFailedResolution(l common.Location, typeName string) { - e.ReportError(l, "[internal] unexpected failed resolution of '%s'", typeName) +func (e *typeErrors) notAMessageType(l common.Location, t *exprpb.Type) { + e.errs.ReportError(l, "'%s' is not a message type", FormatCheckedType(t)) } -func (e *typeErrors) notAComprehensionRange(l common.Location, t *exprpb.Type) { - e.ReportError(l, "expression of type '%s' cannot be range of a comprehension (must be list, map, or dynamic)", - FormatCheckedType(t)) +func (e *typeErrors) referenceRedefinition(l common.Location, ex *exprpb.Expr, prev, next *exprpb.Reference) { + e.errs.ReportError(l, + "reference already exists for expression: %v(%d) old:%v, new:%v", ex, ex.GetId(), prev, next) +} + +func (e *typeErrors) typeDoesNotSupportFieldSelection(l common.Location, t *exprpb.Type) { + e.errs.ReportError(l, "type '%s' does not support field selection", FormatCheckedType(t)) } func (e *typeErrors) typeMismatch(l common.Location, expected *exprpb.Type, actual *exprpb.Type) { - e.ReportError(l, "expected type '%s' but found '%s'", + e.errs.ReportError(l, "expected type '%s' but found '%s'", FormatCheckedType(expected), FormatCheckedType(actual)) } +func (e *typeErrors) undefinedField(l common.Location, field string) { + e.errs.ReportError(l, "undefined field '%s'", field) +} + +func (e *typeErrors) undeclaredReference(l common.Location, container string, name string) { + e.errs.ReportError(l, "undeclared reference to '%s' (in container '%s')", name, container) +} + +func (e *typeErrors) unexpectedFailedResolution(l common.Location, typeName string) { + e.errs.ReportError(l, "unexpected failed resolution of '%s'", typeName) +} + +func (e *typeErrors) unexpectedASTType(l common.Location, ex *exprpb.Expr) { + e.errs.ReportError(l, "unrecognized ast type: %v", reflect.TypeOf(ex)) +} + func formatFunction(resultType *exprpb.Type, argTypes []*exprpb.Type, isInstance bool) string { result := "" if isInstance { diff --git a/common/error.go b/common/error.go index f91f7f8d..774dcb5b 100644 --- a/common/error.go +++ b/common/error.go @@ -22,10 +22,16 @@ import ( "golang.org/x/text/width" ) -// Error type which references a location within source and a message. +// NewError creates an error associated with an expression id with the given message at the given location. +func NewError(id int64, message string, location Location) *Error { + return &Error{Message: message, Location: location, ExprID: id} +} + +// Error type which references an expression id, a location within source, and a message. type Error struct { Location Location Message string + ExprID int64 } const ( diff --git a/common/errors.go b/common/errors.go index 1565085a..d9f1e70b 100644 --- a/common/errors.go +++ b/common/errors.go @@ -22,7 +22,7 @@ import ( // Errors type which contains a list of errors observed during parsing. type Errors struct { - errors []Error + errors []*Error source Source numErrors int maxErrorsToReport int @@ -31,7 +31,7 @@ type Errors struct { // NewErrors creates a new instance of the Errors type. func NewErrors(source Source) *Errors { return &Errors{ - errors: []Error{}, + errors: []*Error{}, source: source, maxErrorsToReport: 100, } @@ -43,7 +43,7 @@ func (e *Errors) ReportError(l Location, format string, args ...any) { if e.numErrors > e.maxErrorsToReport { return } - err := Error{ + err := &Error{ Location: l, Message: fmt.Sprintf(format, args...), } @@ -51,12 +51,12 @@ func (e *Errors) ReportError(l Location, format string, args ...any) { } // GetErrors returns the list of observed errors. -func (e *Errors) GetErrors() []Error { +func (e *Errors) GetErrors() []*Error { return e.errors[:] } // Append creates a new Errors object with the current and input errors. -func (e *Errors) Append(errs []Error) *Errors { +func (e *Errors) Append(errs []*Error) *Errors { return &Errors{ errors: append(e.errors, errs...), source: e.source, diff --git a/ext/BUILD.bazel b/ext/BUILD.bazel index 4bcf8a28..ed3c4545 100644 --- a/ext/BUILD.bazel +++ b/ext/BUILD.bazel @@ -20,7 +20,6 @@ go_library( deps = [ "//cel:go_default_library", "//checker/decls:go_default_library", - "//common:go_default_library", "//common/overloads:go_default_library", "//common/types:go_default_library", "//common/types/pb:go_default_library", @@ -53,7 +52,6 @@ go_test( deps = [ "//cel:go_default_library", "//checker:go_default_library", - "//common:go_default_library", "//common/types:go_default_library", "//common/types/ref:go_default_library", "//common/types/traits:go_default_library", diff --git a/ext/bindings.go b/ext/bindings.go index 9cc3c3ef..e69d69f4 100644 --- a/ext/bindings.go +++ b/ext/bindings.go @@ -16,7 +16,6 @@ package ext import ( "github.com/google/cel-go/cel" - "github.com/google/cel-go/common" exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1" ) @@ -71,7 +70,7 @@ func (celBindings) ProgramOptions() []cel.ProgramOption { return []cel.ProgramOption{} } -func celBind(meh cel.MacroExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *common.Error) { +func celBind(meh cel.MacroExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *cel.Error) { if !macroTargetMatchesNamespace(celNamespace, target) { return nil, nil } @@ -81,10 +80,7 @@ func celBind(meh cel.MacroExprHelper, target *exprpb.Expr, args []*exprpb.Expr) case *exprpb.Expr_IdentExpr: varName = varIdent.GetIdentExpr().GetName() default: - return nil, &common.Error{ - Message: "cel.bind() variable names must be simple identifers", - Location: meh.OffsetLocation(varIdent.GetId()), - } + return nil, meh.NewError(varIdent.GetId(), "cel.bind() variable names must be simple identifers") } varInit := args[1] resultExpr := args[2] diff --git a/ext/encoders.go b/ext/encoders.go index d9f9cb51..61ac0b77 100644 --- a/ext/encoders.go +++ b/ext/encoders.go @@ -16,7 +16,6 @@ package ext import ( "encoding/base64" - "reflect" "github.com/google/cel-go/cel" "github.com/google/cel-go/common/types" @@ -86,7 +85,3 @@ func base64DecodeString(str string) ([]byte, error) { func base64EncodeBytes(bytes []byte) (string, error) { return base64.StdEncoding.EncodeToString(bytes), nil } - -var ( - bytesListType = reflect.TypeOf([]byte{}) -) diff --git a/ext/guards.go b/ext/guards.go index 4c7786a6..785c8675 100644 --- a/ext/guards.go +++ b/ext/guards.go @@ -17,6 +17,7 @@ package ext import ( "github.com/google/cel-go/common/types" "github.com/google/cel-go/common/types/ref" + exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1" ) diff --git a/ext/math.go b/ext/math.go index 1c8ad585..0b9a3610 100644 --- a/ext/math.go +++ b/ext/math.go @@ -19,10 +19,10 @@ import ( "strings" "github.com/google/cel-go/cel" - "github.com/google/cel-go/common" "github.com/google/cel-go/common/types" "github.com/google/cel-go/common/types/ref" "github.com/google/cel-go/common/types/traits" + exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1" ) @@ -187,24 +187,18 @@ func (mathLib) ProgramOptions() []cel.ProgramOption { return []cel.ProgramOption{} } -func mathLeast(meh cel.MacroExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *common.Error) { +func mathLeast(meh cel.MacroExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *cel.Error) { if !macroTargetMatchesNamespace(mathNamespace, target) { return nil, nil } switch len(args) { case 0: - return nil, &common.Error{ - Message: "math.least() requires at least one argument", - Location: meh.OffsetLocation(target.GetId()), - } + return nil, meh.NewError(target.GetId(), "math.least() requires at least one argument") case 1: if isListLiteralWithValidArgs(args[0]) || isValidArgType(args[0]) { return meh.GlobalCall(minFunc, args[0]), nil } - return nil, &common.Error{ - Message: "math.least() invalid single argument value", - Location: meh.OffsetLocation(args[0].GetId()), - } + return nil, meh.NewError(args[0].GetId(), "math.least() invalid single argument value") case 2: err := checkInvalidArgs(meh, "math.least()", args) if err != nil { @@ -220,24 +214,18 @@ func mathLeast(meh cel.MacroExprHelper, target *exprpb.Expr, args []*exprpb.Expr } } -func mathGreatest(meh cel.MacroExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *common.Error) { +func mathGreatest(meh cel.MacroExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *cel.Error) { if !macroTargetMatchesNamespace(mathNamespace, target) { return nil, nil } switch len(args) { case 0: - return nil, &common.Error{ - Message: "math.greatest() requires at least one argument", - Location: meh.OffsetLocation(target.GetId()), - } + return nil, meh.NewError(target.GetId(), "math.greatest() requires at least one argument") case 1: if isListLiteralWithValidArgs(args[0]) || isValidArgType(args[0]) { return meh.GlobalCall(maxFunc, args[0]), nil } - return nil, &common.Error{ - Message: "math.greatest() invalid single argument value", - Location: meh.OffsetLocation(args[0].GetId()), - } + return nil, meh.NewError(args[0].GetId(), "math.greatest() invalid single argument value") case 2: err := checkInvalidArgs(meh, "math.greatest()", args) if err != nil { @@ -323,14 +311,11 @@ func maxList(numList ref.Val) ref.Val { } } -func checkInvalidArgs(meh cel.MacroExprHelper, funcName string, args []*exprpb.Expr) *common.Error { +func checkInvalidArgs(meh cel.MacroExprHelper, funcName string, args []*exprpb.Expr) *cel.Error { for _, arg := range args { err := checkInvalidArgLiteral(funcName, arg) if err != nil { - return &common.Error{ - Message: err.Error(), - Location: meh.OffsetLocation(arg.GetId()), - } + return meh.NewError(arg.GetId(), err.Error()) } } return nil diff --git a/ext/protos.go b/ext/protos.go index b905e710..a7ca27a6 100644 --- a/ext/protos.go +++ b/ext/protos.go @@ -16,7 +16,6 @@ package ext import ( "github.com/google/cel-go/cel" - "github.com/google/cel-go/common" exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1" ) @@ -86,7 +85,7 @@ func (protoLib) ProgramOptions() []cel.ProgramOption { } // hasProtoExt generates a test-only select expression for a fully-qualified extension name on a protobuf message. -func hasProtoExt(meh cel.MacroExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *common.Error) { +func hasProtoExt(meh cel.MacroExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *cel.Error) { if !macroTargetMatchesNamespace(protoNamespace, target) { return nil, nil } @@ -98,7 +97,7 @@ func hasProtoExt(meh cel.MacroExprHelper, target *exprpb.Expr, args []*exprpb.Ex } // getProtoExt generates a select expression for a fully-qualified extension name on a protobuf message. -func getProtoExt(meh cel.MacroExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *common.Error) { +func getProtoExt(meh cel.MacroExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *cel.Error) { if !macroTargetMatchesNamespace(protoNamespace, target) { return nil, nil } @@ -109,7 +108,7 @@ func getProtoExt(meh cel.MacroExprHelper, target *exprpb.Expr, args []*exprpb.Ex return meh.Select(args[0], extFieldName), nil } -func getExtFieldName(meh cel.MacroExprHelper, expr *exprpb.Expr) (string, *common.Error) { +func getExtFieldName(meh cel.MacroExprHelper, expr *exprpb.Expr) (string, *cel.Error) { isValid := false extensionField := "" switch expr.GetExprKind().(type) { @@ -117,10 +116,7 @@ func getExtFieldName(meh cel.MacroExprHelper, expr *exprpb.Expr) (string, *commo extensionField, isValid = validateIdentifier(expr) } if !isValid { - return "", &common.Error{ - Message: "invalid extension field", - Location: meh.OffsetLocation(expr.GetId()), - } + return "", meh.NewError(expr.GetId(), "invalid extension field") } return extensionField, nil } diff --git a/parser/errors.go b/parser/errors.go index ce49bb87..e518ad36 100644 --- a/parser/errors.go +++ b/parser/errors.go @@ -25,6 +25,10 @@ type parseErrors struct { *common.Errors } +func (e *parseErrors) internalError(message string) { + e.ReportError(common.NoLocation, message) +} + func (e *parseErrors) syntaxError(l common.Location, message string) { e.ReportError(l, fmt.Sprintf("Syntax error: %s", message)) } diff --git a/parser/helper.go b/parser/helper.go index 8f8f478e..0e514eb1 100644 --- a/parser/helper.go +++ b/parser/helper.go @@ -558,11 +558,22 @@ func (e *exprHelper) Select(operand *exprpb.Expr, field string) *exprpb.Expr { // OffsetLocation implements the ExprHelper interface method. func (e *exprHelper) OffsetLocation(exprID int64) common.Location { - offset := e.parserHelper.positions[exprID] - location, _ := e.parserHelper.source.OffsetLocation(offset) + offset, found := e.parserHelper.positions[exprID] + if !found { + return common.NoLocation + } + location, found := e.parserHelper.source.OffsetLocation(offset) + if !found { + return common.NoLocation + } return location } +// NewError associates an error message with a given expression id, populating the source offset location of the error if possible. +func (e *exprHelper) NewError(exprID int64, message string) *common.Error { + return common.NewError(exprID, message, e.OffsetLocation(exprID)) +} + var ( // Thread-safe pool of ExprHelper values to minimize alloc overhead of ExprHelper creations. exprHelperPool = &sync.Pool{ diff --git a/parser/macro.go b/parser/macro.go index 80e5c66c..6066e8ef 100644 --- a/parser/macro.go +++ b/parser/macro.go @@ -232,6 +232,9 @@ type ExprHelper interface { // OffsetLocation returns the Location of the expression identifier. OffsetLocation(exprID int64) common.Location + + // NewError associates an error message with a given expression id. + NewError(exprID int64, message string) *common.Error } var ( @@ -324,7 +327,7 @@ func MakeExistsOne(eh ExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*ex func MakeMap(eh ExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *common.Error) { v, found := extractIdent(args[0]) if !found { - return nil, &common.Error{Message: "argument is not an identifier"} + return nil, eh.NewError(args[0].GetId(), "argument is not an identifier") } var fn *exprpb.Expr @@ -355,7 +358,7 @@ func MakeMap(eh ExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.E func MakeFilter(eh ExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *common.Error) { v, found := extractIdent(args[0]) if !found { - return nil, &common.Error{Message: "argument is not an identifier"} + return nil, eh.NewError(args[0].GetId(), "argument is not an identifier") } filter := args[1] @@ -372,17 +375,13 @@ func MakeHas(eh ExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.E if s, ok := args[0].ExprKind.(*exprpb.Expr_SelectExpr); ok { return eh.PresenceTest(s.SelectExpr.GetOperand(), s.SelectExpr.GetField()), nil } - return nil, &common.Error{Message: "invalid argument to has() macro"} + return nil, eh.NewError(args[0].GetId(), "invalid argument to has() macro") } func makeQuantifier(kind quantifierKind, eh ExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *common.Error) { v, found := extractIdent(args[0]) if !found { - location := eh.OffsetLocation(args[0].GetId()) - return nil, &common.Error{ - Message: "argument must be a simple name", - Location: location, - } + return nil, eh.NewError(args[0].GetId(), "argument must be a simple name") } var init *exprpb.Expr @@ -411,7 +410,7 @@ func makeQuantifier(kind quantifierKind, eh ExprHelper, target *exprpb.Expr, arg eh.GlobalCall(operators.Add, eh.AccuIdent(), oneExpr), eh.AccuIdent()) result = eh.GlobalCall(operators.Equals, eh.AccuIdent(), oneExpr) default: - return nil, &common.Error{Message: fmt.Sprintf("unrecognized quantifier '%v'", kind)} + return nil, eh.NewError(args[0].GetId(), fmt.Sprintf("unrecognized quantifier '%v'", kind)) } return eh.Fold(v, target, AccumulatorName, init, condition, step, result), nil } diff --git a/parser/parser.go b/parser/parser.go index e6f70f90..204e67c9 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -357,9 +357,9 @@ func (p *parser) parse(expr runes.Buffer, desc string) *exprpb.Expr { if val := recover(); val != nil { switch err := val.(type) { case *lookaheadLimitError: - p.errors.ReportError(common.NoLocation, err.Error()) + p.errors.internalError(err.Error()) case *recursionError: - p.errors.ReportError(common.NoLocation, err.Error()) + p.errors.internalError(err.Error()) case *tooManyErrors: // do nothing case *recoveryLimitError: diff --git a/parser/parser_test.go b/parser/parser_test.go index 04472c28..a0ce497c 100644 --- a/parser/parser_test.go +++ b/parser/parser_test.go @@ -1107,12 +1107,12 @@ var testCases = []testInfo{ }, { I: "[1, 2, 3].map(var, var * var)", - E: `ERROR: :1:14: argument is not an identifier - | [1, 2, 3].map(var, var * var) - | .............^ - ERROR: :1:15: reserved identifier: var + E: `ERROR: :1:15: reserved identifier: var | [1, 2, 3].map(var, var * var) | ..............^ + ERROR: :1:15: argument is not an identifier + | [1, 2, 3].map(var, var * var) + | ..............^ ERROR: :1:20: reserved identifier: var | [1, 2, 3].map(var, var * var) | ...................^ diff --git a/server/BUILD.bazel b/server/BUILD.bazel index c5e80977..a0ff2e4a 100644 --- a/server/BUILD.bazel +++ b/server/BUILD.bazel @@ -13,7 +13,6 @@ go_library( importpath = "github.com/google/cel-go/server", deps = [ "//cel:go_default_library", - "//common:go_default_library", "//common/types:go_default_library", "//common/types/ref:go_default_library", "@com_google_cel_spec//proto/test/v1/proto2:test_all_types_go_proto", diff --git a/server/server.go b/server/server.go index ad752d5b..e3f45a57 100644 --- a/server/server.go +++ b/server/server.go @@ -20,7 +20,6 @@ import ( "fmt" "github.com/google/cel-go/cel" - "github.com/google/cel-go/common" "github.com/google/cel-go/common/types" "github.com/google/cel-go/common/types/ref" @@ -131,7 +130,7 @@ func (s *ConformanceServer) Eval(ctx context.Context, in *confpb.EvalRequest) (* // appendErrors converts the errors from errs to Status messages // and appends them to the list of issues. -func appendErrors(errs []common.Error, issues *[]*statuspb.Status) { +func appendErrors(errs []*cel.Error, issues *[]*statuspb.Status) { for _, e := range errs { status := ErrToStatus(e, confpb.IssueDetails_ERROR) *issues = append(*issues, status) @@ -139,7 +138,7 @@ func appendErrors(errs []common.Error, issues *[]*statuspb.Status) { } // ErrToStatus converts an Error to a Status message with the given severity. -func ErrToStatus(e common.Error, severity confpb.IssueDetails_Severity) *statuspb.Status { +func ErrToStatus(e *cel.Error, severity confpb.IssueDetails_Severity) *statuspb.Status { detail := &confpb.IssueDetails{ Severity: severity, Position: &confpb.SourcePosition{ From cbbf787ae8b55e37a64fdd6fe2466dd6e9c06ddc Mon Sep 17 00:00:00 2001 From: TristonianJones Date: Thu, 1 Jun 2023 18:00:32 -0700 Subject: [PATCH 2/3] Ensure expression ids are included on parser / checker errors where possible --- checker/checker.go | 72 +++++++++++++++++++++------------------------- checker/errors.go | 56 ++++++++++++++++++------------------ checker/types.go | 14 +++------ common/errors.go | 6 ++++ parser/errors.go | 15 ++++++++-- parser/helper.go | 8 +++--- parser/parser.go | 16 +++++------ 7 files changed, 95 insertions(+), 92 deletions(-) diff --git a/checker/checker.go b/checker/checker.go index 505d9616..7d174738 100644 --- a/checker/checker.go +++ b/checker/checker.go @@ -113,7 +113,7 @@ func (c *checker) check(e *exprpb.Expr) { case *exprpb.Expr_ComprehensionExpr: c.checkComprehension(e) default: - c.errors.unexpectedASTType(c.location(e), e) + c.errors.unexpectedASTType(e.GetId(), c.location(e), e) } } @@ -157,8 +157,7 @@ func (c *checker) checkIdent(e *exprpb.Expr) { } c.setType(e, decls.Error) - c.errors.undeclaredReference( - c.location(e), c.env.container.Name(), identExpr.GetName()) + c.errors.undeclaredReference(e.GetId(), c.location(e), c.env.container.Name(), identExpr.GetName()) } func (c *checker) checkSelect(e *exprpb.Expr) { @@ -199,7 +198,7 @@ func (c *checker) checkOptSelect(e *exprpb.Expr) { field := call.GetArgs()[1] fieldName, isString := maybeUnwrapString(field) if !isString { - c.errors.notAnOptionalFieldSelection(c.location(field), field) + c.errors.notAnOptionalFieldSelection(field.GetId(), c.location(field), field) return } @@ -227,7 +226,7 @@ func (c *checker) checkSelectField(e, operand *exprpb.Expr, field string, option // Objects yield their field type declaration as the selection result type, but only if // the field is defined. messageType := targetType - if fieldType, found := c.lookupFieldType(c.location(e), messageType.GetMessageType(), field); found { + if fieldType, found := c.lookupFieldType(e.GetId(), messageType.GetMessageType(), field); found { resultType = fieldType.Type } case kindTypeParam: @@ -241,7 +240,7 @@ func (c *checker) checkSelectField(e, operand *exprpb.Expr, field string, option // Dynamic / error values are treated as DYN type. Errors are handled this way as well // in order to allow forward progress on the check. if !isDynOrError(targetType) { - c.errors.typeDoesNotSupportFieldSelection(c.location(e), targetType) + c.errors.typeDoesNotSupportFieldSelection(e.GetId(), c.location(e), targetType) } resultType = decls.Dyn } @@ -276,15 +275,14 @@ func (c *checker) checkCall(e *exprpb.Expr) { // Check for the existence of the function. fn := c.env.LookupFunction(fnName) if fn == nil { - c.errors.undeclaredReference( - c.location(e), c.env.container.Name(), fnName) + c.errors.undeclaredReference(e.GetId(), c.location(e), c.env.container.Name(), fnName) c.setType(e, decls.Error) return } // Overwrite the function name with its fully qualified resolved name. call.Function = fn.GetName() // Check to see whether the overload resolves. - c.resolveOverloadOrError(c.location(e), e, fn, nil, args) + c.resolveOverloadOrError(e, fn, nil, args) return } @@ -303,7 +301,7 @@ func (c *checker) checkCall(e *exprpb.Expr) { // Overwrite with fully-qualified resolved function name sans receiver target. call.Target = nil call.Function = fn.GetName() - c.resolveOverloadOrError(c.location(e), e, fn, nil, args) + c.resolveOverloadOrError(e, fn, nil, args) return } } @@ -313,19 +311,18 @@ func (c *checker) checkCall(e *exprpb.Expr) { fn := c.env.LookupFunction(fnName) // Function found, attempt overload resolution. if fn != nil { - c.resolveOverloadOrError(c.location(e), e, fn, target, args) + c.resolveOverloadOrError(e, fn, target, args) return } // Function name not declared, record error. - c.errors.undeclaredReference(c.location(e), c.env.container.Name(), fnName) + c.errors.undeclaredReference(e.GetId(), c.location(e), c.env.container.Name(), fnName) } func (c *checker) resolveOverloadOrError( - loc common.Location, e *exprpb.Expr, fn *exprpb.Decl, target *exprpb.Expr, args []*exprpb.Expr) { // Attempt to resolve the overload. - resolution := c.resolveOverload(loc, fn, target, args) + resolution := c.resolveOverload(e, fn, target, args) // No such overload, error noted in the resolveOverload call, type recorded here. if resolution == nil { c.setType(e, decls.Error) @@ -337,7 +334,7 @@ func (c *checker) resolveOverloadOrError( } func (c *checker) resolveOverload( - loc common.Location, + call *exprpb.Expr, fn *exprpb.Decl, target *exprpb.Expr, args []*exprpb.Expr) *overloadResolution { var argTypes []*exprpb.Type @@ -395,8 +392,7 @@ func (c *checker) resolveOverload( for i, arg := range argTypes { argTypes[i] = substitute(c.mappings, arg, true) } - c.errors.noMatchingOverload(loc, fn.GetName(), argTypes, target != nil) - resultType = decls.Error + c.errors.noMatchingOverload(call.GetId(), c.location(call), fn.GetName(), argTypes, target != nil) return nil } @@ -418,10 +414,10 @@ func (c *checker) checkCreateList(e *exprpb.Expr) { var isOptional bool elemType, isOptional = maybeUnwrapOptional(elemType) if !isOptional && !isDyn(elemType) { - c.errors.typeMismatch(c.location(e), decls.NewOptionalType(elemType), elemType) + c.errors.typeMismatch(e.GetId(), c.location(e), decls.NewOptionalType(elemType), elemType) } } - elemsType = c.joinTypes(c.location(e), elemsType, elemType) + elemsType = c.joinTypes(e, elemsType, elemType) } if elemsType == nil { // If the list is empty, assign free type var to elem type. @@ -446,7 +442,7 @@ func (c *checker) checkCreateMap(e *exprpb.Expr) { for _, ent := range mapVal.GetEntries() { key := ent.GetMapKey() c.check(key) - mapKeyType = c.joinTypes(c.location(key), mapKeyType, c.getType(key)) + mapKeyType = c.joinTypes(key, mapKeyType, c.getType(key)) val := ent.GetValue() c.check(val) @@ -455,10 +451,10 @@ func (c *checker) checkCreateMap(e *exprpb.Expr) { var isOptional bool valType, isOptional = maybeUnwrapOptional(valType) if !isOptional && !isDyn(valType) { - c.errors.typeMismatch(c.location(val), decls.NewOptionalType(valType), valType) + c.errors.typeMismatch(val.GetId(), c.location(val), decls.NewOptionalType(valType), valType) } } - mapValueType = c.joinTypes(c.location(val), mapValueType, valType) + mapValueType = c.joinTypes(val, mapValueType, valType) } if mapKeyType == nil { // If the map is empty, assign free type variables to typeKey and value type. @@ -475,7 +471,7 @@ func (c *checker) checkCreateMessage(e *exprpb.Expr) { decl := c.env.LookupIdent(msgVal.GetMessageName()) if decl == nil { c.errors.undeclaredReference( - c.location(e), c.env.container.Name(), msgVal.GetMessageName()) + e.GetId(), c.location(e), c.env.container.Name(), msgVal.GetMessageName()) return } // Ensure the type name is fully qualified in the AST. @@ -485,11 +481,11 @@ func (c *checker) checkCreateMessage(e *exprpb.Expr) { identKind := kindOf(ident.GetType()) if identKind != kindError { if identKind != kindType { - c.errors.notAType(c.location(e), ident.GetType()) + c.errors.notAType(e.GetId(), c.location(e), ident.GetType()) } else { messageType = ident.GetType().GetType() if kindOf(messageType) != kindObject { - c.errors.notAMessageType(c.location(e), messageType) + c.errors.notAMessageType(e.GetId(), c.location(e), messageType) messageType = decls.Error } } @@ -507,7 +503,7 @@ func (c *checker) checkCreateMessage(e *exprpb.Expr) { c.check(value) fieldType := decls.Error - ft, found := c.lookupFieldType(c.locationByID(ent.GetId()), messageType.GetMessageType(), field) + ft, found := c.lookupFieldType(ent.GetId(), messageType.GetMessageType(), field) if found { fieldType = ft.Type } @@ -517,11 +513,11 @@ func (c *checker) checkCreateMessage(e *exprpb.Expr) { var isOptional bool valType, isOptional = maybeUnwrapOptional(valType) if !isOptional && !isDyn(valType) { - c.errors.typeMismatch(c.location(value), decls.NewOptionalType(valType), valType) + c.errors.typeMismatch(value.GetId(), c.location(value), decls.NewOptionalType(valType), valType) } } if !c.isAssignable(fieldType, valType) { - c.errors.fieldTypeMismatch(c.locationByID(ent.Id), field, fieldType, valType) + c.errors.fieldTypeMismatch(ent.GetId(), c.locationByID(ent.GetId()), field, fieldType, valType) } } } @@ -548,7 +544,7 @@ func (c *checker) checkComprehension(e *exprpb.Expr) { // Set the range iteration variable to type DYN as well. varType = decls.Dyn default: - c.errors.notAComprehensionRange(c.location(comp.GetIterRange()), rangeType) + c.errors.notAComprehensionRange(comp.GetIterRange().GetId(), c.location(comp.GetIterRange()), rangeType) varType = decls.Error } @@ -573,9 +569,7 @@ func (c *checker) checkComprehension(e *exprpb.Expr) { } // Checks compatibility of joined types, and returns the most general common type. -func (c *checker) joinTypes(loc common.Location, - previous *exprpb.Type, - current *exprpb.Type) *exprpb.Type { +func (c *checker) joinTypes(e *exprpb.Expr, previous, current *exprpb.Type) *exprpb.Type { if previous == nil { return current } @@ -585,7 +579,7 @@ func (c *checker) joinTypes(loc common.Location, if c.dynAggregateLiteralElementTypesEnabled() { return decls.Dyn } - c.errors.typeMismatch(loc, previous, current) + c.errors.typeMismatch(e.GetId(), c.location(e), previous, current) return decls.Error } @@ -619,10 +613,10 @@ func (c *checker) isAssignableList(l1 []*exprpb.Type, l2 []*exprpb.Type) bool { return false } -func (c *checker) lookupFieldType(l common.Location, messageType string, fieldName string) (*ref.FieldType, bool) { +func (c *checker) lookupFieldType(exprID int64, messageType, fieldName string) (*ref.FieldType, bool) { if _, found := c.env.provider.FindType(messageType); !found { // This should not happen, anyway, report an error. - c.errors.unexpectedFailedResolution(l, messageType) + c.errors.unexpectedFailedResolution(exprID, c.locationByID(exprID), messageType) return nil, false } @@ -630,13 +624,13 @@ func (c *checker) lookupFieldType(l common.Location, messageType string, fieldNa return ft, found } - c.errors.undefinedField(l, fieldName) + c.errors.undefinedField(exprID, c.locationByID(exprID), fieldName) return nil, false } func (c *checker) setType(e *exprpb.Expr, t *exprpb.Type) { if old, found := c.types[e.GetId()]; found && !proto.Equal(old, t) { - c.errors.incompatibleType(c.location(e), e, old, t) + c.errors.incompatibleType(e.GetId(), c.location(e), e, old, t) return } c.types[e.GetId()] = t @@ -648,7 +642,7 @@ func (c *checker) getType(e *exprpb.Expr) *exprpb.Type { func (c *checker) setReference(e *exprpb.Expr, r *exprpb.Reference) { if old, found := c.references[e.GetId()]; found && !proto.Equal(old, r) { - c.errors.referenceRedefinition(c.location(e), e, old, r) + c.errors.referenceRedefinition(e.GetId(), c.location(e), e, old, r) return } c.references[e.GetId()] = r @@ -656,7 +650,7 @@ func (c *checker) setReference(e *exprpb.Expr, r *exprpb.Reference) { func (c *checker) assertType(e *exprpb.Expr, t *exprpb.Type) { if !c.isAssignable(t, c.getType(e)) { - c.errors.typeMismatch(c.location(e), t, c.getType(e)) + c.errors.typeMismatch(e.GetId(), c.location(e), t, c.getType(e)) } } diff --git a/checker/errors.go b/checker/errors.go index 94717182..7f21644a 100644 --- a/checker/errors.go +++ b/checker/errors.go @@ -27,66 +27,66 @@ type typeErrors struct { errs *common.Errors } -func (e *typeErrors) fieldTypeMismatch(l common.Location, name string, field *exprpb.Type, value *exprpb.Type) { - e.errs.ReportError(l, "expected type of field '%s' is '%s' but provided type is '%s'", +func (e *typeErrors) fieldTypeMismatch(id int64, l common.Location, name string, field *exprpb.Type, value *exprpb.Type) { + e.errs.ReportErrorAtID(id, l, "expected type of field '%s' is '%s' but provided type is '%s'", name, FormatCheckedType(field), FormatCheckedType(value)) } -func (e *typeErrors) incompatibleType(l common.Location, ex *exprpb.Expr, prev, next *exprpb.Type) { - e.errs.ReportError(l, +func (e *typeErrors) incompatibleType(id int64, l common.Location, ex *exprpb.Expr, prev, next *exprpb.Type) { + e.errs.ReportErrorAtID(id, l, "incompatible type already exists for expression: %v(%d) old:%v, new:%v", ex, ex.GetId(), prev, next) } -func (e *typeErrors) noMatchingOverload(l common.Location, name string, args []*exprpb.Type, isInstance bool) { +func (e *typeErrors) noMatchingOverload(id int64, l common.Location, name string, args []*exprpb.Type, isInstance bool) { signature := formatFunction(nil, args, isInstance) - e.errs.ReportError(l, "found no matching overload for '%s' applied to '%s'", name, signature) + e.errs.ReportErrorAtID(id, l, "found no matching overload for '%s' applied to '%s'", name, signature) } -func (e *typeErrors) notAComprehensionRange(l common.Location, t *exprpb.Type) { - e.errs.ReportError(l, "expression of type '%s' cannot be range of a comprehension (must be list, map, or dynamic)", +func (e *typeErrors) notAComprehensionRange(id int64, l common.Location, t *exprpb.Type) { + e.errs.ReportErrorAtID(id, l, "expression of type '%s' cannot be range of a comprehension (must be list, map, or dynamic)", FormatCheckedType(t)) } -func (e *typeErrors) notAnOptionalFieldSelection(l common.Location, field *exprpb.Expr) { - e.errs.ReportError(l, "unsupported optional field selection: %v", field) +func (e *typeErrors) notAnOptionalFieldSelection(id int64, l common.Location, field *exprpb.Expr) { + e.errs.ReportErrorAtID(id, l, "unsupported optional field selection: %v", field) } -func (e *typeErrors) notAType(l common.Location, t *exprpb.Type) { - e.errs.ReportError(l, "'%s' is not a type", FormatCheckedType(t)) +func (e *typeErrors) notAType(id int64, l common.Location, t *exprpb.Type) { + e.errs.ReportErrorAtID(id, l, "'%s' is not a type", FormatCheckedType(t)) } -func (e *typeErrors) notAMessageType(l common.Location, t *exprpb.Type) { - e.errs.ReportError(l, "'%s' is not a message type", FormatCheckedType(t)) +func (e *typeErrors) notAMessageType(id int64, l common.Location, t *exprpb.Type) { + e.errs.ReportErrorAtID(id, l, "'%s' is not a message type", FormatCheckedType(t)) } -func (e *typeErrors) referenceRedefinition(l common.Location, ex *exprpb.Expr, prev, next *exprpb.Reference) { - e.errs.ReportError(l, +func (e *typeErrors) referenceRedefinition(id int64, l common.Location, ex *exprpb.Expr, prev, next *exprpb.Reference) { + e.errs.ReportErrorAtID(id, l, "reference already exists for expression: %v(%d) old:%v, new:%v", ex, ex.GetId(), prev, next) } -func (e *typeErrors) typeDoesNotSupportFieldSelection(l common.Location, t *exprpb.Type) { - e.errs.ReportError(l, "type '%s' does not support field selection", FormatCheckedType(t)) +func (e *typeErrors) typeDoesNotSupportFieldSelection(id int64, l common.Location, t *exprpb.Type) { + e.errs.ReportErrorAtID(id, l, "type '%s' does not support field selection", FormatCheckedType(t)) } -func (e *typeErrors) typeMismatch(l common.Location, expected *exprpb.Type, actual *exprpb.Type) { - e.errs.ReportError(l, "expected type '%s' but found '%s'", +func (e *typeErrors) typeMismatch(id int64, l common.Location, expected *exprpb.Type, actual *exprpb.Type) { + e.errs.ReportErrorAtID(id, l, "expected type '%s' but found '%s'", FormatCheckedType(expected), FormatCheckedType(actual)) } -func (e *typeErrors) undefinedField(l common.Location, field string) { - e.errs.ReportError(l, "undefined field '%s'", field) +func (e *typeErrors) undefinedField(id int64, l common.Location, field string) { + e.errs.ReportErrorAtID(id, l, "undefined field '%s'", field) } -func (e *typeErrors) undeclaredReference(l common.Location, container string, name string) { - e.errs.ReportError(l, "undeclared reference to '%s' (in container '%s')", name, container) +func (e *typeErrors) undeclaredReference(id int64, l common.Location, container string, name string) { + e.errs.ReportErrorAtID(id, l, "undeclared reference to '%s' (in container '%s')", name, container) } -func (e *typeErrors) unexpectedFailedResolution(l common.Location, typeName string) { - e.errs.ReportError(l, "unexpected failed resolution of '%s'", typeName) +func (e *typeErrors) unexpectedFailedResolution(id int64, l common.Location, typeName string) { + e.errs.ReportErrorAtID(id, l, "unexpected failed resolution of '%s'", typeName) } -func (e *typeErrors) unexpectedASTType(l common.Location, ex *exprpb.Expr) { - e.errs.ReportError(l, "unrecognized ast type: %v", reflect.TypeOf(ex)) +func (e *typeErrors) unexpectedASTType(id int64, l common.Location, ex *exprpb.Expr) { + e.errs.ReportErrorAtID(id, l, "unrecognized ast type: %v", reflect.TypeOf(ex)) } func formatFunction(resultType *exprpb.Type, argTypes []*exprpb.Type, isInstance bool) string { diff --git a/checker/types.go b/checker/types.go index 28d21c9d..b7ee7569 100644 --- a/checker/types.go +++ b/checker/types.go @@ -317,10 +317,7 @@ func internalIsAssignableAbstractType(m *mapping, a1 *exprpb.Type_AbstractType, func internalIsAssignableFunction(m *mapping, f1 *exprpb.Type_FunctionType, f2 *exprpb.Type_FunctionType) bool { f1ArgTypes := flattenFunctionTypes(f1) f2ArgTypes := flattenFunctionTypes(f2) - if internalIsAssignableList(m, f1ArgTypes, f2ArgTypes) { - return true - } - return false + return internalIsAssignableList(m, f1ArgTypes, f2ArgTypes) } // internalIsAssignableList returns true if the element types at each index in the list are @@ -340,12 +337,9 @@ func internalIsAssignableList(m *mapping, l1 []*exprpb.Type, l2 []*exprpb.Type) // internalIsAssignableMap returns true if map m1 may be assigned to map m2. func internalIsAssignableMap(m *mapping, m1 *exprpb.Type_MapType, m2 *exprpb.Type_MapType) bool { - if internalIsAssignableList(m, + return internalIsAssignableList(m, []*exprpb.Type{m1.GetKeyType(), m1.GetValueType()}, - []*exprpb.Type{m2.GetKeyType(), m2.GetValueType()}) { - return true - } - return false + []*exprpb.Type{m2.GetKeyType(), m2.GetValueType()}) } // internalIsAssignableNull returns true if the type is nullable. @@ -520,7 +514,7 @@ func flattenFunctionTypes(f *exprpb.Type_FunctionType) []*exprpb.Type { if len(argTypes) == 0 { return []*exprpb.Type{f.GetResultType()} } - flattend := make([]*exprpb.Type, len(argTypes)+1, len(argTypes)+1) + flattend := make([]*exprpb.Type, len(argTypes)+1) for i, at := range argTypes { flattend[i] = at } diff --git a/common/errors.go b/common/errors.go index d9f1e70b..63919714 100644 --- a/common/errors.go +++ b/common/errors.go @@ -39,11 +39,17 @@ func NewErrors(source Source) *Errors { // ReportError records an error at a source location. func (e *Errors) ReportError(l Location, format string, args ...any) { + e.ReportErrorAtID(0, l, format, args...) +} + +// ReportErrorAtID records an error at a source location and expression id. +func (e *Errors) ReportErrorAtID(id int64, l Location, format string, args ...any) { e.numErrors++ if e.numErrors > e.maxErrorsToReport { return } err := &Error{ + ExprID: id, Location: l, Message: fmt.Sprintf(format, args...), } diff --git a/parser/errors.go b/parser/errors.go index e518ad36..93ae7a3a 100644 --- a/parser/errors.go +++ b/parser/errors.go @@ -22,13 +22,22 @@ import ( // parseErrors is a specialization of Errors. type parseErrors struct { - *common.Errors + errs *common.Errors +} + +// errorCount indicates the number of errors reported. +func (e *parseErrors) errorCount() int { + return len(e.errs.GetErrors()) } func (e *parseErrors) internalError(message string) { - e.ReportError(common.NoLocation, message) + e.errs.ReportErrorAtID(0, common.NoLocation, message) } func (e *parseErrors) syntaxError(l common.Location, message string) { - e.ReportError(l, fmt.Sprintf("Syntax error: %s", message)) + e.errs.ReportErrorAtID(0, l, fmt.Sprintf("Syntax error: %s", message)) +} + +func (e *parseErrors) reportErrorAtID(id int64, l common.Location, message string, args ...any) { + e.errs.ReportErrorAtID(id, l, message, args...) } diff --git a/parser/helper.go b/parser/helper.go index 0e514eb1..4040db0a 100644 --- a/parser/helper.go +++ b/parser/helper.go @@ -193,15 +193,15 @@ func (p *parserHelper) newExpr(ctx any) *exprpb.Expr { func (p *parserHelper) id(ctx any) int64 { var location common.Location - switch ctx.(type) { + switch c := ctx.(type) { case antlr.ParserRuleContext: - token := (ctx.(antlr.ParserRuleContext)).GetStart() + token := c.GetStart() location = p.source.NewLocation(token.GetLine(), token.GetColumn()) case antlr.Token: - token := ctx.(antlr.Token) + token := c location = p.source.NewLocation(token.GetLine(), token.GetColumn()) case common.Location: - location = ctx.(common.Location) + location = c default: // This should only happen if the ctx is nil return -1 diff --git a/parser/parser.go b/parser/parser.go index 204e67c9..33652697 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -89,8 +89,9 @@ func mustNewParser(opts ...Option) *Parser { // Parse parses the expression represented by source and returns the result. func (p *Parser) Parse(source common.Source) (*exprpb.ParsedExpr, *common.Errors) { + errs := common.NewErrors(source) impl := parser{ - errors: &parseErrors{common.NewErrors(source)}, + errors: &parseErrors{errs}, helper: newParserHelper(source), macros: p.macros, maxRecursionDepth: p.maxRecursionDepth, @@ -115,7 +116,7 @@ func (p *Parser) Parse(source common.Source) (*exprpb.ParsedExpr, *common.Errors return &exprpb.ParsedExpr{ Expr: e, SourceInfo: impl.helper.getSourceInfo(), - }, impl.errors.Errors + }, errs } // reservedIds are not legal to use as variables. We exclude them post-parse, as they *are* valid @@ -449,7 +450,7 @@ func (p *parser) Visit(tree antlr.ParseTree) any { // Report at least one error if the parser reaches an unknown parse element. // Typically, this happens if the parser has already encountered a syntax error elsewhere. - if len(p.errors.GetErrors()) == 0 { + if p.errors.errorCount() == 0 { txt := "<>" if t != nil { txt = fmt.Sprintf("<<%T>>", t) @@ -869,16 +870,15 @@ func (p *parser) unquote(ctx any, value string, isBytes bool) string { func (p *parser) reportError(ctx any, format string, args ...any) *exprpb.Expr { var location common.Location - switch ctx.(type) { + err := p.helper.newExpr(ctx) + switch c := ctx.(type) { case common.Location: - location = ctx.(common.Location) + location = c case antlr.Token, antlr.ParserRuleContext: - err := p.helper.newExpr(ctx) location = p.helper.getLocation(err.GetId()) } - err := p.helper.newExpr(ctx) // Provide arguments to the report error. - p.errors.ReportError(location, format, args...) + p.errors.reportErrorAtID(err.GetId(), location, format, args...) return err } From f9181e597139aee6f215c2c9962d1cec6616418b Mon Sep 17 00:00:00 2001 From: TristonianJones Date: Fri, 2 Jun 2023 16:20:24 -0700 Subject: [PATCH 3/3] Additional tests which validate expression ids are attached to parse / check issues when possible --- checker/checker_test.go | 42 +++++++++++++++++++++++++++++++++++++---- parser/parser_test.go | 16 ++++++++++++++++ 2 files changed, 54 insertions(+), 4 deletions(-) diff --git a/checker/checker_test.go b/checker/checker_test.go index af43573c..c4e5068f 100644 --- a/checker/checker_test.go +++ b/checker/checker_test.go @@ -16,6 +16,7 @@ package checker import ( "fmt" + "strings" "testing" "github.com/google/cel-go/checker/decls" @@ -2247,7 +2248,7 @@ func TestCheck(t *testing.T) { t.Parallel() src := common.NewTextSource(tc.in) - expression, errors := p.Parse(src) + pAst, errors := p.Parse(src) if len(errors.GetErrors()) > 0 { t.Fatalf("Unexpected parse errors: %v", errors.ToDisplayString()) } @@ -2282,7 +2283,7 @@ func TestCheck(t *testing.T) { } } - semantics, errors := Check(expression, src, env) + cAst, errors := Check(pAst, src, env) if len(errors.GetErrors()) > 0 { errorString := errors.ToDisplayString() if tc.err != "" { @@ -2296,7 +2297,7 @@ func TestCheck(t *testing.T) { t.Errorf("Expected error not thrown: %s", tc.err) } - actual := semantics.TypeMap[expression.Expr.Id] + actual := cAst.TypeMap[pAst.Expr.Id] if tc.err == "" { if actual == nil || !proto.Equal(actual, tc.outType) { t.Error(test.DiffMessage("Type Error", actual, tc.outType)) @@ -2304,7 +2305,7 @@ func TestCheck(t *testing.T) { } if tc.out != "" { - actualStr := Print(expression.Expr, semantics) + actualStr := Print(pAst.Expr, cAst) if !test.Compare(actualStr, tc.out) { t.Error(test.DiffMessage("Structure error", actualStr, tc.out)) } @@ -2362,3 +2363,36 @@ func TestAddEquivalentDeclarations(t *testing.T) { t.Errorf("env.Add(optIndexEquiv) failed: %v", err) } } + +func TestCheckErrorExprID(t *testing.T) { + p, err := parser.NewParser( + parser.EnableOptionalSyntax(true), + parser.Macros(parser.AllMacros...), + ) + if err != nil { + t.Fatalf("parser.NewParser() failed: %v", err) + } + src := common.NewTextSource(`a || true`) + ast, iss := p.Parse(src) + if len(iss.GetErrors()) != 0 { + t.Fatalf("Parse() failed: %v", iss.ToDisplayString()) + } + + reg := newTestRegistry(t) + env, err := NewEnv(containers.DefaultContainer, reg) + if err != nil { + t.Fatalf("NewEnv(cont, reg) failed: %v", err) + } + env.Add(StandardDeclarations()...) + _, iss = Check(ast, src, env) + if len(iss.GetErrors()) != 1 { + t.Fatalf("Check() of a bad expression did produce a single error: %v", iss.ToDisplayString()) + } + celErr := iss.GetErrors()[0] + if celErr.ExprID != 1 { + t.Errorf("got exprID %v, wanted 1", celErr.ExprID) + } + if !strings.Contains(celErr.Message, "undeclared reference") { + t.Errorf("got message %v, wanted undeclared reference", celErr.Message) + } +} diff --git a/parser/parser_test.go b/parser/parser_test.go index a0ce497c..14e0d8ed 100644 --- a/parser/parser_test.go +++ b/parser/parser_test.go @@ -1990,6 +1990,22 @@ func BenchmarkParseParallel(b *testing.B) { }) } +func TestParseErrorData(t *testing.T) { + p := newTestParser(t) + src := common.NewTextSource(`a.?b`) + _, iss := p.Parse(src) + if len(iss.GetErrors()) != 1 { + t.Fatalf("Check() of a bad expression did produce a single error: %v", iss.ToDisplayString()) + } + celErr := iss.GetErrors()[0] + if celErr.ExprID != 2 { + t.Errorf("got exprID %v, wanted 2", celErr.ExprID) + } + if !strings.Contains(celErr.Message, "unsupported syntax") { + t.Errorf("got message %v, wanted unsupported syntax", celErr.Message) + } +} + func newTestParser(t *testing.T, options ...Option) *Parser { t.Helper() defaultOpts := []Option{