Skip to content

Commit

Permalink
[SPARK-45022][SQL] Provide context for dataset API errors
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
This PR captures the dataset APIs used by the user code and the call site in the user code and provides better error messages.

E.g. consider the following Spark app `SimpleApp.scala`:
```scala
   1  import org.apache.spark.sql.SparkSession
   2  import org.apache.spark.sql.functions._
   3
   4  object SimpleApp {
   5    def main(args: Array[String]) {
   6      val spark = SparkSession.builder.appName("Simple Application").config("spark.sql.ansi.enabled", true).getOrCreate()
   7      import spark.implicits._
   8
   9      val c = col("a") / col("b")
  10
  11      Seq((1, 0)).toDF("a", "b").select(c).show()
  12
  13      spark.stop()
  14    }
  15  }
```

After this PR the error message contains the error context (which Spark Dataset API is called from where in the user code) in the following form:
```
Exception in thread "main" org.apache.spark.SparkArithmeticException: [DIVIDE_BY_ZERO] Division by zero. Use `try_divide` to tolerate divisor being 0 and return NULL instead. If necessary set "spark.sql.ansi.enabled" to "false" to bypass this error.
== Dataset ==
"div" was called from SimpleApp$.main(SimpleApp.scala:9)

	at org.apache.spark.sql.errors.QueryExecutionErrors$.divideByZeroError(QueryExecutionErrors.scala:201)
	at org.apache.spark.sql.catalyst.expressions.DivModLike.eval(arithmetic.scala:672
...
```
which is similar to the already provided context in case of SQL queries:
```
org.apache.spark.SparkArithmeticException: [DIVIDE_BY_ZERO] Division by zero. Use `try_divide` to tolerate divisor being 0 and return NULL instead. If necessary set "spark.sql.ansi.enabled" to "false" to bypass this error.
== SQL(line 1, position 1) ==
a / b
^^^^^

	at org.apache.spark.sql.errors.QueryExecutionErrors$.divideByZeroError(QueryExecutionErrors.scala:201)
	at org.apache.spark.sql.errors.QueryExecutionErrors.divideByZeroError(QueryExecutionErrors.scala)
...
```

Please note that stack trace in `spark-shell` doesn't contain meaningful elements:
```
scala> Thread.currentThread().getStackTrace.foreach(println)
java.base/java.lang.Thread.getStackTrace(Thread.java:1602)
$line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.<init>(<console>:23)
$line15.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw.<init>(<console>:27)
$line15.$read$$iw$$iw$$iw$$iw$$iw$$iw.<init>(<console>:29)
$line15.$read$$iw$$iw$$iw$$iw$$iw.<init>(<console>:31)
$line15.$read$$iw$$iw$$iw$$iw.<init>(<console>:33)
$line15.$read$$iw$$iw$$iw.<init>(<console>:35)
$line15.$read$$iw$$iw.<init>(<console>:37)
$line15.$read$$iw.<init>(<console>:39)
$line15.$read.<init>(<console>:41)
$line15.$read$.<init>(<console>:45)
$line15.$read$.<clinit>(<console>)
$line15.$eval$.$print$lzycompute(<console>:7)
$line15.$eval$.$print(<console>:6)
$line15.$eval.$print(<console>)
java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
...
```
so this change doesn't help with that usecase.

### Why are the changes needed?
To provide more user friendly errors.

### Does this PR introduce _any_ user-facing change?
Yes.

### How was this patch tested?
Added new UTs to `QueryExecutionAnsiErrorsSuite`.

### Was this patch authored or co-authored using generative AI tooling?
No.

Closes apache#43334 from MaxGekk/context-for-dataset-api-errors.

Lead-authored-by: Max Gekk <[email protected]>
Co-authored-by: Peter Toth <[email protected]>
Signed-off-by: Max Gekk <[email protected]>
  • Loading branch information
MaxGekk and peter-toth committed Nov 1, 2023
1 parent e1bc48b commit feea99e
Show file tree
Hide file tree
Showing 71 changed files with 1,163 additions and 471 deletions.
9 changes: 9 additions & 0 deletions common/utils/src/main/java/org/apache/spark/QueryContext.java
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@
*/
@Evolving
public interface QueryContext {
// The type of this query context.
QueryContextType contextType();

// The object type of the query which throws the exception.
// If the exception is directly from the main query, it should be an empty string.
// Otherwise, it should be the exact object type in upper case. For example, a "VIEW".
Expand All @@ -45,4 +48,10 @@ public interface QueryContext {

// The corresponding fragment of the query which throws the exception.
String fragment();

// The user code (call site of the API) that caused throwing the exception.
String callSite();

// Summary of the exception cause.
String summary();
}
31 changes: 31 additions & 0 deletions common/utils/src/main/java/org/apache/spark/QueryContextType.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark;

import org.apache.spark.annotation.Evolving;

/**
* The type of {@link QueryContext}.
*
* @since 4.0.0
*/
@Evolving
public enum QueryContextType {
SQL,
DataFrame
}
Original file line number Diff line number Diff line change
Expand Up @@ -114,13 +114,19 @@ private[spark] object SparkThrowableHelper {
g.writeArrayFieldStart("queryContext")
e.getQueryContext.foreach { c =>
g.writeStartObject()
g.writeStringField("objectType", c.objectType())
g.writeStringField("objectName", c.objectName())
val startIndex = c.startIndex() + 1
if (startIndex > 0) g.writeNumberField("startIndex", startIndex)
val stopIndex = c.stopIndex() + 1
if (stopIndex > 0) g.writeNumberField("stopIndex", stopIndex)
g.writeStringField("fragment", c.fragment())
c.contextType() match {
case QueryContextType.SQL =>
g.writeStringField("objectType", c.objectType())
g.writeStringField("objectName", c.objectName())
val startIndex = c.startIndex() + 1
if (startIndex > 0) g.writeNumberField("startIndex", startIndex)
val stopIndex = c.stopIndex() + 1
if (stopIndex > 0) g.writeNumberField("stopIndex", stopIndex)
g.writeStringField("fragment", c.fragment())
case QueryContextType.DataFrame =>
g.writeStringField("fragment", c.fragment())
g.writeStringField("callSite", c.callSite())
}
g.writeEndObject()
}
g.writeEndArray()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -823,6 +823,13 @@ message FetchErrorDetailsResponse {
// QueryContext defines the schema for the query context of a SparkThrowable.
// It helps users understand where the error occurs while executing queries.
message QueryContext {
// The type of this query context.
enum ContextType {
SQL = 0;
DATAFRAME = 1;
}
ContextType context_type = 10;

// The object type of the query which throws the exception.
// If the exception is directly from the main query, it should be an empty string.
// Otherwise, it should be the exact object type in upper case. For example, a "VIEW".
Expand All @@ -841,6 +848,12 @@ message FetchErrorDetailsResponse {

// The corresponding fragment of the query which throws the exception.
string fragment = 5;

// The user code (call site of the API) that caused throwing the exception.
string callSite = 6;

// Summary of the exception cause.
string summary = 7;
}

// SparkThrowable defines the schema for SparkThrowable exceptions.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import io.grpc.protobuf.StatusProto
import org.json4s.DefaultFormats
import org.json4s.jackson.JsonMethods

import org.apache.spark.{QueryContext, SparkArithmeticException, SparkArrayIndexOutOfBoundsException, SparkDateTimeException, SparkException, SparkIllegalArgumentException, SparkNumberFormatException, SparkRuntimeException, SparkUnsupportedOperationException, SparkUpgradeException}
import org.apache.spark.{QueryContext, QueryContextType, SparkArithmeticException, SparkArrayIndexOutOfBoundsException, SparkDateTimeException, SparkException, SparkIllegalArgumentException, SparkNumberFormatException, SparkRuntimeException, SparkUnsupportedOperationException, SparkUpgradeException}
import org.apache.spark.connect.proto.{FetchErrorDetailsRequest, FetchErrorDetailsResponse, UserContext}
import org.apache.spark.connect.proto.SparkConnectServiceGrpc.SparkConnectServiceBlockingStub
import org.apache.spark.internal.Logging
Expand Down Expand Up @@ -324,15 +324,18 @@ private[client] object GrpcExceptionConverter {

val queryContext = error.getSparkThrowable.getQueryContextsList.asScala.map { queryCtx =>
new QueryContext {
override def contextType(): QueryContextType = queryCtx.getContextType match {
case FetchErrorDetailsResponse.QueryContext.ContextType.DATAFRAME =>
QueryContextType.DataFrame
case _ => QueryContextType.SQL
}
override def objectType(): String = queryCtx.getObjectType

override def objectName(): String = queryCtx.getObjectName

override def startIndex(): Int = queryCtx.getStartIndex

override def stopIndex(): Int = queryCtx.getStopIndex

override def fragment(): String = queryCtx.getFragment
override def callSite(): String = queryCtx.getCallSite
override def summary(): String = queryCtx.getSummary
}
}.toArray

Expand Down
60 changes: 43 additions & 17 deletions core/src/test/scala/org/apache/spark/SparkFunSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,7 @@ abstract class SparkFunSuite
sqlState: Option[String] = None,
parameters: Map[String, String] = Map.empty,
matchPVals: Boolean = false,
queryContext: Array[QueryContext] = Array.empty): Unit = {
queryContext: Array[ExpectedContext] = Array.empty): Unit = {
assert(exception.getErrorClass === errorClass)
sqlState.foreach(state => assert(exception.getSqlState === state))
val expectedParameters = exception.getMessageParameters.asScala
Expand All @@ -364,16 +364,25 @@ abstract class SparkFunSuite
val actualQueryContext = exception.getQueryContext()
assert(actualQueryContext.length === queryContext.length, "Invalid length of the query context")
actualQueryContext.zip(queryContext).foreach { case (actual, expected) =>
assert(actual.objectType() === expected.objectType(),
"Invalid objectType of a query context Actual:" + actual.toString)
assert(actual.objectName() === expected.objectName(),
"Invalid objectName of a query context. Actual:" + actual.toString)
assert(actual.startIndex() === expected.startIndex(),
"Invalid startIndex of a query context. Actual:" + actual.toString)
assert(actual.stopIndex() === expected.stopIndex(),
"Invalid stopIndex of a query context. Actual:" + actual.toString)
assert(actual.fragment() === expected.fragment(),
"Invalid fragment of a query context. Actual:" + actual.toString)
assert(actual.contextType() === expected.contextType,
"Invalid contextType of a query context Actual:" + actual.toString)
if (actual.contextType() == QueryContextType.SQL) {
assert(actual.objectType() === expected.objectType,
"Invalid objectType of a query context Actual:" + actual.toString)
assert(actual.objectName() === expected.objectName,
"Invalid objectName of a query context. Actual:" + actual.toString)
assert(actual.startIndex() === expected.startIndex,
"Invalid startIndex of a query context. Actual:" + actual.toString)
assert(actual.stopIndex() === expected.stopIndex,
"Invalid stopIndex of a query context. Actual:" + actual.toString)
assert(actual.fragment() === expected.fragment,
"Invalid fragment of a query context. Actual:" + actual.toString)
} else if (actual.contextType() == QueryContextType.DataFrame) {
assert(actual.fragment() === expected.fragment,
"Invalid code fragment of a query context. Actual:" + actual.toString)
assert(actual.callSite().matches(expected.callSitePattern),
"Invalid callSite of a query context. Actual:" + actual.toString)
}
}
}

Expand All @@ -389,29 +398,29 @@ abstract class SparkFunSuite
errorClass: String,
sqlState: String,
parameters: Map[String, String],
context: QueryContext): Unit =
context: ExpectedContext): Unit =
checkError(exception, errorClass, Some(sqlState), parameters, false, Array(context))

protected def checkError(
exception: SparkThrowable,
errorClass: String,
parameters: Map[String, String],
context: QueryContext): Unit =
context: ExpectedContext): Unit =
checkError(exception, errorClass, None, parameters, false, Array(context))

protected def checkError(
exception: SparkThrowable,
errorClass: String,
sqlState: String,
context: QueryContext): Unit =
context: ExpectedContext): Unit =
checkError(exception, errorClass, None, Map.empty, false, Array(context))

protected def checkError(
exception: SparkThrowable,
errorClass: String,
sqlState: Option[String],
parameters: Map[String, String],
context: QueryContext): Unit =
context: ExpectedContext): Unit =
checkError(exception, errorClass, sqlState, parameters,
false, Array(context))

Expand All @@ -426,7 +435,7 @@ abstract class SparkFunSuite
errorClass: String,
sqlState: Option[String],
parameters: Map[String, String],
context: QueryContext): Unit =
context: ExpectedContext): Unit =
checkError(exception, errorClass, sqlState, parameters,
matchPVals = true, Array(context))

Expand All @@ -453,16 +462,33 @@ abstract class SparkFunSuite
parameters = Map("relationName" -> tableName))

case class ExpectedContext(
contextType: QueryContextType,
objectType: String,
objectName: String,
startIndex: Int,
stopIndex: Int,
fragment: String) extends QueryContext
fragment: String,
callSitePattern: String
)

object ExpectedContext {
def apply(fragment: String, start: Int, stop: Int): ExpectedContext = {
ExpectedContext("", "", start, stop, fragment)
}

def apply(
objectType: String,
objectName: String,
startIndex: Int,
stopIndex: Int,
fragment: String): ExpectedContext = {
new ExpectedContext(QueryContextType.SQL, objectType, objectName, startIndex, stopIndex,
fragment, "")
}

def apply(fragment: String, callSitePattern: String): ExpectedContext = {
new ExpectedContext(QueryContextType.DataFrame, "", "", -1, -1, fragment, callSitePattern)
}
}

class LogAppender(msg: String = "", maxEvents: Int = 1000)
Expand Down
51 changes: 51 additions & 0 deletions core/src/test/scala/org/apache/spark/SparkThrowableSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -503,11 +503,14 @@ class SparkThrowableSuite extends SparkFunSuite {
test("Get message in the specified format") {
import ErrorMessageFormat._
class TestQueryContext extends QueryContext {
override val contextType = QueryContextType.SQL
override val objectName = "v1"
override val objectType = "VIEW"
override val startIndex = 2
override val stopIndex = -1
override val fragment = "1 / 0"
override def callSite: String = throw new UnsupportedOperationException
override val summary = ""
}
val e = new SparkArithmeticException(
errorClass = "DIVIDE_BY_ZERO",
Expand Down Expand Up @@ -577,6 +580,54 @@ class SparkThrowableSuite extends SparkFunSuite {
| "message" : "Test message"
| }
|}""".stripMargin)

class TestQueryContext2 extends QueryContext {
override val contextType = QueryContextType.DataFrame
override def objectName: String = throw new UnsupportedOperationException
override def objectType: String = throw new UnsupportedOperationException
override def startIndex: Int = throw new UnsupportedOperationException
override def stopIndex: Int = throw new UnsupportedOperationException
override val fragment: String = "div"
override val callSite: String = "SimpleApp$.main(SimpleApp.scala:9)"
override val summary = ""
}
val e4 = new SparkArithmeticException(
errorClass = "DIVIDE_BY_ZERO",
messageParameters = Map("config" -> "CONFIG"),
context = Array(new TestQueryContext2),
summary = "Query summary")

assert(SparkThrowableHelper.getMessage(e4, PRETTY) ===
"[DIVIDE_BY_ZERO] Division by zero. Use `try_divide` to tolerate divisor being 0 " +
"and return NULL instead. If necessary set CONFIG to \"false\" to bypass this error." +
" SQLSTATE: 22012\nQuery summary")
// scalastyle:off line.size.limit
assert(SparkThrowableHelper.getMessage(e4, MINIMAL) ===
"""{
| "errorClass" : "DIVIDE_BY_ZERO",
| "sqlState" : "22012",
| "messageParameters" : {
| "config" : "CONFIG"
| },
| "queryContext" : [ {
| "fragment" : "div",
| "callSite" : "SimpleApp$.main(SimpleApp.scala:9)"
| } ]
|}""".stripMargin)
assert(SparkThrowableHelper.getMessage(e4, STANDARD) ===
"""{
| "errorClass" : "DIVIDE_BY_ZERO",
| "messageTemplate" : "Division by zero. Use `try_divide` to tolerate divisor being 0 and return NULL instead. If necessary set <config> to \"false\" to bypass this error.",
| "sqlState" : "22012",
| "messageParameters" : {
| "config" : "CONFIG"
| },
| "queryContext" : [ {
| "fragment" : "div",
| "callSite" : "SimpleApp$.main(SimpleApp.scala:9)"
| } ]
|}""".stripMargin)
// scalastyle:on line.size.limit
}

test("overwrite error classes") {
Expand Down
9 changes: 8 additions & 1 deletion project/MimaExcludes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,14 @@ object MimaExcludes {
// [SPARK-45427][CORE] Add RPC SSL settings to SSLOptions and SparkTransportConf
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.network.netty.SparkTransportConf.fromSparkConf"),
// [SPARK-45136][CONNECT] Enhance ClosureCleaner with Ammonite support
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.util.MethodIdentifier$")
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.util.MethodIdentifier$"),
// [SPARK-45022][SQL] Provide context for dataset API errors
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.QueryContext.contextType"),
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.QueryContext.code"),
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.QueryContext.callSite"),
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.QueryContext.summary"),
ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.types.Decimal.fromStringANSI$default$3"),
ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.sql.types.Decimal.fromStringANSI")
)

// Default exclude rules
Expand Down
24 changes: 13 additions & 11 deletions python/pyspark/sql/connect/proto/base_pb2.py

Large diffs are not rendered by default.

Loading

0 comments on commit feea99e

Please sign in to comment.