diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 6368665f249ee..c96d7435a7ed4 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -982,15 +982,7 @@ class DAGScheduler( if (!jobIdToStageIds.contains(jobId)) { logDebug("Trying to cancel unregistered job " + jobId) } else { - val independentStages = removeJobAndIndependentStages(jobId) - independentStages.foreach(taskScheduler.cancelTasks) - val error = new SparkException("Job %d cancelled".format(jobId)) - val job = jobIdToActiveJob(jobId) - job.listener.jobFailed(error) - jobIdToStageIds -= jobId - activeJobs -= job - jobIdToActiveJob -= jobId - listenerBus.post(SparkListenerJobEnd(job.jobId, JobFailed(error, job.finalStage.id))) + failJobAndIndependentStages(jobIdToActiveJob(jobId), s"Job $jobId cancelled") } } @@ -1007,19 +999,39 @@ class DAGScheduler( stageToInfos(failedStage).completionTime = Some(System.currentTimeMillis()) for (resultStage <- dependentStages) { val job = resultStageToJob(resultStage) - val error = new SparkException("Job aborted: " + reason) - job.listener.jobFailed(error) - jobIdToStageIdsRemove(job.jobId) - jobIdToActiveJob -= resultStage.jobId - activeJobs -= job - resultStageToJob -= resultStage - listenerBus.post(SparkListenerJobEnd(job.jobId, JobFailed(error, failedStage.id))) + failJobAndIndependentStages(job, s"Job aborted due to stage failure: $reason") } if (dependentStages.isEmpty) { logInfo("Ignoring failure of " + failedStage + " because all jobs depending on it are done") } } + /** + * Fails a job and all stages that are only used by that job, and cleans up relevant state. + */ + private def failJobAndIndependentStages(job: ActiveJob, failureReason: String) { + val error = new SparkException(failureReason) + job.listener.jobFailed(error) + + // Cancel all tasks in independent stages. + val independentStages = removeJobAndIndependentStages(job.jobId) + independentStages.foreach(taskScheduler.cancelTasks) + + // Clean up remaining state we store for the job. + jobIdToActiveJob -= job.jobId + activeJobs -= job + jobIdToStageIds -= job.jobId + val resultStagesForJob = resultStageToJob.keySet.filter( + stage => resultStageToJob(stage).jobId == job.jobId) + if (resultStagesForJob.size != 1) { + logWarning( + s"${resultStagesForJob.size} result stages for job ${job.jobId} (expect exactly 1)") + } + resultStageToJob --= resultStagesForJob + + listenerBus.post(SparkListenerJobEnd(job.jobId, JobFailed(error, job.finalStage.id))) + } + /** * Return true if one of stage's ancestors is target. */ diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index ce567b0cde85d..2e3026bffba2f 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.scheduler import scala.Tuple2 -import scala.collection.mutable.{HashMap, Map} +import scala.collection.mutable.{HashSet, HashMap, Map} import org.scalatest.{BeforeAndAfter, FunSuite} @@ -43,6 +43,10 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont val conf = new SparkConf /** Set of TaskSets the DAGScheduler has requested executed. */ val taskSets = scala.collection.mutable.Buffer[TaskSet]() + + /** Stages for which the DAGScheduler has called TaskScheduler.cancelTasks(). */ + val cancelledStages = new HashSet[Int]() + val taskScheduler = new TaskScheduler() { override def rootPool: Pool = null override def schedulingMode: SchedulingMode = SchedulingMode.NONE @@ -53,7 +57,9 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont taskSet.tasks.foreach(_.epoch = mapOutputTracker.getEpoch) taskSets += taskSet } - override def cancelTasks(stageId: Int) {} + override def cancelTasks(stageId: Int) { + cancelledStages += stageId + } override def setDAGScheduler(dagScheduler: DAGScheduler) = {} override def defaultParallelism() = 2 } @@ -91,6 +97,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont before { sc = new SparkContext("local", "DAGSchedulerSuite") taskSets.clear() + cancelledStages.clear() cacheLocations.clear() results.clear() mapOutputTracker = new MapOutputTrackerMaster(conf) @@ -174,15 +181,16 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont } } - /** Sends the rdd to the scheduler for scheduling. */ + /** Sends the rdd to the scheduler for scheduling and returns the job id. */ private def submit( rdd: RDD[_], partitions: Array[Int], func: (TaskContext, Iterator[_]) => _ = jobComputeFunc, allowLocal: Boolean = false, - listener: JobListener = listener) { + listener: JobListener = listener): Int = { val jobId = scheduler.nextJobId.getAndIncrement() runEvent(JobSubmitted(jobId, rdd, func, partitions, allowLocal, null, listener)) + return jobId } /** Sends TaskSetFailed to the scheduler. */ @@ -190,6 +198,11 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont runEvent(TaskSetFailed(taskSet, message)) } + /** Sends JobCancelled to the DAG scheduler. */ + private def cancel(jobId: Int) { + runEvent(JobCancelled(jobId)) + } + test("zero split job") { val rdd = makeRdd(0, Nil) var numResults = 0 @@ -248,7 +261,15 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont test("trivial job failure") { submit(makeRdd(1, Nil), Array(0)) failed(taskSets(0), "some failure") - assert(failure.getMessage === "Job aborted: some failure") + assert(failure.getMessage === "Job aborted due to stage failure: some failure") + assertDataStructuresEmpty + } + + test("trivial job cancellation") { + val rdd = makeRdd(1, Nil) + val jobId = submit(rdd, Array(0)) + cancel(jobId) + assert(failure.getMessage === s"Job $jobId cancelled") assertDataStructuresEmpty } @@ -323,6 +344,67 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont assertDataStructuresEmpty } + test("run shuffle with map stage failure") { + val shuffleMapRdd = makeRdd(2, Nil) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, null) + val reduceRdd = makeRdd(2, List(shuffleDep)) + submit(reduceRdd, Array(0, 1)) + + // Fail the map stage. This should cause the entire job to fail. + val stageFailureMessage = "Exception failure in map stage" + failed(taskSets(0), stageFailureMessage) + assert(failure.getMessage === s"Job aborted due to stage failure: $stageFailureMessage") + assertDataStructuresEmpty + } + + /** + * Makes sure that failures of stage used by multiple jobs are correctly handled. + * + * This test creates the following dependency graph: + * + * shuffleMapRdd1 shuffleMapRDD2 + * | \ | + * | \ | + * | \ | + * | \ | + * reduceRdd1 reduceRdd2 + * + * We start both shuffleMapRdds and then fail shuffleMapRdd1. As a result, the job listeners for + * reduceRdd1 and reduceRdd2 should both be informed that the job failed. shuffleMapRDD2 should + * also be cancelled, because it is only used by reduceRdd2 and reduceRdd2 cannot complete + * without shuffleMapRdd1. + */ + test("failure of stage used by two jobs") { + val shuffleMapRdd1 = makeRdd(2, Nil) + val shuffleDep1 = new ShuffleDependency(shuffleMapRdd1, null) + val shuffleMapRdd2 = makeRdd(2, Nil) + val shuffleDep2 = new ShuffleDependency(shuffleMapRdd2, null) + + val reduceRdd1 = makeRdd(2, List(shuffleDep1)) + val reduceRdd2 = makeRdd(2, List(shuffleDep1, shuffleDep2)) + + // We need to make our own listeners for this test, since by default submit uses the same + // listener for all jobs, and here we want to capture the failure for each job separately. + class FailureRecordingJobListener() extends JobListener { + var failureMessage: String = _ + override def taskSucceeded(index: Int, result: Any) {} + override def jobFailed(exception: Exception) = { failureMessage = exception.getMessage } + } + val listener1 = new FailureRecordingJobListener() + val listener2 = new FailureRecordingJobListener() + + submit(reduceRdd1, Array(0, 1), listener=listener1) + submit(reduceRdd2, Array(0, 1), listener=listener2) + + val stageFailureMessage = "Exception failure in map stage" + failed(taskSets(0), stageFailureMessage) + + assert(cancelledStages.contains(1)) + assert(listener1.failureMessage === s"Job aborted due to stage failure: $stageFailureMessage") + assert(listener2.failureMessage === s"Job aborted due to stage failure: $stageFailureMessage") + assertDataStructuresEmpty + } + test("run trivial shuffle with out-of-band failure and retry") { val shuffleMapRdd = makeRdd(2, Nil) val shuffleDep = new ShuffleDependency(shuffleMapRdd, null)