Skip to content

Commit

Permalink
Merge pull request alteryx#90 from pwendell/master
Browse files Browse the repository at this point in the history
SPARK-940: Do not directly pass Stage objects to SparkListener.

This patch updates the SparkListener interface to pass StageInfo objects rather than directly pass spark Stages. The reason for this patch is explained in detail in SPARK-940.
(cherry picked from commit c404adb)

Signed-off-by: Patrick Wendell <[email protected]>

foo
  • Loading branch information
pwendell authored and markhamstra committed Oct 22, 2013
1 parent 71d8b99 commit dfee409
Show file tree
Hide file tree
Showing 12 changed files with 166 additions and 118 deletions.
26 changes: 13 additions & 13 deletions core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ class DAGScheduler(
shuffleToMapStage.get(shuffleDep.shuffleId) match {
case Some(stage) => stage
case None =>
val stage = newStage(shuffleDep.rdd, Some(shuffleDep), jobId)
val stage = newStage(shuffleDep.rdd, shuffleDep.rdd.partitions.size, Some(shuffleDep), jobId)
shuffleToMapStage(shuffleDep.shuffleId) = stage
stage
}
Expand All @@ -197,6 +197,7 @@ class DAGScheduler(
*/
private def newStage(
rdd: RDD[_],
numTasks: Int,
shuffleDep: Option[ShuffleDependency[_,_]],
jobId: Int,
callSite: Option[String] = None)
Expand All @@ -209,9 +210,10 @@ class DAGScheduler(
mapOutputTracker.registerShuffle(shuffleDep.get.shuffleId, rdd.partitions.size)
}
val id = nextStageId.getAndIncrement()
val stage = new Stage(id, rdd, shuffleDep, getParentStages(rdd, jobId), jobId, callSite)
val stage =
new Stage(id, rdd, numTasks, shuffleDep, getParentStages(rdd, jobId), jobId, callSite)
stageIdToStage(id) = stage
stageToInfos(stage) = StageInfo(stage)
stageToInfos(stage) = new StageInfo(stage)
val stageIdSet = jobIdToStageIds.getOrElseUpdate(jobId, new HashSet)
stageIdSet += id
stage
Expand Down Expand Up @@ -365,7 +367,7 @@ class DAGScheduler(
private[scheduler] def processEvent(event: DAGSchedulerEvent): Boolean = {
event match {
case JobSubmitted(jobId, rdd, func, partitions, allowLocal, callSite, listener, properties) =>
val finalStage = newStage(rdd, None, jobId, Some(callSite))
val finalStage = newStage(rdd, partitions.size, None, jobId, Some(callSite))
val job = new ActiveJob(jobId, finalStage, func, partitions, callSite, listener, properties)
clearCacheLocs()
logInfo("Got job " + job.jobId + " (" + callSite + ") with " + partitions.length +
Expand Down Expand Up @@ -588,7 +590,7 @@ class DAGScheduler(

// must be run listener before possible NotSerializableException
// should be "StageSubmitted" first and then "JobEnded"
listenerBus.post(SparkListenerStageSubmitted(stage, tasks.size, properties))
listenerBus.post(SparkListenerStageSubmitted(stageToInfos(stage), properties))

if (tasks.size > 0) {
// Preemptively serialize a task to make sure it can be serialized. We are catching this
Expand All @@ -609,9 +611,7 @@ class DAGScheduler(
logDebug("New pending tasks: " + myPending)
taskSched.submitTasks(
new TaskSet(tasks.toArray, stage.id, stage.newAttemptId(), stage.jobId, properties))
if (!stage.submissionTime.isDefined) {
stage.submissionTime = Some(System.currentTimeMillis())
}
stageToInfos(stage).submissionTime = Some(System.currentTimeMillis())
} else {
logDebug("Stage " + stage + " is actually done; %b %d %d".format(
stage.isAvailable, stage.numAvailableOutputs, stage.numPartitions))
Expand All @@ -631,12 +631,12 @@ class DAGScheduler(
})

def markStageAsFinished(stage: Stage) = {
val serviceTime = stage.submissionTime match {
val serviceTime = stageToInfos(stage).submissionTime match {
case Some(t) => "%.03f".format((System.currentTimeMillis() - t) / 1000.0)
case _ => "Unkown"
case _ => "Unknown"
}
logInfo("%s (%s) finished in %s s".format(stage, stage.name, serviceTime))
stage.completionTime = Some(System.currentTimeMillis)
stageToInfos(stage).completionTime = Some(System.currentTimeMillis())
listenerBus.post(StageCompleted(stageToInfos(stage)))
running -= stage
}
Expand Down Expand Up @@ -810,10 +810,10 @@ class DAGScheduler(
val dependentStages = if (failedStage.isDefined)
resultStageToJob.keys.filter(x => stageDependsOn(x, failedStage.get)).toSeq else Seq()
failedStage.foreach {stage =>
stage.completionTime = Some(System.currentTimeMillis())
stageToInfos(stage).completionTime = Some(System.currentTimeMillis())
for (resultStage <- dependentStages) {
val job = resultStageToJob(resultStage)
val error = new SparkException("Job failed: " + reason)
val error = new SparkException("Job aborted: " + reason)
job.listener.jobFailed(error)
listenerBus.post(SparkListenerJobEnd(job, JobFailed(error, failedStage)))
idToActiveJob -= resultStage.jobId
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -207,13 +207,13 @@ class JobLogger(val logDirName: String) extends SparkListener with Logging {
}

override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted) {
stageLogInfo(stageSubmitted.stage.id, "STAGE_ID=%d STATUS=SUBMITTED TASK_SIZE=%d".format(
stageSubmitted.stage.id, stageSubmitted.taskSize))
stageLogInfo(stageSubmitted.stage.stageId,"STAGE_ID=%d STATUS=SUBMITTED TASK_SIZE=%d".format(
stageSubmitted.stage.stageId, stageSubmitted.stage.numTasks))
}

override def onStageCompleted(stageCompleted: StageCompleted) {
stageLogInfo(stageCompleted.stageInfo.stage.id,
"STAGE_ID=%d STATUS=COMPLETED".format(stageCompleted.stageInfo.stage.id))
stageLogInfo(stageCompleted.stage.stageId, "STAGE_ID=%d STATUS=COMPLETED".format(
stageCompleted.stage.stageId))
}

override def onTaskStart(taskStart: SparkListenerTaskStart) { }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@ import org.apache.spark.executor.TaskMetrics

sealed trait SparkListenerEvents

case class SparkListenerStageSubmitted(stage: Stage, taskSize: Int, properties: Properties)
case class SparkListenerStageSubmitted(stage: StageInfo, properties: Properties)
extends SparkListenerEvents

case class StageCompleted(val stageInfo: StageInfo) extends SparkListenerEvents
case class StageCompleted(val stage: StageInfo) extends SparkListenerEvents

case class SparkListenerTaskStart(task: Task[_], taskInfo: TaskInfo) extends SparkListenerEvents

Expand Down Expand Up @@ -80,7 +80,7 @@ class StatsReportListener extends SparkListener with Logging {
override def onStageCompleted(stageCompleted: StageCompleted) {
import org.apache.spark.scheduler.StatsReportListener._
implicit val sc = stageCompleted
this.logInfo("Finished stage: " + stageCompleted.stageInfo)
this.logInfo("Finished stage: " + stageCompleted.stage)
showMillisDistribution("task runtime:", (info, _) => Some(info.duration))

//shuffle write
Expand All @@ -93,7 +93,7 @@ class StatsReportListener extends SparkListener with Logging {

//runtime breakdown

val runtimePcts = stageCompleted.stageInfo.taskInfos.map{
val runtimePcts = stageCompleted.stage.taskInfos.map{
case (info, metrics) => RuntimePercentage(info.duration, metrics)
}
showDistribution("executor (non-fetch) time pct: ", Distribution(runtimePcts.map{_.executorPct * 100}), "%2.0f %%")
Expand All @@ -111,7 +111,7 @@ object StatsReportListener extends Logging {
val percentilesHeader = "\t" + percentiles.mkString("%\t") + "%"

def extractDoubleDistribution(stage:StageCompleted, getMetric: (TaskInfo,TaskMetrics) => Option[Double]): Option[Distribution] = {
Distribution(stage.stageInfo.taskInfos.flatMap{
Distribution(stage.stage.taskInfos.flatMap {
case ((info,metric)) => getMetric(info, metric)})
}

Expand Down
6 changes: 1 addition & 5 deletions core/src/main/scala/org/apache/spark/scheduler/Stage.scala
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ import org.apache.spark.storage.BlockManagerId
private[spark] class Stage(
val id: Int,
val rdd: RDD[_],
val numTasks: Int,
val shuffleDep: Option[ShuffleDependency[_,_]], // Output shuffle if stage is a map stage
val parents: List[Stage],
val jobId: Int,
Expand All @@ -49,11 +50,6 @@ private[spark] class Stage(
val numPartitions = rdd.partitions.size
val outputLocs = Array.fill[List[MapStatus]](numPartitions)(Nil)
var numAvailableOutputs = 0

/** When first task was submitted to scheduler. */
var submissionTime: Option[Long] = None
var completionTime: Option[Long] = None

private var nextAttemptId = 0

def isAvailable: Boolean = {
Expand Down
15 changes: 12 additions & 3 deletions core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,18 @@ import scala.collection._

import org.apache.spark.executor.TaskMetrics

case class StageInfo(
val stage: Stage,
class StageInfo(
stage: Stage,
val taskInfos: mutable.Buffer[(TaskInfo, TaskMetrics)] = mutable.Buffer[(TaskInfo, TaskMetrics)]()
) {
override def toString = stage.rdd.toString
val stageId = stage.id
/** When this stage was submitted from the DAGScheduler to a TaskScheduler. */
var submissionTime: Option[Long] = None
var completionTime: Option[Long] = None
val rddName = stage.rdd.name
val name = stage.name
val numPartitions = stage.numPartitions
val numTasks = stage.numTasks

override def toString = rddName
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ private[spark] class IndexPage(parent: JobProgressUI) {
val now = System.currentTimeMillis()

var activeTime = 0L
for (tasks <- listener.stageToTasksActive.values; t <- tasks) {
for (tasks <- listener.stageIdToTasksActive.values; t <- tasks) {
activeTime += t.timeRunning(now)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,52 +36,52 @@ private[spark] class JobProgressListener(val sc: SparkContext) extends SparkList
val RETAINED_STAGES = System.getProperty("spark.ui.retained_stages", "1000").toInt
val DEFAULT_POOL_NAME = "default"

val stageToPool = new HashMap[Stage, String]()
val stageToDescription = new HashMap[Stage, String]()
val poolToActiveStages = new HashMap[String, HashSet[Stage]]()
val stageIdToPool = new HashMap[Int, String]()
val stageIdToDescription = new HashMap[Int, String]()
val poolToActiveStages = new HashMap[String, HashSet[StageInfo]]()

val activeStages = HashSet[Stage]()
val completedStages = ListBuffer[Stage]()
val failedStages = ListBuffer[Stage]()
val activeStages = HashSet[StageInfo]()
val completedStages = ListBuffer[StageInfo]()
val failedStages = ListBuffer[StageInfo]()

// Total metrics reflect metrics only for completed tasks
var totalTime = 0L
var totalShuffleRead = 0L
var totalShuffleWrite = 0L

val stageToTime = HashMap[Int, Long]()
val stageToShuffleRead = HashMap[Int, Long]()
val stageToShuffleWrite = HashMap[Int, Long]()
val stageToTasksActive = HashMap[Int, HashSet[TaskInfo]]()
val stageToTasksComplete = HashMap[Int, Int]()
val stageToTasksFailed = HashMap[Int, Int]()
val stageToTaskInfos =
val stageIdToTime = HashMap[Int, Long]()
val stageIdToShuffleRead = HashMap[Int, Long]()
val stageIdToShuffleWrite = HashMap[Int, Long]()
val stageIdToTasksActive = HashMap[Int, HashSet[TaskInfo]]()
val stageIdToTasksComplete = HashMap[Int, Int]()
val stageIdToTasksFailed = HashMap[Int, Int]()
val stageIdToTaskInfos =
HashMap[Int, HashSet[(TaskInfo, Option[TaskMetrics], Option[ExceptionFailure])]]()

override def onJobStart(jobStart: SparkListenerJobStart) {}

override def onStageCompleted(stageCompleted: StageCompleted) = synchronized {
val stage = stageCompleted.stageInfo.stage
poolToActiveStages(stageToPool(stage)) -= stage
val stage = stageCompleted.stage
poolToActiveStages(stageIdToPool(stage.stageId)) -= stage
activeStages -= stage
completedStages += stage
trimIfNecessary(completedStages)
}

/** If stages is too large, remove and garbage collect old stages */
def trimIfNecessary(stages: ListBuffer[Stage]) = synchronized {
def trimIfNecessary(stages: ListBuffer[StageInfo]) = synchronized {
if (stages.size > RETAINED_STAGES) {
val toRemove = RETAINED_STAGES / 10
stages.takeRight(toRemove).foreach( s => {
stageToTaskInfos.remove(s.id)
stageToTime.remove(s.id)
stageToShuffleRead.remove(s.id)
stageToShuffleWrite.remove(s.id)
stageToTasksActive.remove(s.id)
stageToTasksComplete.remove(s.id)
stageToTasksFailed.remove(s.id)
stageToPool.remove(s)
if (stageToDescription.contains(s)) {stageToDescription.remove(s)}
stageIdToTaskInfos.remove(s.stageId)
stageIdToTime.remove(s.stageId)
stageIdToShuffleRead.remove(s.stageId)
stageIdToShuffleWrite.remove(s.stageId)
stageIdToTasksActive.remove(s.stageId)
stageIdToTasksComplete.remove(s.stageId)
stageIdToTasksFailed.remove(s.stageId)
stageIdToPool.remove(s.stageId)
if (stageIdToDescription.contains(s.stageId)) {stageIdToDescription.remove(s.stageId)}
})
stages.trimEnd(toRemove)
}
Expand All @@ -95,74 +95,79 @@ private[spark] class JobProgressListener(val sc: SparkContext) extends SparkList
val poolName = Option(stageSubmitted.properties).map {
p => p.getProperty("spark.scheduler.pool", DEFAULT_POOL_NAME)
}.getOrElse(DEFAULT_POOL_NAME)
stageToPool(stage) = poolName
stageIdToPool(stage.stageId) = poolName

val description = Option(stageSubmitted.properties).flatMap {
p => Option(p.getProperty(SparkContext.SPARK_JOB_DESCRIPTION))
}
description.map(d => stageToDescription(stage) = d)
description.map(d => stageIdToDescription(stage.stageId) = d)

val stages = poolToActiveStages.getOrElseUpdate(poolName, new HashSet[Stage]())
val stages = poolToActiveStages.getOrElseUpdate(poolName, new HashSet[StageInfo]())
stages += stage
}

override def onTaskStart(taskStart: SparkListenerTaskStart) = synchronized {
val sid = taskStart.task.stageId
val tasksActive = stageToTasksActive.getOrElseUpdate(sid, new HashSet[TaskInfo]())
val tasksActive = stageIdToTasksActive.getOrElseUpdate(sid, new HashSet[TaskInfo]())
tasksActive += taskStart.taskInfo
val taskList = stageToTaskInfos.getOrElse(
val taskList = stageIdToTaskInfos.getOrElse(
sid, HashSet[(TaskInfo, Option[TaskMetrics], Option[ExceptionFailure])]())
taskList += ((taskStart.taskInfo, None, None))
stageToTaskInfos(sid) = taskList
stageIdToTaskInfos(sid) = taskList
}

override def onTaskEnd(taskEnd: SparkListenerTaskEnd) = synchronized {
val sid = taskEnd.task.stageId
val tasksActive = stageToTasksActive.getOrElseUpdate(sid, new HashSet[TaskInfo]())
val tasksActive = stageIdToTasksActive.getOrElseUpdate(sid, new HashSet[TaskInfo]())
tasksActive -= taskEnd.taskInfo
val (failureInfo, metrics): (Option[ExceptionFailure], Option[TaskMetrics]) =
taskEnd.reason match {
case e: ExceptionFailure =>
stageToTasksFailed(sid) = stageToTasksFailed.getOrElse(sid, 0) + 1
stageIdToTasksFailed(sid) = stageIdToTasksFailed.getOrElse(sid, 0) + 1
(Some(e), e.metrics)
case _ =>
stageToTasksComplete(sid) = stageToTasksComplete.getOrElse(sid, 0) + 1
stageIdToTasksComplete(sid) = stageIdToTasksComplete.getOrElse(sid, 0) + 1
(None, Option(taskEnd.taskMetrics))
}

stageToTime.getOrElseUpdate(sid, 0L)
stageIdToTime.getOrElseUpdate(sid, 0L)
val time = metrics.map(m => m.executorRunTime).getOrElse(0)
stageToTime(sid) += time
stageIdToTime(sid) += time
totalTime += time

stageToShuffleRead.getOrElseUpdate(sid, 0L)
stageIdToShuffleRead.getOrElseUpdate(sid, 0L)
val shuffleRead = metrics.flatMap(m => m.shuffleReadMetrics).map(s =>
s.remoteBytesRead).getOrElse(0L)
stageToShuffleRead(sid) += shuffleRead
stageIdToShuffleRead(sid) += shuffleRead
totalShuffleRead += shuffleRead

stageToShuffleWrite.getOrElseUpdate(sid, 0L)
stageIdToShuffleWrite.getOrElseUpdate(sid, 0L)
val shuffleWrite = metrics.flatMap(m => m.shuffleWriteMetrics).map(s =>
s.shuffleBytesWritten).getOrElse(0L)
stageToShuffleWrite(sid) += shuffleWrite
stageIdToShuffleWrite(sid) += shuffleWrite
totalShuffleWrite += shuffleWrite

val taskList = stageToTaskInfos.getOrElse(
val taskList = stageIdToTaskInfos.getOrElse(
sid, HashSet[(TaskInfo, Option[TaskMetrics], Option[ExceptionFailure])]())
taskList -= ((taskEnd.taskInfo, None, None))
taskList += ((taskEnd.taskInfo, metrics, failureInfo))
stageToTaskInfos(sid) = taskList
stageIdToTaskInfos(sid) = taskList
}

override def onJobEnd(jobEnd: SparkListenerJobEnd) = synchronized {
jobEnd match {
case end: SparkListenerJobEnd =>
end.jobResult match {
case JobFailed(ex, Some(stage)) =>
activeStages -= stage
poolToActiveStages(stageToPool(stage)) -= stage
failedStages += stage
trimIfNecessary(failedStages)
/* If two jobs share a stage we could get this failure message twice. So we first
* check whether we've already retired this stage. */
val stageInfo = activeStages.filter(s => s.stageId == stage.id).headOption
stageInfo.foreach {s =>
activeStages -= s
poolToActiveStages(stageIdToPool(stage.id)) -= s
failedStages += s
trimIfNecessary(failedStages)
}
case _ =>
}
case _ =>
Expand Down
Loading

0 comments on commit dfee409

Please sign in to comment.