Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Limit number of errors returned (when variables are used) #1034

Merged
merged 6 commits into from
Sep 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,27 @@ lazy val core = project
description := "Scala GraphQL implementation",
mimaPreviousArtifacts := Set("org.sangria-graphql" %% "sangria-core" % "4.0.0"),
mimaBinaryIssueFilters ++= Seq(
ProblemFilters.exclude[DirectMissingMethodProblem]("sangria.execution.Executor.apply"),
ProblemFilters.exclude[DirectMissingMethodProblem]("sangria.execution.Executor.copy"),
ProblemFilters.exclude[DirectMissingMethodProblem]("sangria.execution.Executor.this"),
ProblemFilters.exclude[DirectMissingMethodProblem]("sangria.execution.Executor.execute"),
ProblemFilters.exclude[DirectMissingMethodProblem]("sangria.execution.Executor.prepare"),
ProblemFilters.exclude[DirectMissingMethodProblem](
"sangria.validation.RuleBasedQueryValidator.this"),
"sangria.execution.QueryReducerExecutor.reduceQueryWithoutVariables"),
ProblemFilters.exclude[DirectMissingMethodProblem](
"sangria.execution.ValueCoercionHelper.isValidValue"),
ProblemFilters.exclude[DirectMissingMethodProblem](
"sangria.execution.ValueCoercionHelper.getVariableValue"),
ProblemFilters.exclude[DirectMissingMethodProblem](
"sangria.execution.batch.BatchExecutor.executeBatch"),
ProblemFilters.exclude[DirectMissingMethodProblem](
"sangria.schema.ResolverBasedAstSchemaBuilder.validateSchema"),
ProblemFilters.exclude[DirectMissingMethodProblem](
"sangria.validation.QueryValidator.validateQuery"),
ProblemFilters.exclude[ReversedMissingMethodProblem](
"sangria.validation.QueryValidator.validateQuery"),
ProblemFilters.exclude[DirectMissingMethodProblem](
"sangria.validation.RuleBasedQueryValidator.validateQuery"),
ProblemFilters.exclude[DirectMissingMethodProblem](
"sangria.validation.ValidationContext.this")
),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import sangria.validation.{QueryValidator, RuleBasedQueryValidator, Violation}
@State(Scope.Thread)
class OverlappingFieldsCanBeMergedBenchmark {

val validator: QueryValidator = RuleBasedQueryValidator(
val validator: QueryValidator = new RuleBasedQueryValidator(
List(new rules.OverlappingFieldsCanBeMerged))

val schema: Schema[_, _] =
Expand Down Expand Up @@ -98,7 +98,7 @@ class OverlappingFieldsCanBeMergedBenchmark {
bh.consume(doValidate(validator, deepAbstractConcrete))

private def doValidate(validator: QueryValidator, document: Document): Vector[Violation] = {
val result = validator.validateQuery(schema, document)
val result = validator.validateQuery(schema, document, None)
require(result.isEmpty)
result
}
Expand Down
27 changes: 18 additions & 9 deletions modules/core/src/main/scala/sangria/execution/Executor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ case class Executor[Ctx, Root](
deprecationTracker: DeprecationTracker = DeprecationTracker.empty,
middleware: List[Middleware[Ctx]] = Nil,
maxQueryDepth: Option[Int] = None,
queryReducers: List[QueryReducer[Ctx, _]] = Nil
queryReducers: List[QueryReducer[Ctx, _]] = Nil,
errorsLimit: Option[Int] = None
yanns marked this conversation as resolved.
Show resolved Hide resolved
)(implicit executionContext: ExecutionContext) {
def prepare[Input](
queryAst: ast.Document,
Expand All @@ -29,7 +30,7 @@ case class Executor[Ctx, Root](
variables: Input = emptyMapVars
)(implicit um: InputUnmarshaller[Input]): Future[PreparedQuery[Ctx, Root, Input]] = {
val (violations, validationTiming) =
TimeMeasurement.measure(queryValidator.validateQuery(schema, queryAst))
TimeMeasurement.measure(queryValidator.validateQuery(schema, queryAst, errorsLimit))

if (violations.nonEmpty)
Future.failed(ValidationError(violations, exceptionHandler))
Expand All @@ -49,7 +50,9 @@ case class Executor[Ctx, Root](
operation <- Executor.getOperation(exceptionHandler, queryAst, operationName)
unmarshalledVariables <- valueCollector.getVariableValues(
operation.variables,
scalarMiddleware)
scalarMiddleware,
errorsLimit
)
fieldCollector = new FieldCollector[Ctx, Root](
schema,
queryAst,
Expand Down Expand Up @@ -141,7 +144,7 @@ case class Executor[Ctx, Root](
um: InputUnmarshaller[Input],
scheme: ExecutionScheme): scheme.Result[Ctx, marshaller.Node] = {
val (violations, validationTiming) =
TimeMeasurement.measure(queryValidator.validateQuery(schema, queryAst))
TimeMeasurement.measure(queryValidator.validateQuery(schema, queryAst, errorsLimit))

if (violations.nonEmpty)
scheme.failed(ValidationError(violations, exceptionHandler))
Expand All @@ -161,7 +164,9 @@ case class Executor[Ctx, Root](
operation <- Executor.getOperation(exceptionHandler, queryAst, operationName)
unmarshalledVariables <- valueCollector.getVariableValues(
operation.variables,
scalarMiddleware)
scalarMiddleware,
errorsLimit
)
fieldCollector = new FieldCollector[Ctx, Root](
schema,
queryAst,
Expand Down Expand Up @@ -324,7 +329,8 @@ object Executor {
deprecationTracker: DeprecationTracker = DeprecationTracker.empty,
middleware: List[Middleware[Ctx]] = Nil,
maxQueryDepth: Option[Int] = None,
queryReducers: List[QueryReducer[Ctx, _]] = Nil
queryReducers: List[QueryReducer[Ctx, _]] = Nil,
errorsLimit: Option[Int] = None
)(implicit
executionContext: ExecutionContext,
marshaller: ResultMarshaller,
Expand All @@ -338,7 +344,8 @@ object Executor {
deprecationTracker,
middleware,
maxQueryDepth,
queryReducers)
queryReducers,
errorsLimit)
.execute(queryAst, userContext, root, operationName, variables)

def prepare[Ctx, Root, Input](
Expand All @@ -354,7 +361,8 @@ object Executor {
deprecationTracker: DeprecationTracker = DeprecationTracker.empty,
middleware: List[Middleware[Ctx]] = Nil,
maxQueryDepth: Option[Int] = None,
queryReducers: List[QueryReducer[Ctx, _]] = Nil
queryReducers: List[QueryReducer[Ctx, _]] = Nil,
errorsLimit: Option[Int] = None
)(implicit
executionContext: ExecutionContext,
um: InputUnmarshaller[Input]): Future[PreparedQuery[Ctx, Root, Input]] =
Expand All @@ -366,7 +374,8 @@ object Executor {
deprecationTracker,
middleware,
maxQueryDepth,
queryReducers)
queryReducers,
errorsLimit)
.prepare(queryAst, userContext, root, operationName, variables)

def getOperationRootType[Ctx, Root](
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,10 @@ object QueryReducerExecutor {
queryValidator: QueryValidator = QueryValidator.default,
exceptionHandler: ExceptionHandler = ExceptionHandler.empty,
deprecationTracker: DeprecationTracker = DeprecationTracker.empty,
middleware: List[Middleware[Ctx]] = Nil
middleware: List[Middleware[Ctx]] = Nil,
errorsLimit: Option[Int] = None
)(implicit executionContext: ExecutionContext): Future[(Ctx, TimeMeasurement)] = {
val violations = queryValidator.validateQuery(schema, queryAst)
val violations = queryValidator.validateQuery(schema, queryAst, errorsLimit)

if (violations.nonEmpty)
Future.failed(ValidationError(violations, exceptionHandler))
Expand Down
189 changes: 105 additions & 84 deletions modules/core/src/main/scala/sangria/execution/ValueCoercionHelper.scala
Original file line number Diff line number Diff line change
Expand Up @@ -580,107 +580,128 @@ class ValueCoercionHelper[Ctx](
nodeLocation.toList ++ firstValue.toList
}

def isValidValue[In](tpe: InputType[_], input: Option[In])(implicit
um: InputUnmarshaller[In]): Vector[Violation] = (tpe, input) match {
case (OptionInputType(ofType), Some(value)) if um.isDefined(value) =>
isValidValue(ofType, Some(value))
case (OptionInputType(_), _) => Vector.empty
case (_, None) => Vector(NotNullValueIsNullViolation(sourceMapper, Nil))

case (ListInputType(ofType), Some(values)) if um.isListNode(values) =>
um.getListValue(values)
.toVector
.flatMap(v =>
isValidValue(
ofType,
v match {
case opt: Option[In @unchecked] => opt
case other => Option(other)
}).map(ListValueViolation(0, _, sourceMapper, Nil)))

case (ListInputType(ofType), Some(value)) if um.isDefined(value) =>
isValidValue(
ofType,
value match {
case opt: Option[In @unchecked] => opt
case other => Option(other)
}).map(ListValueViolation(0, _, sourceMapper, Nil))

case (objTpe: InputObjectType[_], Some(valueMap)) if um.isMapNode(valueMap) =>
val unknownFields = um.getMapKeys(valueMap).toVector.collect {
case f if !objTpe.fieldsByName.contains(f) =>
UnknownInputObjectFieldViolation(
SchemaRenderer.renderTypeName(objTpe, true),
f,
sourceMapper,
Nil)
}
private def isValidValue[In](
inputType: InputType[_],
input: Option[In],
errorsLimit: Option[Int])(implicit um: InputUnmarshaller[In]): Vector[Violation] = {

val fieldViolations =
objTpe.fields.toVector.flatMap(f =>
isValidValue(f.fieldType, um.getMapValue(valueMap, f.name))
.map(MapValueViolation(f.name, _, sourceMapper, Nil)))
// keeping track of the number of errors
var errors = 0
Copy link
Contributor Author

@bc-dima-pasieka bc-dima-pasieka Jul 17, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using mutable variable was the only viable option here to avoid making huge refactoring and keeping the same signature of the original method (by returning Vector[Violation]).

The alternative 1 was to use ListBuffer[Violation] but it would require making signature changes (probably just to return Unit) + it didn't work well with .map(ListValueViolation(0, _, sourceMapper, Nil)).

The alternative 2 was to propagate the number of errors back to the recursion function but I was not able to make it work properly 🙂

def addViolation(violation: Violation): Vector[Violation] = {
errors += 1
Vector(violation)
}

fieldViolations ++ unknownFields
def isValidValueRec(tpe: InputType[_], in: Option[In])(implicit
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I created a separate recursion function (basically the old method migrated here for the most part) to be called internally so it has easy access to mutable variable

um: InputUnmarshaller[In]): Vector[Violation] =
// early termination if errors limit is defined and the current number of violations exceeds the limit
if (errorsLimit.exists(_ <= errors)) Vector.empty
else
(tpe, in) match {
case (OptionInputType(ofType), Some(value)) if um.isDefined(value) =>
isValidValueRec(ofType, Some(value))
case (OptionInputType(_), _) => Vector.empty
case (_, None) => addViolation(NotNullValueIsNullViolation(sourceMapper, Nil))

case (ListInputType(ofType), Some(values)) if um.isListNode(values) =>
um.getListValue(values)
.toVector
.flatMap(v =>
isValidValueRec(
ofType,
v match {
case opt: Option[In @unchecked] => opt
case other => Option(other)
}).map(ListValueViolation(0, _, sourceMapper, Nil)))

case (ListInputType(ofType), Some(value)) if um.isDefined(value) =>
isValidValueRec(
ofType,
value match {
case opt: Option[In @unchecked] => opt
case other => Option(other)
}).map(ListValueViolation(0, _, sourceMapper, Nil))

case (objTpe: InputObjectType[_], Some(valueMap)) if um.isMapNode(valueMap) =>
val unknownFields = um.getMapKeys(valueMap).toVector.collect {
case f if !objTpe.fieldsByName.contains(f) =>
addViolation(
UnknownInputObjectFieldViolation(
SchemaRenderer.renderTypeName(objTpe, true),
f,
sourceMapper,
Nil)).head
}

case (objTpe: InputObjectType[_], _) =>
Vector(
InputObjectIsOfWrongTypeMissingViolation(
SchemaRenderer.renderTypeName(objTpe, true),
sourceMapper,
Nil))
val fieldViolations =
objTpe.fields.toVector.flatMap(f =>
isValidValueRec(f.fieldType, um.getMapValue(valueMap, f.name))
.map(MapValueViolation(f.name, _, sourceMapper, Nil)))

case (scalar: ScalarType[_], Some(value)) if um.isScalarNode(value) =>
val coerced = um.getScalarValue(value) match {
case node: ast.Value => scalar.coerceInput(node)
case other => scalar.coerceUserInput(other)
}
fieldViolations ++ unknownFields

coerced match {
case Left(violation) => Vector(violation)
case _ => Vector.empty
}
case (objTpe: InputObjectType[_], _) =>
addViolation(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So instead of returning Vector(violation), we call addViolation(violation) that mutates number of errors and returns the same Vector(violation)

InputObjectIsOfWrongTypeMissingViolation(
SchemaRenderer.renderTypeName(objTpe, true),
sourceMapper,
Nil))

case (scalar: ScalarAlias[_, _], Some(value)) if um.isScalarNode(value) =>
val coerced = um.getScalarValue(value) match {
case node: ast.Value => scalar.aliasFor.coerceInput(node)
case other => scalar.aliasFor.coerceUserInput(other)
}
case (scalar: ScalarType[_], Some(value)) if um.isScalarNode(value) =>
val coerced = um.getScalarValue(value) match {
case node: ast.Value => scalar.coerceInput(node)
case other => scalar.coerceUserInput(other)
}

coerced match {
case Left(violation) => Vector(violation)
case Right(v) =>
scalar.fromScalar(v) match {
case Left(violation) => Vector(violation)
case _ => Vector.empty
}
}
coerced match {
case Left(violation) => addViolation(violation)
case _ => Vector.empty
}

case (enumT: EnumType[_], Some(value)) if um.isEnumNode(value) =>
val coerced = um.getScalarValue(value) match {
case node: ast.Value => enumT.coerceInput(node)
case other => enumT.coerceUserInput(other)
}
case (scalar: ScalarAlias[_, _], Some(value)) if um.isScalarNode(value) =>
val coerced = um.getScalarValue(value) match {
case node: ast.Value => scalar.aliasFor.coerceInput(node)
case other => scalar.aliasFor.coerceUserInput(other)
}

coerced match {
case Left(violation) => Vector(violation)
case _ => Vector.empty
}
coerced match {
case Left(violation) => addViolation(violation)
case Right(v) =>
scalar.fromScalar(v) match {
case Left(violation) => addViolation(violation)
case _ => Vector.empty
}
}

case (enumT: EnumType[_], Some(value)) =>
Vector(EnumCoercionViolation)
case (enumT: EnumType[_], Some(value)) if um.isEnumNode(value) =>
val coerced = um.getScalarValue(value) match {
case node: ast.Value => enumT.coerceInput(node)
case other => enumT.coerceUserInput(other)
}

coerced match {
case Left(violation) => addViolation(violation)
case _ => Vector.empty
}

case (enumT: EnumType[_], Some(value)) =>
addViolation(EnumCoercionViolation)

case _ =>
Vector(GenericInvalidValueViolation(sourceMapper, Nil))
case _ =>
addViolation(GenericInvalidValueViolation(sourceMapper, Nil))
}

isValidValueRec(inputType, input)
}

def getVariableValue[In](
definition: ast.VariableDefinition,
tpe: InputType[_],
input: Option[In],
fromScalarMiddleware: Option[(Any, InputType[_]) => Option[Either[Violation, Any]]])(implicit
um: InputUnmarshaller[In]): Either[Vector[Violation], Option[VariableValue]] = {
val violations = isValidValue(tpe, input)
fromScalarMiddleware: Option[(Any, InputType[_]) => Option[Either[Violation, Any]]],
errorsLimit: Option[Int]
)(implicit um: InputUnmarshaller[In]): Either[Vector[Violation], Option[VariableValue]] = {
val violations = isValidValue(tpe, input, errorsLimit)

if (violations.isEmpty) {
val fieldPath = s"$$${definition.name}" :: Nil
Expand Down
Loading