diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 524b0c4f6c3ac..af11a319083d5 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -1389,6 +1389,15 @@ private[spark] class DAGScheduler( event.reason match { case Success => + // An earlier attempt of a stage (which is zombie) may still have running tasks. If these + // tasks complete, they still count and we can mark the corresponding partitions as + // finished. Here we notify the task scheduler to skip running tasks for the same partition, + // to save resource. + if (task.stageAttemptId < stage.latestInfo.attemptNumber()) { + taskScheduler.notifyPartitionCompletion( + stageId, task.partitionId, event.taskInfo.duration) + } + task match { case rt: ResultTask[_, _] => // Cast to ResultStage here because it's part of the ResultTask diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala index c6dedaaa9554a..09c4d9b5bce04 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala @@ -155,6 +155,16 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul } } + // This method calls `TaskSchedulerImpl.handlePartitionCompleted` asynchronously. We do not want + // DAGScheduler to call `TaskSchedulerImpl.handlePartitionCompleted` directly, as it's + // synchronized and may hurt the throughput of the scheduler. + def enqueuePartitionCompletionNotification( + stageId: Int, partitionId: Int, taskDuration: Long): Unit = { + getTaskResultExecutor.execute(() => Utils.logUncaughtExceptions { + scheduler.handlePartitionCompleted(stageId, partitionId, taskDuration) + }) + } + def stop() { getTaskResultExecutor.shutdownNow() } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala index 94221eb0d5515..1862e16824277 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala @@ -68,6 +68,10 @@ private[spark] trait TaskScheduler { // Throw UnsupportedOperationException if the backend doesn't support kill tasks. def killAllTaskAttempts(stageId: Int, interruptThread: Boolean, reason: String): Unit + // Notify the corresponding `TaskSetManager`s of the stage, that a partition has already completed + // and they can skip running tasks for it. + def notifyPartitionCompletion(stageId: Int, partitionId: Int, taskDuration: Long) + // Set the DAG scheduler for upcalls. This is guaranteed to be set before submitTasks is called. def setDAGScheduler(dagScheduler: DAGScheduler): Unit diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index f461a6f0aae36..7e820c32fa78d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -301,6 +301,11 @@ private[spark] class TaskSchedulerImpl( } } + override def notifyPartitionCompletion( + stageId: Int, partitionId: Int, taskDuration: Long): Unit = { + taskResultGetter.enqueuePartitionCompletionNotification(stageId, partitionId, taskDuration) + } + /** * Called to indicate that all task attempts (including speculated tasks) associated with the * given TaskSetManager have completed, so state associated with the TaskSetManager should be @@ -637,6 +642,24 @@ private[spark] class TaskSchedulerImpl( } } + /** + * Marks the task has completed in the active TaskSetManager for the given stage. + * + * After stage failure and retry, there may be multiple TaskSetManagers for the stage. + * If an earlier zombie attempt of a stage completes a task, we can ask the later active attempt + * to skip submitting and running the task for the same partition, to save resource. That also + * means that a task completion from an earlier zombie attempt can lead to the entire stage + * getting marked as successful. + */ + private[scheduler] def handlePartitionCompleted( + stageId: Int, + partitionId: Int, + taskDuration: Long) = synchronized { + taskSetsByStageIdAndAttempt.get(stageId).foreach(_.values.filter(!_.isZombie).foreach { tsm => + tsm.markPartitionCompleted(partitionId, taskDuration) + }) + } + def error(message: String) { synchronized { if (taskSetsByStageIdAndAttempt.nonEmpty) { @@ -868,24 +891,6 @@ private[spark] class TaskSchedulerImpl( manager } } - - /** - * Marks the task has completed in all TaskSetManagers for the given stage. - * - * After stage failure and retry, there may be multiple TaskSetManagers for the stage. - * If an earlier attempt of a stage completes a task, we should ensure that the later attempts - * do not also submit those same tasks. That also means that a task completion from an earlier - * attempt can lead to the entire stage getting marked as successful. - */ - private[scheduler] def markPartitionCompletedInAllTaskSets( - stageId: Int, - partitionId: Int, - taskInfo: TaskInfo) = { - taskSetsByStageIdAndAttempt.getOrElse(stageId, Map()).values.foreach { tsm => - tsm.markPartitionCompleted(partitionId, taskInfo) - } - } - } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index e1df1555b8141..b3aa814537500 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -806,9 +806,6 @@ private[spark] class TaskSetManager( logInfo("Ignoring task-finished event for " + info.id + " in stage " + taskSet.id + " because task " + index + " has already completed successfully") } - // There may be multiple tasksets for this stage -- we let all of them know that the partition - // was completed. This may result in some of the tasksets getting completed. - sched.markPartitionCompletedInAllTaskSets(stageId, tasks(index).partitionId, info) // This method is called by "TaskSchedulerImpl.handleSuccessfulTask" which holds the // "TaskSchedulerImpl" lock until exiting. To avoid the SPARK-7655 issue, we should not // "deserialize" the value when holding a lock to avoid blocking other threads. So we call @@ -819,11 +816,11 @@ private[spark] class TaskSetManager( maybeFinishTaskSet() } - private[scheduler] def markPartitionCompleted(partitionId: Int, taskInfo: TaskInfo): Unit = { + private[scheduler] def markPartitionCompleted(partitionId: Int, taskDuration: Long): Unit = { partitionToIndex.get(partitionId).foreach { index => if (!successful(index)) { if (speculationEnabled && !isZombie) { - successfulTaskDurations.insert(taskInfo.duration) + successfulTaskDurations.insert(taskDuration) } tasksSuccessful += 1 successful(index) = true diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 72c20a8173365..c8ae834e01e19 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -134,6 +134,8 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi /** Stages for which the DAGScheduler has called TaskScheduler.cancelTasks(). */ val cancelledStages = new HashSet[Int]() + val tasksMarkedAsCompleted = new ArrayBuffer[Task[_]]() + val taskScheduler = new TaskScheduler() { override def schedulingMode: SchedulingMode = SchedulingMode.FIFO override def rootPool: Pool = new Pool("", schedulingMode, 0, 0) @@ -156,6 +158,14 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi taskId: Long, interruptThread: Boolean, reason: String): Boolean = false override def killAllTaskAttempts( stageId: Int, interruptThread: Boolean, reason: String): Unit = {} + override def notifyPartitionCompletion( + stageId: Int, partitionId: Int, taskDuration: Long): Unit = { + taskSets.filter(_.stageId == stageId).lastOption.foreach { ts => + val tasks = ts.tasks.filter(_.partitionId == partitionId) + assert(tasks.length == 1) + tasksMarkedAsCompleted += tasks.head + } + } override def setDAGScheduler(dagScheduler: DAGScheduler) = {} override def defaultParallelism() = 2 override def executorLost(executorId: String, reason: ExecutorLossReason): Unit = {} @@ -246,6 +256,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi failure = null sc.addSparkListener(sparkListener) taskSets.clear() + tasksMarkedAsCompleted.clear() cancelledStages.clear() cacheLocations.clear() results.clear() @@ -658,6 +669,10 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi stageId: Int, interruptThread: Boolean, reason: String): Unit = { throw new UnsupportedOperationException } + override def notifyPartitionCompletion( + stageId: Int, partitionId: Int, taskDuration: Long): Unit = { + throw new UnsupportedOperationException + } override def setDAGScheduler(dagScheduler: DAGScheduler): Unit = {} override def defaultParallelism(): Int = 2 override def executorHeartbeatReceived( @@ -2862,6 +2877,57 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi assert(latch.await(10, TimeUnit.SECONDS)) } + test("Completions in zombie tasksets update status of non-zombie taskset") { + val parts = 4 + val shuffleMapRdd = new MyRDD(sc, parts, Nil) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(parts)) + val reduceRdd = new MyRDD(sc, parts, List(shuffleDep), tracker = mapOutputTracker) + submit(reduceRdd, (0 until parts).toArray) + assert(taskSets.length == 1) + + // Finish the first task of the shuffle map stage. + runEvent(makeCompletionEvent( + taskSets(0).tasks(0), Success, makeMapStatus("hostA", 4), + Seq.empty, createFakeTaskInfoWithId(0))) + + // The second task of the shuffle map stage failed with FetchFailed. + runEvent(makeCompletionEvent( + taskSets(0).tasks(1), + FetchFailed(makeBlockManagerId("hostB"), shuffleDep.shuffleId, 0, 0, "ignored"), + null)) + + scheduler.resubmitFailedStages() + assert(taskSets.length == 2) + // The first partition has completed already, so the new attempt only need to run 3 tasks. + assert(taskSets(1).tasks.length == 3) + + // Finish the first task of the second attempt of the shuffle map stage. + runEvent(makeCompletionEvent( + taskSets(1).tasks(0), Success, makeMapStatus("hostA", 4), + Seq.empty, createFakeTaskInfoWithId(0))) + + // Finish the third task of the first attempt of the shuffle map stage. + runEvent(makeCompletionEvent( + taskSets(0).tasks(2), Success, makeMapStatus("hostA", 4), + Seq.empty, createFakeTaskInfoWithId(0))) + assert(tasksMarkedAsCompleted.length == 1) + assert(tasksMarkedAsCompleted.head.partitionId == 2) + + // Finish the forth task of the first attempt of the shuffle map stage. + runEvent(makeCompletionEvent( + taskSets(0).tasks(3), Success, makeMapStatus("hostA", 4), + Seq.empty, createFakeTaskInfoWithId(0))) + assert(tasksMarkedAsCompleted.length == 2) + assert(tasksMarkedAsCompleted.last.partitionId == 3) + + // Now the shuffle map stage is completed, and the next stage is submitted. + assert(taskSets.length == 3) + + // Finish + complete(taskSets(2), Seq((Success, 42), (Success, 42), (Success, 42), (Success, 42))) + assertDataStructuresEmpty() + } + /** * Assert that the supplied TaskSet has exactly the given hosts as its preferred locations. * Note that this checks only the host and not the executor ID. diff --git a/core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala index 30d0966691a3c..347064dc9aadf 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala @@ -84,6 +84,8 @@ private class DummyTaskScheduler extends TaskScheduler { taskId: Long, interruptThread: Boolean, reason: String): Boolean = false override def killAllTaskAttempts( stageId: Int, interruptThread: Boolean, reason: String): Unit = {} + override def notifyPartitionCompletion( + stageId: Int, partitionId: Int, taskDuration: Long): Unit = {} override def setDAGScheduler(dagScheduler: DAGScheduler): Unit = {} override def defaultParallelism(): Int = 2 override def executorLost(executorId: String, reason: ExecutorLossReason): Unit = {} diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala index 137ff2bd167ae..29614058485ab 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala @@ -1121,110 +1121,6 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B } } - test("Completions in zombie tasksets update status of non-zombie taskset") { - val taskScheduler = setupSchedulerWithMockTaskSetBlacklist() - val valueSer = SparkEnv.get.serializer.newInstance() - - def completeTaskSuccessfully(tsm: TaskSetManager, partition: Int): Unit = { - val indexInTsm = tsm.partitionToIndex(partition) - val matchingTaskInfo = tsm.taskAttempts.flatten.filter(_.index == indexInTsm).head - val result = new DirectTaskResult[Int](valueSer.serialize(1), Seq()) - tsm.handleSuccessfulTask(matchingTaskInfo.taskId, result) - } - - // Submit a task set, have it fail with a fetch failed, and then re-submit the task attempt, - // two times, so we have three active task sets for one stage. (For this to really happen, - // you'd need the previous stage to also get restarted, and then succeed, in between each - // attempt, but that happens outside what we're mocking here.) - val zombieAttempts = (0 until 2).map { stageAttempt => - val attempt = FakeTask.createTaskSet(10, stageAttemptId = stageAttempt) - taskScheduler.submitTasks(attempt) - val tsm = taskScheduler.taskSetManagerForAttempt(0, stageAttempt).get - val offers = (0 until 10).map{ idx => WorkerOffer(s"exec-$idx", s"host-$idx", 1) } - taskScheduler.resourceOffers(offers) - assert(tsm.runningTasks === 10) - // fail attempt - tsm.handleFailedTask(tsm.taskAttempts.head.head.taskId, TaskState.FAILED, - FetchFailed(null, 0, 0, 0, "fetch failed")) - // the attempt is a zombie, but the tasks are still running (this could be true even if - // we actively killed those tasks, as killing is best-effort) - assert(tsm.isZombie) - assert(tsm.runningTasks === 9) - tsm - } - - // we've now got 2 zombie attempts, each with 9 tasks still active. Submit the 3rd attempt for - // the stage, but this time with insufficient resources so not all tasks are active. - - val finalAttempt = FakeTask.createTaskSet(10, stageAttemptId = 2) - taskScheduler.submitTasks(finalAttempt) - val finalTsm = taskScheduler.taskSetManagerForAttempt(0, 2).get - val offers = (0 until 5).map{ idx => WorkerOffer(s"exec-$idx", s"host-$idx", 1) } - val finalAttemptLaunchedPartitions = taskScheduler.resourceOffers(offers).flatten.map { task => - finalAttempt.tasks(task.index).partitionId - }.toSet - assert(finalTsm.runningTasks === 5) - assert(!finalTsm.isZombie) - - // We simulate late completions from our zombie tasksets, corresponding to all the pending - // partitions in our final attempt. This means we're only waiting on the tasks we've already - // launched. - val finalAttemptPendingPartitions = (0 until 10).toSet.diff(finalAttemptLaunchedPartitions) - finalAttemptPendingPartitions.foreach { partition => - completeTaskSuccessfully(zombieAttempts(0), partition) - } - - // If there is another resource offer, we shouldn't run anything. Though our final attempt - // used to have pending tasks, now those tasks have been completed by zombie attempts. The - // remaining tasks to compute are already active in the non-zombie attempt. - assert( - taskScheduler.resourceOffers(IndexedSeq(WorkerOffer("exec-1", "host-1", 1))).flatten.isEmpty) - - val remainingTasks = finalAttemptLaunchedPartitions.toIndexedSeq.sorted - - // finally, if we finish the remaining partitions from a mix of tasksets, all attempts should be - // marked as zombie. - // for each of the remaining tasks, find the tasksets with an active copy of the task, and - // finish the task. - remainingTasks.foreach { partition => - val tsm = if (partition == 0) { - // we failed this task on both zombie attempts, this one is only present in the latest - // taskset - finalTsm - } else { - // should be active in every taskset. We choose a zombie taskset just to make sure that - // we transition the active taskset correctly even if the final completion comes - // from a zombie. - zombieAttempts(partition % 2) - } - completeTaskSuccessfully(tsm, partition) - } - - assert(finalTsm.isZombie) - - // no taskset has completed all of its tasks, so no updates to the blacklist tracker yet - verify(blacklist, never).updateBlacklistForSuccessfulTaskSet(anyInt(), anyInt(), any()) - - // finally, lets complete all the tasks. We simulate failures in attempt 1, but everything - // else succeeds, to make sure we get the right updates to the blacklist in all cases. - (zombieAttempts ++ Seq(finalTsm)).foreach { tsm => - val stageAttempt = tsm.taskSet.stageAttemptId - tsm.runningTasksSet.foreach { index => - if (stageAttempt == 1) { - tsm.handleFailedTask(tsm.taskInfos(index).taskId, TaskState.FAILED, TaskResultLost) - } else { - val result = new DirectTaskResult[Int](valueSer.serialize(1), Seq()) - tsm.handleSuccessfulTask(tsm.taskInfos(index).taskId, result) - } - } - - // we update the blacklist for the stage attempts with all successful tasks. Even though - // some tasksets had failures, we still consider them all successful from a blacklisting - // perspective, as the failures weren't from a problem w/ the tasks themselves. - verify(blacklist).updateBlacklistForSuccessfulTaskSet(meq(0), meq(stageAttempt), any()) - } - } - test("don't schedule for a barrier taskSet if available slots are less than pending tasks") { val taskCpus = 2 val taskScheduler = setupSchedulerWithMaster( diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index 79160d05b3e60..0666bc335abac 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -1372,7 +1372,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg assert(taskOption4.get.addedJars === addedJarsMidTaskSet) } - test("[SPARK-24677] Avoid NoSuchElementException from MedianHeap") { + test("SPARK-24677: Avoid NoSuchElementException from MedianHeap") { val conf = new SparkConf().set(config.SPECULATION_ENABLED, true) sc = new SparkContext("local", "test", conf) // Set the speculation multiplier to be 0 so speculative tasks are launched immediately @@ -1386,39 +1386,17 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg val dagScheduler = new FakeDAGScheduler(sc, sched) sched.setDAGScheduler(dagScheduler) - val taskSet1 = FakeTask.createTaskSet(10) - val accumUpdatesByTask: Array[Seq[AccumulatorV2[_, _]]] = taskSet1.tasks.map { task => - task.metrics.internalAccums - } + val taskSet = FakeTask.createTaskSet(10) - sched.submitTasks(taskSet1) + sched.submitTasks(taskSet) sched.resourceOffers( - (0 until 10).map { idx => WorkerOffer(s"exec-$idx", s"host-$idx", 1) }) - - val taskSetManager1 = sched.taskSetManagerForAttempt(0, 0).get - - // fail fetch - taskSetManager1.handleFailedTask( - taskSetManager1.taskAttempts.head.head.taskId, TaskState.FAILED, - FetchFailed(null, 0, 0, 0, "fetch failed")) - - assert(taskSetManager1.isZombie) - assert(taskSetManager1.runningTasks === 9) - - val taskSet2 = FakeTask.createTaskSet(10, stageAttemptId = 1) - sched.submitTasks(taskSet2) - sched.resourceOffers( - (11 until 20).map { idx => WorkerOffer(s"exec-$idx", s"host-$idx", 1) }) - - // Complete the 2 tasks and leave 8 task in running - for (id <- Set(0, 1)) { - taskSetManager1.handleSuccessfulTask(id, createTaskResult(id, accumUpdatesByTask(id))) - assert(sched.endedTasks(id) === Success) - } + (0 until 8).map { idx => WorkerOffer(s"exec-$idx", s"host-$idx", 1) }) - val taskSetManager2 = sched.taskSetManagerForAttempt(0, 1).get - assert(!taskSetManager2.successfulTaskDurations.isEmpty()) - taskSetManager2.checkSpeculatableTasks(0) + val taskSetManager = sched.taskSetManagerForAttempt(0, 0).get + assert(taskSetManager.runningTasks === 8) + taskSetManager.markPartitionCompleted(8, 0) + assert(!taskSetManager.successfulTaskDurations.isEmpty()) + taskSetManager.checkSpeculatableTasks(0) }