Skip to content

Commit

Permalink
[SPARK-40106] Task failure should always trigger task failure listeners
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

Clean up the semantics of spark task listeners. Today, if a task body succeeds, but a completion listener fails, failure listeners are not called -- even tho the task has indeed failed at that point. The fix is to invoke failure listeners if a completion listener fails, before running the remaining completion listeners.

### Why are the changes needed?

Failure listeners are not reliably called today, if the task failure is caused by a failed completion listener. This limits the utility of task listeners, especially ones that could assist with task cleanup.

### Does this PR introduce _any_ user-facing change?

No changes to public methods, but failure listeners will now run when a completion listener fails, where previously they did not.

### How was this patch tested?

New unit tests exercise various combinations of failed listeners, with a task body that did (or did not) throw.

Closes #37531 from ryan-johnson-databricks/task-failure-listeners.

Authored-by: Ryan Johnson <[email protected]>
Signed-off-by: Josh Rosen <[email protected]>
  • Loading branch information
ryan-johnson-databricks authored and JoshRosen committed Aug 18, 2022
1 parent 50c1635 commit 6cd9d88
Show file tree
Hide file tree
Showing 4 changed files with 295 additions and 44 deletions.
41 changes: 37 additions & 4 deletions core/src/main/scala/org/apache/spark/TaskContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import org.apache.spark.executor.TaskMetrics
import org.apache.spark.memory.TaskMemoryManager
import org.apache.spark.metrics.source.Source
import org.apache.spark.resource.ResourceInformation
import org.apache.spark.scheduler.Task
import org.apache.spark.shuffle.FetchFailedException
import org.apache.spark.util.{AccumulatorV2, TaskCompletionListener, TaskFailureListener}

Expand Down Expand Up @@ -133,21 +134,53 @@ abstract class TaskContext extends Serializable {
}

/**
* Adds a listener to be executed on task failure. Adding a listener to an already failed task
* will result in that listener being called immediately.
* Adds a listener to be executed on task failure (which includes completion listener failure, if
* the task body did not already fail). Adding a listener to an already failed task will result in
* that listener being called immediately.
*
* Note: Prior to Spark 3.4.0, failure listeners were only invoked if the main task body failed.
*/
def addTaskFailureListener(listener: TaskFailureListener): TaskContext

/**
* Adds a listener to be executed on task failure. Adding a listener to an already failed task
* will result in that listener being called immediately.
* Adds a listener to be executed on task failure (which includes completion listener failure, if
* the task body did not already fail). Adding a listener to an already failed task will result in
* that listener being called immediately.
*
* Note: Prior to Spark 3.4.0, failure listeners were only invoked if the main task body failed.
*/
def addTaskFailureListener(f: (TaskContext, Throwable) => Unit): TaskContext = {
addTaskFailureListener(new TaskFailureListener {
override def onTaskFailure(context: TaskContext, error: Throwable): Unit = f(context, error)
})
}

/** Runs a task with this context, ensuring failure and completion listeners get triggered. */
private[spark] def runTaskWithListeners[T](task: Task[T]): T = {
try {
task.runTask(this)
} catch {
case e: Throwable =>
// Catch all errors; run task failure and completion callbacks, and rethrow the exception.
try {
markTaskFailed(e)
} catch {
case t: Throwable =>
e.addSuppressed(t)
}
try {
markTaskCompleted(Some(e))
} catch {
case t: Throwable =>
e.addSuppressed(t)
}
throw e
} finally {
// Call the task completion callbacks. No-op if "markTaskCompleted" was already called.
markTaskCompleted(None)
}
}

/**
* The ID of the stage that this task belong to.
*/
Expand Down
59 changes: 54 additions & 5 deletions core/src/main/scala/org/apache/spark/TaskContextImpl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ private[spark] class TaskContextImpl(
*
* `invokeListeners()` uses this to ensure listeners are called sequentially.
*/
@transient private var listenerInvocationThread: Option[Thread] = None
@transient @volatile private var listenerInvocationThread: Option[Thread] = None

// If defined, the corresponding task has been killed and this option contains the reason.
@volatile private var reasonIfKilled: Option[String] = None
Expand Down Expand Up @@ -191,20 +191,69 @@ private[spark] class TaskContextImpl(
}
}

val errorMsgs = new ArrayBuffer[String](2)
val listenerExceptions = new ArrayBuffer[Throwable](2)
var listenerOption: Option[T] = None
while ({listenerOption = getNextListenerOrDeregisterThread(); listenerOption.nonEmpty}) {
val listener = listenerOption.get
try {
callback(listener)
} catch {
case e: Throwable =>
errorMsgs += e.getMessage
// A listener failed. Temporarily clear the listenerInvocationThread and markTaskFailed.
//
// One of the following cases applies (#3 being the interesting one):
//
// 1. [[Task.doRunTask]] is currently calling [[markTaskFailed]] because the task body
// failed, and now a failure listener has failed here (not necessarily the first to
// fail). Then calling [[markTaskFailed]] again here is a no-op, and we simply resume
// running the remaining failure listeners. [[Task.doRunTask]] will then call
// [[markTaskCompleted]] after this method returns.
//
// 2. The task body failed, [[Task.doRunTask]] already called [[markTaskFailed]],
// [[Task.doRunTask]] is currently calling [[markTaskCompleted]], and now a completion
// listener has failed here (not necessarily the first one to fail). Then calling
// [[markTaskFailed]] it again here is a no-op, and we simply resume running the
// remaining completion listeners.
//
// 3. [[Task.doRunTask]] is currently calling [[markTaskCompleted]] because the task body
// succeeded, and now a completion listener has failed here (the first one to
// fail). Then our call to [[markTaskFailed]] here will run all failure listeners
// before returning, after which we will resume running the remaining completion
// listeners.
//
// 4. [[Task.doRunTask]] is currently calling [[markTaskCompleted]] because the task body
// succeeded, but [[markTaskFailed]] is currently running because a completion listener
// has failed, and now a failure listener has failed (not necessarily the first one to
// fail). Then calling [[markTaskFailed]] again here will have no effect, and we simply
// resume running the remaining failure listeners; we will resume running the remaining
// completion listeners after this call returns.
//
// 5. [[Task.doRunTask]] is currently calling [[markTaskCompleted]] because the task body
// succeeded, [[markTaskFailed]] already ran because a completion listener previously
// failed, and now another completion listener has failed. Then our call to
// [[markTaskFailed]] here will have no effect and we simply resume running the
// remaining completion handlers.
try {
listenerInvocationThread = None
markTaskFailed(e)
} catch {
case t: Throwable => e.addSuppressed(t)
} finally {
synchronized {
if (listenerInvocationThread.isEmpty) {
listenerInvocationThread = Some(Thread.currentThread())
}
}
}
listenerExceptions += e
logError(s"Error in $name", e)
}
}
if (errorMsgs.nonEmpty) {
throw new TaskCompletionListenerException(errorMsgs.toSeq, error)
if (listenerExceptions.nonEmpty) {
val exception = new TaskCompletionListenerException(
listenerExceptions.map(_.getMessage).toSeq, error)
listenerExceptions.foreach(exception.addSuppressed)
throw exception
}
}

Expand Down
53 changes: 18 additions & 35 deletions core/src/main/scala/org/apache/spark/scheduler/Task.scala
Original file line number Diff line number Diff line change
Expand Up @@ -136,44 +136,27 @@ private[spark] abstract class Task[T](
plugins.foreach(_.onTaskStart())

try {
runTask(context)
} catch {
case e: Throwable =>
// Catch all errors; run task failure callbacks, and rethrow the exception.
try {
context.markTaskFailed(e)
} catch {
case t: Throwable =>
e.addSuppressed(t)
}
context.markTaskCompleted(Some(e))
throw e
context.runTaskWithListeners(this)
} finally {
try {
// Call the task completion callbacks. If "markTaskCompleted" is called twice, the second
// one is no-op.
context.markTaskCompleted(None)
} finally {
try {
Utils.tryLogNonFatalError {
// Release memory used by this thread for unrolling blocks
SparkEnv.get.blockManager.memoryStore.releaseUnrollMemoryForThisTask(MemoryMode.ON_HEAP)
SparkEnv.get.blockManager.memoryStore.releaseUnrollMemoryForThisTask(
MemoryMode.OFF_HEAP)
// Notify any tasks waiting for execution memory to be freed to wake up and try to
// acquire memory again. This makes impossible the scenario where a task sleeps forever
// because there are no other tasks left to notify it. Since this is safe to do but may
// not be strictly necessary, we should revisit whether we can remove this in the
// future.
val memoryManager = SparkEnv.get.memoryManager
memoryManager.synchronized { memoryManager.notifyAll() }
}
} finally {
// Though we unset the ThreadLocal here, the context member variable itself is still
// queried directly in the TaskRunner to check for FetchFailedExceptions.
TaskContext.unset()
InputFileBlockHolder.unset()
Utils.tryLogNonFatalError {
// Release memory used by this thread for unrolling blocks
SparkEnv.get.blockManager.memoryStore.releaseUnrollMemoryForThisTask(MemoryMode.ON_HEAP)
SparkEnv.get.blockManager.memoryStore.releaseUnrollMemoryForThisTask(
MemoryMode.OFF_HEAP)
// Notify any tasks waiting for execution memory to be freed to wake up and try to
// acquire memory again. This makes impossible the scenario where a task sleeps forever
// because there are no other tasks left to notify it. Since this is safe to do but may
// not be strictly necessary, we should revisit whether we can remove this in the
// future.
val memoryManager = SparkEnv.get.memoryManager
memoryManager.synchronized { memoryManager.notifyAll() }
}
} finally {
// Though we unset the ThreadLocal here, the context member variable itself is still
// queried directly in the TaskRunner to check for FetchFailedExceptions.
TaskContext.unset()
InputFileBlockHolder.unset()
}
}
}
Expand Down
Loading

0 comments on commit 6cd9d88

Please sign in to comment.