Skip to content

Commit

Permalink
[SPARK-8103][core] DAGScheduler should not submit multiple concurrent…
Browse files Browse the repository at this point in the history
… attempts for a stage

https://issues.apache.org/jira/browse/SPARK-8103

cc kayousterhout (thanks for the extra test case)

Author: Imran Rashid <[email protected]>
Author: Kay Ousterhout <[email protected]>
Author: Imran Rashid <[email protected]>

Closes apache#6750 from squito/SPARK-8103 and squashes the following commits:

fb3acfc [Imran Rashid] fix log msg
e01b7aa [Imran Rashid] fix some comments, style
584acd4 [Imran Rashid] simplify going from taskId to taskSetMgr
e43ac25 [Imran Rashid] Merge branch 'master' into SPARK-8103
6bc23af [Imran Rashid] update log msg
4470fa1 [Imran Rashid] rename
c04707e [Imran Rashid] style
88b61cc [Imran Rashid] add tests to make sure that TaskSchedulerImpl schedules correctly with zombie attempts
d7f1ef2 [Imran Rashid] get rid of activeTaskSets
a21c8b5 [Imran Rashid] Merge branch 'master' into SPARK-8103
906d626 [Imran Rashid] fix merge
109900e [Imran Rashid] Merge branch 'master' into SPARK-8103
c0d4d90 [Imran Rashid] Revert "Index active task sets by stage Id rather than by task set id"
f025154 [Imran Rashid] Merge pull request #2 from kayousterhout/imran_SPARK-8103
baf46e1 [Kay Ousterhout] Index active task sets by stage Id rather than by task set id
19685bb [Imran Rashid] switch to using latestInfo.attemptId, and add comments
a5f7c8c [Imran Rashid] remove comment for reviewers
227b40d [Imran Rashid] style
517b6e5 [Imran Rashid] get rid of SparkIllegalStateException
b2faef5 [Imran Rashid] faster check for conflicting task sets
6542b42 [Imran Rashid] remove extra stageAttemptId
ada7726 [Imran Rashid] reviewer feedback
d8eb202 [Imran Rashid] Merge branch 'master' into SPARK-8103
46bc26a [Imran Rashid] more cleanup of debug garbage
cb245da [Imran Rashid] finally found the issue ... clean up debug stuff
8c29707 [Imran Rashid] Merge branch 'master' into SPARK-8103
89a59b6 [Imran Rashid] more printlns ...
9601b47 [Imran Rashid] more debug printlns
ecb4e7d [Imran Rashid] debugging printlns
b6bc248 [Imran Rashid] style
55f4a94 [Imran Rashid] get rid of more random test case since kays tests are clearer
7021d28 [Imran Rashid] update test since listenerBus.waitUntilEmpty now throws an exception instead of returning a boolean
883fe49 [Kay Ousterhout] Unit tests for concurrent stages issue
6e14683 [Imran Rashid] unit test just to make sure we fail fast on concurrent attempts
06a0af6 [Imran Rashid] ignore for jenkins
c443def [Imran Rashid] better fix and simpler test case
28d70aa [Imran Rashid] wip on getting a better test case ...
a9bf31f [Imran Rashid] wip
  • Loading branch information
squito authored and kayousterhout committed Jul 20, 2015
1 parent c6fe9b4 commit 80e2568
Show file tree
Hide file tree
Showing 13 changed files with 383 additions and 86 deletions.
78 changes: 44 additions & 34 deletions core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -857,7 +857,6 @@ class DAGScheduler(
// Get our pending tasks and remember them in our pendingTasks entry
stage.pendingTasks.clear()


// First figure out the indexes of partition ids to compute.
val partitionsToCompute: Seq[Int] = {
stage match {
Expand Down Expand Up @@ -918,7 +917,7 @@ class DAGScheduler(
partitionsToCompute.map { id =>
val locs = getPreferredLocs(stage.rdd, id)
val part = stage.rdd.partitions(id)
new ShuffleMapTask(stage.id, taskBinary, part, locs)
new ShuffleMapTask(stage.id, stage.latestInfo.attemptId, taskBinary, part, locs)
}

case stage: ResultStage =>
Expand All @@ -927,7 +926,7 @@ class DAGScheduler(
val p: Int = job.partitions(id)
val part = stage.rdd.partitions(p)
val locs = getPreferredLocs(stage.rdd, p)
new ResultTask(stage.id, taskBinary, part, locs, id)
new ResultTask(stage.id, stage.latestInfo.attemptId, taskBinary, part, locs, id)
}
}
} catch {
Expand Down Expand Up @@ -1069,10 +1068,11 @@ class DAGScheduler(
val execId = status.location.executorId
logDebug("ShuffleMapTask finished on " + execId)
if (failedEpoch.contains(execId) && smt.epoch <= failedEpoch(execId)) {
logInfo("Ignoring possibly bogus ShuffleMapTask completion from " + execId)
logInfo(s"Ignoring possibly bogus $smt completion from executor $execId")
} else {
shuffleStage.addOutputLoc(smt.partitionId, status)
}

if (runningStages.contains(shuffleStage) && shuffleStage.pendingTasks.isEmpty) {
markStageAsFinished(shuffleStage)
logInfo("looking for newly runnable stages")
Expand Down Expand Up @@ -1132,38 +1132,48 @@ class DAGScheduler(
val failedStage = stageIdToStage(task.stageId)
val mapStage = shuffleToMapStage(shuffleId)

// It is likely that we receive multiple FetchFailed for a single stage (because we have
// multiple tasks running concurrently on different executors). In that case, it is possible
// the fetch failure has already been handled by the scheduler.
if (runningStages.contains(failedStage)) {
logInfo(s"Marking $failedStage (${failedStage.name}) as failed " +
s"due to a fetch failure from $mapStage (${mapStage.name})")
markStageAsFinished(failedStage, Some(failureMessage))
}
if (failedStage.latestInfo.attemptId != task.stageAttemptId) {
logInfo(s"Ignoring fetch failure from $task as it's from $failedStage attempt" +
s" ${task.stageAttemptId} and there is a more recent attempt for that stage " +
s"(attempt ID ${failedStage.latestInfo.attemptId}) running")
} else {

if (disallowStageRetryForTest) {
abortStage(failedStage, "Fetch failure will not retry stage due to testing config")
} else if (failedStages.isEmpty) {
// Don't schedule an event to resubmit failed stages if failed isn't empty, because
// in that case the event will already have been scheduled.
// TODO: Cancel running tasks in the stage
logInfo(s"Resubmitting $mapStage (${mapStage.name}) and " +
s"$failedStage (${failedStage.name}) due to fetch failure")
messageScheduler.schedule(new Runnable {
override def run(): Unit = eventProcessLoop.post(ResubmitFailedStages)
}, DAGScheduler.RESUBMIT_TIMEOUT, TimeUnit.MILLISECONDS)
}
failedStages += failedStage
failedStages += mapStage
// Mark the map whose fetch failed as broken in the map stage
if (mapId != -1) {
mapStage.removeOutputLoc(mapId, bmAddress)
mapOutputTracker.unregisterMapOutput(shuffleId, mapId, bmAddress)
}
// It is likely that we receive multiple FetchFailed for a single stage (because we have
// multiple tasks running concurrently on different executors). In that case, it is
// possible the fetch failure has already been handled by the scheduler.
if (runningStages.contains(failedStage)) {
logInfo(s"Marking $failedStage (${failedStage.name}) as failed " +
s"due to a fetch failure from $mapStage (${mapStage.name})")
markStageAsFinished(failedStage, Some(failureMessage))
} else {
logDebug(s"Received fetch failure from $task, but its from $failedStage which is no " +
s"longer running")
}

if (disallowStageRetryForTest) {
abortStage(failedStage, "Fetch failure will not retry stage due to testing config")
} else if (failedStages.isEmpty) {
// Don't schedule an event to resubmit failed stages if failed isn't empty, because
// in that case the event will already have been scheduled.
// TODO: Cancel running tasks in the stage
logInfo(s"Resubmitting $mapStage (${mapStage.name}) and " +
s"$failedStage (${failedStage.name}) due to fetch failure")
messageScheduler.schedule(new Runnable {
override def run(): Unit = eventProcessLoop.post(ResubmitFailedStages)
}, DAGScheduler.RESUBMIT_TIMEOUT, TimeUnit.MILLISECONDS)
}
failedStages += failedStage
failedStages += mapStage
// Mark the map whose fetch failed as broken in the map stage
if (mapId != -1) {
mapStage.removeOutputLoc(mapId, bmAddress)
mapOutputTracker.unregisterMapOutput(shuffleId, mapId, bmAddress)
}

// TODO: mark the executor as failed only if there were lots of fetch failures on it
if (bmAddress != null) {
handleExecutorLost(bmAddress.executorId, fetchFailed = true, Some(task.epoch))
// TODO: mark the executor as failed only if there were lots of fetch failures on it
if (bmAddress != null) {
handleExecutorLost(bmAddress.executorId, fetchFailed = true, Some(task.epoch))
}
}

case commitDenied: TaskCommitDenied =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,12 @@ import org.apache.spark.rdd.RDD
*/
private[spark] class ResultTask[T, U](
stageId: Int,
stageAttemptId: Int,
taskBinary: Broadcast[Array[Byte]],
partition: Partition,
@transient locs: Seq[TaskLocation],
val outputId: Int)
extends Task[U](stageId, partition.index) with Serializable {
extends Task[U](stageId, stageAttemptId, partition.index) with Serializable {

@transient private[this] val preferredLocs: Seq[TaskLocation] = {
if (locs == null) Nil else locs.toSet.toSeq
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,15 @@ import org.apache.spark.shuffle.ShuffleWriter
*/
private[spark] class ShuffleMapTask(
stageId: Int,
stageAttemptId: Int,
taskBinary: Broadcast[Array[Byte]],
partition: Partition,
@transient private var locs: Seq[TaskLocation])
extends Task[MapStatus](stageId, partition.index) with Logging {
extends Task[MapStatus](stageId, stageAttemptId, partition.index) with Logging {

/** A constructor used only in test suites. This does not require passing in an RDD. */
def this(partitionId: Int) {
this(0, null, new Partition { override def index: Int = 0 }, null)
this(0, 0, null, new Partition { override def index: Int = 0 }, null)
}

@transient private val preferredLocs: Seq[TaskLocation] = {
Expand Down
5 changes: 4 additions & 1 deletion core/src/main/scala/org/apache/spark/scheduler/Task.scala
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,10 @@ import org.apache.spark.util.Utils
* @param stageId id of the stage this task belongs to
* @param partitionId index of the number in the RDD
*/
private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) extends Serializable {
private[spark] abstract class Task[T](
val stageId: Int,
val stageAttemptId: Int,
var partitionId: Int) extends Serializable {

/**
* The key of the Map is the accumulator id and the value of the Map is the latest accumulator
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,9 @@ private[spark] class TaskSchedulerImpl(

// TaskSetManagers are not thread safe, so any access to one should be synchronized
// on this class.
val activeTaskSets = new HashMap[String, TaskSetManager]
private val taskSetsByStageIdAndAttempt = new HashMap[Int, HashMap[Int, TaskSetManager]]

val taskIdToTaskSetId = new HashMap[Long, String]
private[scheduler] val taskIdToTaskSetManager = new HashMap[Long, TaskSetManager]
val taskIdToExecutorId = new HashMap[Long, String]

@volatile private var hasReceivedTask = false
Expand Down Expand Up @@ -162,7 +162,17 @@ private[spark] class TaskSchedulerImpl(
logInfo("Adding task set " + taskSet.id + " with " + tasks.length + " tasks")
this.synchronized {
val manager = createTaskSetManager(taskSet, maxTaskFailures)
activeTaskSets(taskSet.id) = manager
val stage = taskSet.stageId
val stageTaskSets =
taskSetsByStageIdAndAttempt.getOrElseUpdate(stage, new HashMap[Int, TaskSetManager])
stageTaskSets(taskSet.stageAttemptId) = manager
val conflictingTaskSet = stageTaskSets.exists { case (_, ts) =>
ts.taskSet != taskSet && !ts.isZombie
}
if (conflictingTaskSet) {
throw new IllegalStateException(s"more than one active taskSet for stage $stage:" +
s" ${stageTaskSets.toSeq.map{_._2.taskSet.id}.mkString(",")}")
}
schedulableBuilder.addTaskSetManager(manager, manager.taskSet.properties)

if (!isLocal && !hasReceivedTask) {
Expand Down Expand Up @@ -192,19 +202,21 @@ private[spark] class TaskSchedulerImpl(

override def cancelTasks(stageId: Int, interruptThread: Boolean): Unit = synchronized {
logInfo("Cancelling stage " + stageId)
activeTaskSets.find(_._2.stageId == stageId).foreach { case (_, tsm) =>
// There are two possible cases here:
// 1. The task set manager has been created and some tasks have been scheduled.
// In this case, send a kill signal to the executors to kill the task and then abort
// the stage.
// 2. The task set manager has been created but no tasks has been scheduled. In this case,
// simply abort the stage.
tsm.runningTasksSet.foreach { tid =>
val execId = taskIdToExecutorId(tid)
backend.killTask(tid, execId, interruptThread)
taskSetsByStageIdAndAttempt.get(stageId).foreach { attempts =>
attempts.foreach { case (_, tsm) =>
// There are two possible cases here:
// 1. The task set manager has been created and some tasks have been scheduled.
// In this case, send a kill signal to the executors to kill the task and then abort
// the stage.
// 2. The task set manager has been created but no tasks has been scheduled. In this case,
// simply abort the stage.
tsm.runningTasksSet.foreach { tid =>
val execId = taskIdToExecutorId(tid)
backend.killTask(tid, execId, interruptThread)
}
tsm.abort("Stage %s cancelled".format(stageId))
logInfo("Stage %d was cancelled".format(stageId))
}
tsm.abort("Stage %s cancelled".format(stageId))
logInfo("Stage %d was cancelled".format(stageId))
}
}

Expand All @@ -214,7 +226,12 @@ private[spark] class TaskSchedulerImpl(
* cleaned up.
*/
def taskSetFinished(manager: TaskSetManager): Unit = synchronized {
activeTaskSets -= manager.taskSet.id
taskSetsByStageIdAndAttempt.get(manager.taskSet.stageId).foreach { taskSetsForStage =>
taskSetsForStage -= manager.taskSet.stageAttemptId
if (taskSetsForStage.isEmpty) {
taskSetsByStageIdAndAttempt -= manager.taskSet.stageId
}
}
manager.parent.removeSchedulable(manager)
logInfo("Removed TaskSet %s, whose tasks have all completed, from pool %s"
.format(manager.taskSet.id, manager.parent.name))
Expand All @@ -235,7 +252,7 @@ private[spark] class TaskSchedulerImpl(
for (task <- taskSet.resourceOffer(execId, host, maxLocality)) {
tasks(i) += task
val tid = task.taskId
taskIdToTaskSetId(tid) = taskSet.taskSet.id
taskIdToTaskSetManager(tid) = taskSet
taskIdToExecutorId(tid) = execId
executorsByHost(host) += execId
availableCpus(i) -= CPUS_PER_TASK
Expand Down Expand Up @@ -319,26 +336,24 @@ private[spark] class TaskSchedulerImpl(
failedExecutor = Some(execId)
}
}
taskIdToTaskSetId.get(tid) match {
case Some(taskSetId) =>
taskIdToTaskSetManager.get(tid) match {
case Some(taskSet) =>
if (TaskState.isFinished(state)) {
taskIdToTaskSetId.remove(tid)
taskIdToTaskSetManager.remove(tid)
taskIdToExecutorId.remove(tid)
}
activeTaskSets.get(taskSetId).foreach { taskSet =>
if (state == TaskState.FINISHED) {
taskSet.removeRunningTask(tid)
taskResultGetter.enqueueSuccessfulTask(taskSet, tid, serializedData)
} else if (Set(TaskState.FAILED, TaskState.KILLED, TaskState.LOST).contains(state)) {
taskSet.removeRunningTask(tid)
taskResultGetter.enqueueFailedTask(taskSet, tid, state, serializedData)
}
if (state == TaskState.FINISHED) {
taskSet.removeRunningTask(tid)
taskResultGetter.enqueueSuccessfulTask(taskSet, tid, serializedData)
} else if (Set(TaskState.FAILED, TaskState.KILLED, TaskState.LOST).contains(state)) {
taskSet.removeRunningTask(tid)
taskResultGetter.enqueueFailedTask(taskSet, tid, state, serializedData)
}
case None =>
logError(
("Ignoring update with state %s for TID %s because its task set is gone (this is " +
"likely the result of receiving duplicate task finished status updates)")
.format(state, tid))
"likely the result of receiving duplicate task finished status updates)")
.format(state, tid))
}
} catch {
case e: Exception => logError("Exception in statusUpdate", e)
Expand All @@ -363,9 +378,9 @@ private[spark] class TaskSchedulerImpl(

val metricsWithStageIds: Array[(Long, Int, Int, TaskMetrics)] = synchronized {
taskMetrics.flatMap { case (id, metrics) =>
taskIdToTaskSetId.get(id)
.flatMap(activeTaskSets.get)
.map(taskSetMgr => (id, taskSetMgr.stageId, taskSetMgr.taskSet.attempt, metrics))
taskIdToTaskSetManager.get(id).map { taskSetMgr =>
(id, taskSetMgr.stageId, taskSetMgr.taskSet.stageAttemptId, metrics)
}
}
}
dagScheduler.executorHeartbeatReceived(execId, metricsWithStageIds, blockManagerId)
Expand Down Expand Up @@ -397,9 +412,12 @@ private[spark] class TaskSchedulerImpl(

def error(message: String) {
synchronized {
if (activeTaskSets.nonEmpty) {
if (taskSetsByStageIdAndAttempt.nonEmpty) {
// Have each task set throw a SparkException with the error
for ((taskSetId, manager) <- activeTaskSets) {
for {
attempts <- taskSetsByStageIdAndAttempt.values
manager <- attempts.values
} {
try {
manager.abort(message)
} catch {
Expand Down Expand Up @@ -520,6 +538,17 @@ private[spark] class TaskSchedulerImpl(

override def applicationAttemptId(): Option[String] = backend.applicationAttemptId()

private[scheduler] def taskSetManagerForAttempt(
stageId: Int,
stageAttemptId: Int): Option[TaskSetManager] = {
for {
attempts <- taskSetsByStageIdAndAttempt.get(stageId)
manager <- attempts.get(stageAttemptId)
} yield {
manager
}
}

}


Expand Down
4 changes: 2 additions & 2 deletions core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@ import java.util.Properties
private[spark] class TaskSet(
val tasks: Array[Task[_]],
val stageId: Int,
val attempt: Int,
val stageAttemptId: Int,
val priority: Int,
val properties: Properties) {
val id: String = stageId + "." + attempt
val id: String = stageId + "." + stageAttemptId

override def toString: String = "TaskSet " + id
}
Original file line number Diff line number Diff line change
Expand Up @@ -191,15 +191,14 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp
for (task <- tasks.flatten) {
val serializedTask = ser.serialize(task)
if (serializedTask.limit >= akkaFrameSize - AkkaUtils.reservedSizeBytes) {
val taskSetId = scheduler.taskIdToTaskSetId(task.taskId)
scheduler.activeTaskSets.get(taskSetId).foreach { taskSet =>
scheduler.taskIdToTaskSetManager.get(task.taskId).foreach { taskSetMgr =>
try {
var msg = "Serialized task %s:%d was %d bytes, which exceeds max allowed: " +
"spark.akka.frameSize (%d bytes) - reserved (%d bytes). Consider increasing " +
"spark.akka.frameSize or using broadcast variables for large values."
msg = msg.format(task.taskId, task.index, serializedTask.limit, akkaFrameSize,
AkkaUtils.reservedSizeBytes)
taskSet.abort(msg)
taskSetMgr.abort(msg)
} catch {
case e: Exception => logError("Exception in error callback", e)
}
Expand Down
Loading

0 comments on commit 80e2568

Please sign in to comment.