Skip to content

Commit

Permalink
[SPARK-49474][SS] Classify Error class for FlatMapGroupsWithState use…
Browse files Browse the repository at this point in the history
…r function error

### What changes were proposed in this pull request?

Add new error classification for errors occurring in the user function that is used in FlatMapGroupsWithState.

### Why are the changes needed?

The user provided function can throw any type of error. Using the new error framework for better error messages and classification.

### Does this PR introduce _any_ user-facing change?

Yes, better error message with error class for Foreach sink user function failures.

### How was this patch tested?

Updated existing tests and added new unit test in FlatMapGroupsWithStateSuite.

### Was this patch authored or co-authored using generative AI tooling?

No

Closes #47940 from liviazhu-db/liviazhu-db/classify-flatmapgroupswithstate-error.

Authored-by: Livia Zhu <[email protected]>
Signed-off-by: Jungtaek Lim <[email protected]>
  • Loading branch information
liviazhu-db authored and HeartSaVioR committed Sep 5, 2024
1 parent 37f2fa9 commit 9a7b6e5
Show file tree
Hide file tree
Showing 3 changed files with 117 additions and 7 deletions.
6 changes: 6 additions & 0 deletions common/utils/src/main/resources/error/error-conditions.json
Original file line number Diff line number Diff line change
Expand Up @@ -1457,6 +1457,12 @@
],
"sqlState" : "42704"
},
"FLATMAPGROUPSWITHSTATE_USER_FUNCTION_ERROR" : {
"message" : [
"An error occurred in the user provided function in flatMapGroupsWithState. Reason: <reason>"
],
"sqlState" : "39000"
},
"FORBIDDEN_OPERATION" : {
"message" : [
"The operation <statement> is not allowed on the <objectType>: <objectName>."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,11 @@ package org.apache.spark.sql.execution.streaming

import java.util.concurrent.TimeUnit.NANOSECONDS

import scala.util.control.NonFatal

import org.apache.hadoop.conf.Configuration

import org.apache.spark.{SparkException, SparkThrowable}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
Expand Down Expand Up @@ -447,10 +450,33 @@ case class FlatMapGroupsWithStateExec(
hasTimedOut,
watermarkPresent)

// Call function, get the returned objects and convert them to rows
val mappedIterator = func(keyObj, valueObjIter, groupState).map { obj =>
numOutputRows += 1
getOutputRow(obj)
def withUserFuncExceptionHandling[T](func: => T): T = {
try {
func
} catch {
case NonFatal(e) if !e.isInstanceOf[SparkThrowable] =>
throw FlatMapGroupsWithStateUserFuncException(e)
case f: Throwable =>
throw f
}
}

val mappedIterator = withUserFuncExceptionHandling {
func(keyObj, valueObjIter, groupState).map { obj =>
numOutputRows += 1
getOutputRow(obj)
}
}

// Wrap user-provided fns with error handling
val wrappedMappedIterator = new Iterator[InternalRow] {
override def hasNext: Boolean = {
withUserFuncExceptionHandling(mappedIterator.hasNext)
}

override def next(): InternalRow = {
withUserFuncExceptionHandling(mappedIterator.next())
}
}

// When the iterator is consumed, then write changes to state
Expand All @@ -472,7 +498,9 @@ case class FlatMapGroupsWithStateExec(
}

// Return an iterator of rows such that fully consumed, the updated state value will be saved
CompletionIterator[InternalRow, Iterator[InternalRow]](mappedIterator, onIteratorCompletion)
CompletionIterator[InternalRow, Iterator[InternalRow]](
wrappedMappedIterator, onIteratorCompletion
)
}
}
}
Expand Down Expand Up @@ -544,3 +572,13 @@ object FlatMapGroupsWithStateExec {
}
}
}


/**
* Exception that wraps the exception thrown in the user provided function in Foreach sink.
*/
private[sql] case class FlatMapGroupsWithStateUserFuncException(cause: Throwable)
extends SparkException(
errorClass = "FLATMAPGROUPSWITHSTATE_USER_FUNCTION_ERROR",
messageParameters = Map("reason" -> Option(cause.getMessage).getOrElse("")),
cause = cause)
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ import java.sql.Timestamp
import org.apache.commons.io.FileUtils
import org.scalatest.exceptions.TestFailedException

import org.apache.spark.SparkException
import org.apache.spark.api.java.function.FlatMapGroupsWithStateFunction
import org.apache.spark.sql.{DataFrame, Encoder}
import org.apache.spark.sql.catalyst.InternalRow
Expand Down Expand Up @@ -635,6 +634,72 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest {
)
}

testWithAllStateVersions("[SPARK-49474] flatMapGroupsWithState - user NPE is classified") {
// Throws NPE
val stateFunc = (_: String, _: Iterator[String], _: GroupState[RunningCount]) => {
throw new NullPointerException()
// Need to return an iterator for compilation to get types
Iterator(1)
}

val inputData = MemoryStream[String]
val result =
inputData.toDS()
.groupByKey(x => x)
.flatMapGroupsWithState(Update, GroupStateTimeout.NoTimeout)(stateFunc)

testStream(result, Update)(
AddData(inputData, "a"),
ExpectFailure[FlatMapGroupsWithStateUserFuncException]()
)
}

testWithAllStateVersions(
"[SPARK-49474] flatMapGroupsWithState - null user iterator error is classified") {
// Returns null, will throw NPE when method is called on it
val stateFunc = (_: String, _: Iterator[String], _: GroupState[RunningCount]) => {
null.asInstanceOf[Iterator[Int]]
}

val inputData = MemoryStream[String]
val result =
inputData.toDS()
.groupByKey(x => x)
.flatMapGroupsWithState(Update, GroupStateTimeout.NoTimeout)(stateFunc)

testStream(result, Update)(
AddData(inputData, "a"),
ExpectFailure[FlatMapGroupsWithStateUserFuncException]()
)
}

testWithAllStateVersions(
"[SPARK-49474] flatMapGroupsWithState - NPE from user iterator is classified") {
// Returns iterator that throws NPE when next is called
val stateFunc = (_: String, _: Iterator[String], _: GroupState[RunningCount]) => {
new Iterator[Int] {
override def hasNext: Boolean = {
true
}

override def next(): Int = {
throw new NullPointerException()
}
}
}

val inputData = MemoryStream[String]
val result =
inputData.toDS()
.groupByKey(x => x)
.flatMapGroupsWithState(Update, GroupStateTimeout.NoTimeout)(stateFunc)

testStream(result, Update)(
AddData(inputData, "a"),
ExpectFailure[FlatMapGroupsWithStateUserFuncException]()
)
}

test("mapGroupsWithState - streaming") {
// Function to maintain running count up to 2, and then remove the count
// Returns the data and the count (-1 if count reached beyond 2 and state was just removed)
Expand Down Expand Up @@ -816,7 +881,8 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest {
CheckNewAnswer(("a", 2L)),
setFailInTask(true),
AddData(inputData, "a"),
ExpectFailure[SparkException](), // task should fail but should not increment count
// task should fail but should not increment count
ExpectFailure[FlatMapGroupsWithStateUserFuncException](),
setFailInTask(false),
StartStream(),
CheckNewAnswer(("a", 3L)) // task should not fail, and should show correct count
Expand Down

0 comments on commit 9a7b6e5

Please sign in to comment.