Skip to content

Commit

Permalink
Merge pull request #2 from kayousterhout/imran_SPARK-8103
Browse files Browse the repository at this point in the history
Index active task sets by stage Id rather than by task set id
  • Loading branch information
squito committed Jul 14, 2015
2 parents 19685bb + baf46e1 commit f025154
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,9 @@ private[spark] class TaskSchedulerImpl(

// TaskSetManagers are not thread safe, so any access to one should be synchronized
// on this class.
val activeTaskSets = new HashMap[String, TaskSetManager]
val taskSetsByStage = new HashMap[Int, HashMap[Int, TaskSetManager]]
val stageIdToActiveTaskSet = new HashMap[Int, TaskSetManager]

val taskIdToTaskSetId = new HashMap[Long, String]
val taskIdToStageId = new HashMap[Long, Int]
val taskIdToExecutorId = new HashMap[Long, String]

@volatile private var hasReceivedTask = false
Expand Down Expand Up @@ -163,17 +162,13 @@ private[spark] class TaskSchedulerImpl(
logInfo("Adding task set " + taskSet.id + " with " + tasks.length + " tasks")
this.synchronized {
val manager = createTaskSetManager(taskSet, maxTaskFailures)
activeTaskSets(taskSet.id) = manager
val stage = taskSet.stageId
val stageTaskSets = taskSetsByStage.getOrElseUpdate(stage, new HashMap[Int, TaskSetManager])
stageTaskSets(taskSet.attempt) = manager
val conflictingTaskSet = stageTaskSets.exists { case (_, ts) =>
ts.taskSet != taskSet && !ts.isZombie
}
if (conflictingTaskSet) {
throw new IllegalStateException(s"more than one active taskSet for stage $stage:" +
s" ${stageTaskSets.toSeq.map{_._2.taskSet.id}.mkString(",")}")
stageIdToActiveTaskSet(taskSet.stageId) = manager
val stageId = taskSet.stageId
stageIdToActiveTaskSet.get(stageId).map { activeTaskSet =>
throw new IllegalStateException(
s"Active taskSet with id already exists for stage $stageId: ${activeTaskSet.taskSet.id}")
}
stageIdToActiveTaskSet(stageId) = manager
schedulableBuilder.addTaskSetManager(manager, manager.taskSet.properties)

if (!isLocal && !hasReceivedTask) {
Expand Down Expand Up @@ -203,7 +198,7 @@ private[spark] class TaskSchedulerImpl(

override def cancelTasks(stageId: Int, interruptThread: Boolean): Unit = synchronized {
logInfo("Cancelling stage " + stageId)
activeTaskSets.find(_._2.stageId == stageId).foreach { case (_, tsm) =>
stageIdToActiveTaskSet.get(stageId).map {tsm =>
// There are two possible cases here:
// 1. The task set manager has been created and some tasks have been scheduled.
// In this case, send a kill signal to the executors to kill the task and then abort
Expand All @@ -225,13 +220,7 @@ private[spark] class TaskSchedulerImpl(
* cleaned up.
*/
def taskSetFinished(manager: TaskSetManager): Unit = synchronized {
activeTaskSets -= manager.taskSet.id
taskSetsByStage.get(manager.taskSet.stageId).foreach { taskSetsForStage =>
taskSetsForStage -= manager.taskSet.attempt
if (taskSetsForStage.isEmpty) {
taskSetsByStage -= manager.taskSet.stageId
}
}
stageIdToActiveTaskSet -= manager.stageId
manager.parent.removeSchedulable(manager)
logInfo("Removed TaskSet %s, whose tasks have all completed, from pool %s"
.format(manager.taskSet.id, manager.parent.name))
Expand All @@ -252,7 +241,7 @@ private[spark] class TaskSchedulerImpl(
for (task <- taskSet.resourceOffer(execId, host, maxLocality)) {
tasks(i) += task
val tid = task.taskId
taskIdToTaskSetId(tid) = taskSet.taskSet.id
taskIdToStageId(tid) = taskSet.taskSet.stageId
taskIdToExecutorId(tid) = execId
executorsByHost(host) += execId
availableCpus(i) -= CPUS_PER_TASK
Expand Down Expand Up @@ -336,13 +325,13 @@ private[spark] class TaskSchedulerImpl(
failedExecutor = Some(execId)
}
}
taskIdToTaskSetId.get(tid) match {
case Some(taskSetId) =>
taskIdToStageId.get(tid) match {
case Some(stageId) =>
if (TaskState.isFinished(state)) {
taskIdToTaskSetId.remove(tid)
taskIdToStageId.remove(tid)
taskIdToExecutorId.remove(tid)
}
activeTaskSets.get(taskSetId).foreach { taskSet =>
stageIdToActiveTaskSet.get(stageId).foreach { taskSet =>
if (state == TaskState.FINISHED) {
taskSet.removeRunningTask(tid)
taskResultGetter.enqueueSuccessfulTask(taskSet, tid, serializedData)
Expand Down Expand Up @@ -380,8 +369,8 @@ private[spark] class TaskSchedulerImpl(

val metricsWithStageIds: Array[(Long, Int, Int, TaskMetrics)] = synchronized {
taskMetrics.flatMap { case (id, metrics) =>
taskIdToTaskSetId.get(id)
.flatMap(activeTaskSets.get)
taskIdToStageId.get(id)
.flatMap(stageIdToActiveTaskSet.get)
.map(taskSetMgr => (id, taskSetMgr.stageId, taskSetMgr.taskSet.attempt, metrics))
}
}
Expand Down Expand Up @@ -414,9 +403,9 @@ private[spark] class TaskSchedulerImpl(

def error(message: String) {
synchronized {
if (activeTaskSets.nonEmpty) {
if (stageIdToActiveTaskSet.nonEmpty) {
// Have each task set throw a SparkException with the error
for ((taskSetId, manager) <- activeTaskSets) {
for ((_, manager) <- stageIdToActiveTaskSet) {
try {
manager.abort(message)
} catch {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -191,8 +191,8 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp
for (task <- tasks.flatten) {
val serializedTask = ser.serialize(task)
if (serializedTask.limit >= akkaFrameSize - AkkaUtils.reservedSizeBytes) {
val taskSetId = scheduler.taskIdToTaskSetId(task.taskId)
scheduler.activeTaskSets.get(taskSetId).foreach { taskSet =>
val taskSetId = scheduler.taskIdToStageId(task.taskId)
scheduler.stageIdToActiveTaskSet.get(taskSetId).foreach { taskSet =>
try {
var msg = "Serialized task %s:%d was %d bytes, which exceeds max allowed: " +
"spark.akka.frameSize (%d bytes) - reserved (%d bytes). Consider increasing " +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,11 +144,11 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with L
intercept[IllegalStateException] { taskScheduler.submitTasks(attempt2) }

// OK to submit multiple if previous attempts are all zombie
taskScheduler.activeTaskSets(attempt1.id).isZombie = true
taskScheduler.stageIdToActiveTaskSet(attempt1.stageId).isZombie = true
taskScheduler.submitTasks(attempt2)
val attempt3 = new TaskSet(Array(new FakeTask(0)), 0, 2, 0, null)
intercept[IllegalStateException] { taskScheduler.submitTasks(attempt3) }
taskScheduler.activeTaskSets(attempt2.id).isZombie = true
taskScheduler.stageIdToActiveTaskSet(attempt2.stageId).isZombie = true
taskScheduler.submitTasks(attempt3)
}

Expand Down

0 comments on commit f025154

Please sign in to comment.