From 9a7b6e5c31bfcca4283ed6bc22df10b743e9a470 Mon Sep 17 00:00:00 2001 From: Livia Zhu Date: Thu, 5 Sep 2024 16:24:01 +0900 Subject: [PATCH] [SPARK-49474][SS] Classify Error class for FlatMapGroupsWithState user function error ### What changes were proposed in this pull request? Add new error classification for errors occurring in the user function that is used in FlatMapGroupsWithState. ### Why are the changes needed? The user provided function can throw any type of error. Using the new error framework for better error messages and classification. ### Does this PR introduce _any_ user-facing change? Yes, better error message with error class for Foreach sink user function failures. ### How was this patch tested? Updated existing tests and added new unit test in FlatMapGroupsWithStateSuite. ### Was this patch authored or co-authored using generative AI tooling? No Closes #47940 from liviazhu-db/liviazhu-db/classify-flatmapgroupswithstate-error. Authored-by: Livia Zhu Signed-off-by: Jungtaek Lim --- .../resources/error/error-conditions.json | 6 ++ .../FlatMapGroupsWithStateExec.scala | 48 +++++++++++-- .../FlatMapGroupsWithStateSuite.scala | 70 ++++++++++++++++++- 3 files changed, 117 insertions(+), 7 deletions(-) diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index e2725a98a63bd..96105c9672254 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -1457,6 +1457,12 @@ ], "sqlState" : "42704" }, + "FLATMAPGROUPSWITHSTATE_USER_FUNCTION_ERROR" : { + "message" : [ + "An error occurred in the user provided function in flatMapGroupsWithState. Reason: " + ], + "sqlState" : "39000" + }, "FORBIDDEN_OPERATION" : { "message" : [ "The operation is not allowed on the : ." diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala index d56dfebd61ba1..766caaab2285e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala @@ -18,8 +18,11 @@ package org.apache.spark.sql.execution.streaming import java.util.concurrent.TimeUnit.NANOSECONDS +import scala.util.control.NonFatal + import org.apache.hadoop.conf.Configuration +import org.apache.spark.{SparkException, SparkThrowable} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder @@ -447,10 +450,33 @@ case class FlatMapGroupsWithStateExec( hasTimedOut, watermarkPresent) - // Call function, get the returned objects and convert them to rows - val mappedIterator = func(keyObj, valueObjIter, groupState).map { obj => - numOutputRows += 1 - getOutputRow(obj) + def withUserFuncExceptionHandling[T](func: => T): T = { + try { + func + } catch { + case NonFatal(e) if !e.isInstanceOf[SparkThrowable] => + throw FlatMapGroupsWithStateUserFuncException(e) + case f: Throwable => + throw f + } + } + + val mappedIterator = withUserFuncExceptionHandling { + func(keyObj, valueObjIter, groupState).map { obj => + numOutputRows += 1 + getOutputRow(obj) + } + } + + // Wrap user-provided fns with error handling + val wrappedMappedIterator = new Iterator[InternalRow] { + override def hasNext: Boolean = { + withUserFuncExceptionHandling(mappedIterator.hasNext) + } + + override def next(): InternalRow = { + withUserFuncExceptionHandling(mappedIterator.next()) + } } // When the iterator is consumed, then write changes to state @@ -472,7 +498,9 @@ case class FlatMapGroupsWithStateExec( } // Return an iterator of rows such that fully consumed, the updated state value will be saved - CompletionIterator[InternalRow, Iterator[InternalRow]](mappedIterator, onIteratorCompletion) + CompletionIterator[InternalRow, Iterator[InternalRow]]( + wrappedMappedIterator, onIteratorCompletion + ) } } } @@ -544,3 +572,13 @@ object FlatMapGroupsWithStateExec { } } } + + +/** + * Exception that wraps the exception thrown in the user provided function in Foreach sink. + */ +private[sql] case class FlatMapGroupsWithStateUserFuncException(cause: Throwable) + extends SparkException( + errorClass = "FLATMAPGROUPSWITHSTATE_USER_FUNCTION_ERROR", + messageParameters = Map("reason" -> Option(cause.getMessage).getOrElse("")), + cause = cause) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala index 45a80a210fcee..f3ef73c6af5fa 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala @@ -23,7 +23,6 @@ import java.sql.Timestamp import org.apache.commons.io.FileUtils import org.scalatest.exceptions.TestFailedException -import org.apache.spark.SparkException import org.apache.spark.api.java.function.FlatMapGroupsWithStateFunction import org.apache.spark.sql.{DataFrame, Encoder} import org.apache.spark.sql.catalyst.InternalRow @@ -635,6 +634,72 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest { ) } + testWithAllStateVersions("[SPARK-49474] flatMapGroupsWithState - user NPE is classified") { + // Throws NPE + val stateFunc = (_: String, _: Iterator[String], _: GroupState[RunningCount]) => { + throw new NullPointerException() + // Need to return an iterator for compilation to get types + Iterator(1) + } + + val inputData = MemoryStream[String] + val result = + inputData.toDS() + .groupByKey(x => x) + .flatMapGroupsWithState(Update, GroupStateTimeout.NoTimeout)(stateFunc) + + testStream(result, Update)( + AddData(inputData, "a"), + ExpectFailure[FlatMapGroupsWithStateUserFuncException]() + ) + } + + testWithAllStateVersions( + "[SPARK-49474] flatMapGroupsWithState - null user iterator error is classified") { + // Returns null, will throw NPE when method is called on it + val stateFunc = (_: String, _: Iterator[String], _: GroupState[RunningCount]) => { + null.asInstanceOf[Iterator[Int]] + } + + val inputData = MemoryStream[String] + val result = + inputData.toDS() + .groupByKey(x => x) + .flatMapGroupsWithState(Update, GroupStateTimeout.NoTimeout)(stateFunc) + + testStream(result, Update)( + AddData(inputData, "a"), + ExpectFailure[FlatMapGroupsWithStateUserFuncException]() + ) + } + + testWithAllStateVersions( + "[SPARK-49474] flatMapGroupsWithState - NPE from user iterator is classified") { + // Returns iterator that throws NPE when next is called + val stateFunc = (_: String, _: Iterator[String], _: GroupState[RunningCount]) => { + new Iterator[Int] { + override def hasNext: Boolean = { + true + } + + override def next(): Int = { + throw new NullPointerException() + } + } + } + + val inputData = MemoryStream[String] + val result = + inputData.toDS() + .groupByKey(x => x) + .flatMapGroupsWithState(Update, GroupStateTimeout.NoTimeout)(stateFunc) + + testStream(result, Update)( + AddData(inputData, "a"), + ExpectFailure[FlatMapGroupsWithStateUserFuncException]() + ) + } + test("mapGroupsWithState - streaming") { // Function to maintain running count up to 2, and then remove the count // Returns the data and the count (-1 if count reached beyond 2 and state was just removed) @@ -816,7 +881,8 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest { CheckNewAnswer(("a", 2L)), setFailInTask(true), AddData(inputData, "a"), - ExpectFailure[SparkException](), // task should fail but should not increment count + // task should fail but should not increment count + ExpectFailure[FlatMapGroupsWithStateUserFuncException](), setFailInTask(false), StartStream(), CheckNewAnswer(("a", 3L)) // task should not fail, and should show correct count