From baf46e19829f9693ceb1a54457b4e1c3602ba560 Mon Sep 17 00:00:00 2001 From: Kay Ousterhout Date: Tue, 7 Jul 2015 18:05:38 -0700 Subject: [PATCH] Index active task sets by stage Id rather than by task set id --- .../spark/scheduler/TaskSchedulerImpl.scala | 49 +++++++------------ .../CoarseGrainedSchedulerBackend.scala | 4 +- .../scheduler/TaskSchedulerImplSuite.scala | 4 +- 3 files changed, 23 insertions(+), 34 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index 4eebff8dbb516..0a89761108726 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -75,10 +75,9 @@ private[spark] class TaskSchedulerImpl( // TaskSetManagers are not thread safe, so any access to one should be synchronized // on this class. - val activeTaskSets = new HashMap[String, TaskSetManager] - val taskSetsByStage = new HashMap[Int, HashMap[Int, TaskSetManager]] + val stageIdToActiveTaskSet = new HashMap[Int, TaskSetManager] - val taskIdToTaskSetId = new HashMap[Long, String] + val taskIdToStageId = new HashMap[Long, Int] val taskIdToExecutorId = new HashMap[Long, String] @volatile private var hasReceivedTask = false @@ -163,17 +162,13 @@ private[spark] class TaskSchedulerImpl( logInfo("Adding task set " + taskSet.id + " with " + tasks.length + " tasks") this.synchronized { val manager = createTaskSetManager(taskSet, maxTaskFailures) - activeTaskSets(taskSet.id) = manager - val stage = taskSet.stageId - val stageTaskSets = taskSetsByStage.getOrElseUpdate(stage, new HashMap[Int, TaskSetManager]) - stageTaskSets(taskSet.attempt) = manager - val conflictingTaskSet = stageTaskSets.exists { case (_, ts) => - ts.taskSet != taskSet && !ts.isZombie - } - if (conflictingTaskSet) { - throw new IllegalStateException(s"more than one active taskSet for stage $stage:" + - s" ${stageTaskSets.toSeq.map{_._2.taskSet.id}.mkString(",")}") + stageIdToActiveTaskSet(taskSet.stageId) = manager + val stageId = taskSet.stageId + stageIdToActiveTaskSet.get(stageId).map { activeTaskSet => + throw new IllegalStateException( + s"Active taskSet with id already exists for stage $stageId: ${activeTaskSet.taskSet.id}") } + stageIdToActiveTaskSet(stageId) = manager schedulableBuilder.addTaskSetManager(manager, manager.taskSet.properties) if (!isLocal && !hasReceivedTask) { @@ -203,7 +198,7 @@ private[spark] class TaskSchedulerImpl( override def cancelTasks(stageId: Int, interruptThread: Boolean): Unit = synchronized { logInfo("Cancelling stage " + stageId) - activeTaskSets.find(_._2.stageId == stageId).foreach { case (_, tsm) => + stageIdToActiveTaskSet.get(stageId).map {tsm => // There are two possible cases here: // 1. The task set manager has been created and some tasks have been scheduled. // In this case, send a kill signal to the executors to kill the task and then abort @@ -225,13 +220,7 @@ private[spark] class TaskSchedulerImpl( * cleaned up. */ def taskSetFinished(manager: TaskSetManager): Unit = synchronized { - activeTaskSets -= manager.taskSet.id - taskSetsByStage.get(manager.taskSet.stageId).foreach { taskSetsForStage => - taskSetsForStage -= manager.taskSet.attempt - if (taskSetsForStage.isEmpty) { - taskSetsByStage -= manager.taskSet.stageId - } - } + stageIdToActiveTaskSet -= manager.stageId manager.parent.removeSchedulable(manager) logInfo("Removed TaskSet %s, whose tasks have all completed, from pool %s" .format(manager.taskSet.id, manager.parent.name)) @@ -252,7 +241,7 @@ private[spark] class TaskSchedulerImpl( for (task <- taskSet.resourceOffer(execId, host, maxLocality)) { tasks(i) += task val tid = task.taskId - taskIdToTaskSetId(tid) = taskSet.taskSet.id + taskIdToStageId(tid) = taskSet.taskSet.stageId taskIdToExecutorId(tid) = execId executorsByHost(host) += execId availableCpus(i) -= CPUS_PER_TASK @@ -336,13 +325,13 @@ private[spark] class TaskSchedulerImpl( failedExecutor = Some(execId) } } - taskIdToTaskSetId.get(tid) match { - case Some(taskSetId) => + taskIdToStageId.get(tid) match { + case Some(stageId) => if (TaskState.isFinished(state)) { - taskIdToTaskSetId.remove(tid) + taskIdToStageId.remove(tid) taskIdToExecutorId.remove(tid) } - activeTaskSets.get(taskSetId).foreach { taskSet => + stageIdToActiveTaskSet.get(stageId).foreach { taskSet => if (state == TaskState.FINISHED) { taskSet.removeRunningTask(tid) taskResultGetter.enqueueSuccessfulTask(taskSet, tid, serializedData) @@ -380,8 +369,8 @@ private[spark] class TaskSchedulerImpl( val metricsWithStageIds: Array[(Long, Int, Int, TaskMetrics)] = synchronized { taskMetrics.flatMap { case (id, metrics) => - taskIdToTaskSetId.get(id) - .flatMap(activeTaskSets.get) + taskIdToStageId.get(id) + .flatMap(stageIdToActiveTaskSet.get) .map(taskSetMgr => (id, taskSetMgr.stageId, taskSetMgr.taskSet.attempt, metrics)) } } @@ -414,9 +403,9 @@ private[spark] class TaskSchedulerImpl( def error(message: String) { synchronized { - if (activeTaskSets.nonEmpty) { + if (stageIdToActiveTaskSet.nonEmpty) { // Have each task set throw a SparkException with the error - for ((taskSetId, manager) <- activeTaskSets) { + for ((_, manager) <- stageIdToActiveTaskSet) { try { manager.abort(message) } catch { diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 7c7f70d8a193b..f2bd76aaef8ee 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -191,8 +191,8 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp for (task <- tasks.flatten) { val serializedTask = ser.serialize(task) if (serializedTask.limit >= akkaFrameSize - AkkaUtils.reservedSizeBytes) { - val taskSetId = scheduler.taskIdToTaskSetId(task.taskId) - scheduler.activeTaskSets.get(taskSetId).foreach { taskSet => + val taskSetId = scheduler.taskIdToStageId(task.taskId) + scheduler.stageIdToActiveTaskSet.get(taskSetId).foreach { taskSet => try { var msg = "Serialized task %s:%d was %d bytes, which exceeds max allowed: " + "spark.akka.frameSize (%d bytes) - reserved (%d bytes). Consider increasing " + diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala index 55be409afcf31..48eda6741b8d6 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala @@ -144,11 +144,11 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with L intercept[IllegalStateException] { taskScheduler.submitTasks(attempt2) } // OK to submit multiple if previous attempts are all zombie - taskScheduler.activeTaskSets(attempt1.id).isZombie = true + taskScheduler.stageIdToActiveTaskSet(attempt1.stageId).isZombie = true taskScheduler.submitTasks(attempt2) val attempt3 = new TaskSet(Array(new FakeTask(0)), 0, 2, 0, null) intercept[IllegalStateException] { taskScheduler.submitTasks(attempt3) } - taskScheduler.activeTaskSets(attempt2.id).isZombie = true + taskScheduler.stageIdToActiveTaskSet(attempt2.stageId).isZombie = true taskScheduler.submitTasks(attempt3) }