Skip to content

Commit

Permalink
[SNAP-1190] Reduce partition message overhead from driver to executor (
Browse files Browse the repository at this point in the history
…#31)

- DAGScheduler:
  - For small enough common task data (RDD + closure) send inline with the Task instead of a broadcast
  - Transiently store task binary data in Stage to re-use if possible
  - Compress the common task bytes to save on network cost
- Task: New TaskData class to encapsulate task compressed bytes from above, the uncompressed length
  and reference index if TaskData is being read from a separate list (see next comments)
- CoarseGrainedClusterMessage: Added new LaunchTasks message to encapsulate multiple
  Task messages to same executor
- CoarseGrainedSchedulerBackend:
  - Create LaunchTasks by grouping messages in ExecutorTaskGroup per executor
  - Actual TaskData is sent as part of TaskDescription and not the Task to easily
    separate out the common portions in a separate list
  - Send the common TaskData as a separate ArrayBuffer of data with the index into this
    list set in the original task's TaskData
- CoarseGrainedExecutorBackend: Handle LaunchTasks by splitting into individual jobs
- CompressionCodec: added bytes compress/decompress methods for more efficient byte array compression
- Executor:
  - Set the common decompressed task data back into the Task object.
  - Avoid additional serialization of TaskResult just to determine the serialization time.
    Instead now calculate the time inline during serialization write/writeExternal methods
- TaskMetrics: more generic handling for DoubleAccumulator case
- Task: Handling of TaskData during serialization to send a flag to indicate whether
  data is inlined or will be received via broadcast
- ResultTask, ShuffleMapTask: delegate handling of TaskData to parent Task class
- SparkEnv: encapsulate codec creation as a zero-arg function to avoid repeated conf lookups
- SparkContext.clean: avoid checking serializability in case non-default closure serializer is being used
- Test updates for above
Conflicts:
	core/src/main/scala/org/apache/spark/SparkEnv.scala
	core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
	core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
	core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala
	core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
	core/src/main/scala/org/apache/spark/scheduler/Task.scala
	core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala
	core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
	core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
  • Loading branch information
Sumedh Wale committed Jul 8, 2017
1 parent 29a4205 commit 22141bd
Show file tree
Hide file tree
Showing 25 changed files with 535 additions and 150 deletions.
6 changes: 5 additions & 1 deletion core/src/main/scala/org/apache/spark/SparkContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ import org.apache.spark.rpc.RpcEndpointRef
import org.apache.spark.scheduler._
import org.apache.spark.scheduler.cluster.{CoarseGrainedSchedulerBackend, StandaloneSchedulerBackend}
import org.apache.spark.scheduler.local.LocalSchedulerBackend
import org.apache.spark.serializer.JavaSerializer
import org.apache.spark.storage._
import org.apache.spark.storage.BlockManagerMessages.TriggerThreadDump
import org.apache.spark.ui.{ConsoleProgressBar, SparkUI}
Expand Down Expand Up @@ -233,6 +234,7 @@ class SparkContext(config: SparkConf) extends Logging {
private var _jars: Seq[String] = _
private var _files: Seq[String] = _
private var _shutdownHookRef: AnyRef = _
private var _isDefaultClosureSerializer: Boolean = true

/* ------------------------------------------------------------------------------------- *
| Accessors and public fields. These provide access to the internal state of the |
Expand Down Expand Up @@ -450,6 +452,8 @@ class SparkContext(config: SparkConf) extends Logging {
_env = createSparkEnv(_conf, isLocal, listenerBus)
SparkEnv.set(_env)

_isDefaultClosureSerializer = _env.closureSerializer.isInstanceOf[JavaSerializer]

// If running the REPL, register the repl's output dir with the file server.
_conf.getOption("spark.repl.class.outputDir").foreach { path =>
val replUri = _env.rpcEnv.fileServer.addDirectory("/classes", new File(path))
Expand Down Expand Up @@ -2109,7 +2113,7 @@ class SparkContext(config: SparkConf) extends Logging {
* serializable
*/
private[spark] def clean[F <: AnyRef](f: F, checkSerializable: Boolean = true): F = {
ClosureCleaner.clean(f, checkSerializable)
ClosureCleaner.clean(f, checkSerializable && _isDefaultClosureSerializer)
f
}

Expand Down
6 changes: 6 additions & 0 deletions core/src/main/scala/org/apache/spark/SparkEnv.scala
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ import org.apache.spark.api.python.PythonWorkerFactory
import org.apache.spark.broadcast.BroadcastManager
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config._
import org.apache.spark.io.CompressionCodec
import org.apache.spark.memory.{MemoryManager, StaticMemoryManager, UnifiedMemoryManager}
import org.apache.spark.metrics.MetricsSystem
import org.apache.spark.network.netty.NettyBlockTransferService
Expand Down Expand Up @@ -96,6 +97,11 @@ class SparkEnv (

private[spark] var driverTmpDir: Option[String] = None

private val codecCreator = CompressionCodec.codecCreator(conf,
CompressionCodec.getCodecName(conf))

def createCompressionCodec: CompressionCodec = codecCreator()

private[spark] def stop() {

if (!isStopped) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ import org.apache.spark.internal.Logging
import org.apache.spark.rpc._
import org.apache.spark.scheduler.{ExecutorLossReason, TaskDescription}
import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._
import org.apache.spark.serializer.SerializerInstance
import org.apache.spark.util.{ThreadUtils, Utils}

private[spark] class CoarseGrainedExecutorBackend(
Expand All @@ -50,10 +49,6 @@ private[spark] class CoarseGrainedExecutorBackend(
var executor: Executor = null
@volatile var driver: Option[RpcEndpointRef] = None

// If this CoarseGrainedExecutorBackend is changed to support multiple threads, then this may need
// to be changed so that we don't share the serializer instance across threads
private[this] val ser: SerializerInstance = env.closureSerializer.newInstance()

override def onStart() {
logInfo("Connecting to driver: " + driverUrl)
rpcEnv.asyncSetupEndpointRefByURI(driverUrl).flatMap { ref =>
Expand Down Expand Up @@ -91,14 +86,28 @@ private[spark] class CoarseGrainedExecutorBackend(
case RegisterExecutorFailed(message) =>
exitExecutor(1, "Slave registration failed: " + message)

case LaunchTask(data) =>
case LaunchTask(taskDesc) =>
if (executor == null) {
exitExecutor(1, "Received LaunchTask command but executor was null")
} else {
val taskDesc = ser.deserialize[TaskDescription](data.value)
logInfo("Got assigned task " + taskDesc.taskId)
executor.launchTask(this, taskId = taskDesc.taskId, attemptNumber = taskDesc.attemptNumber,
taskDesc.name, taskDesc.serializedTask)
taskDesc.name, taskDesc.serializedTask, taskDesc.taskData.decompress(env))
}

case LaunchTasks(tasks, taskDataList) =>
if (executor ne null) {
logDebug("Got assigned tasks " + tasks.map(_.taskId).mkString(","))
for (task <- tasks) {
logInfo("Got assigned task " + task.taskId)
val ref = task.taskData.reference
val taskData = if (ref >= 0) taskDataList(ref) else task.taskData
executor.launchTask(this, taskId = task.taskId,
attemptNumber = task.attemptNumber, task.name, task.serializedTask,
taskData.decompress(env))
}
} else {
exitExecutor(1, "Received LaunchTasks command but executor was null")
}

case KillTask(taskId, _, interruptThread) =>
Expand Down
26 changes: 12 additions & 14 deletions core/src/main/scala/org/apache/spark/executor/Executor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -140,9 +140,10 @@ private[spark] class Executor(
taskId: Long,
attemptNumber: Int,
taskName: String,
serializedTask: ByteBuffer): Unit = {
val tr = new TaskRunner(context, taskId = taskId, attemptNumber = attemptNumber, taskName,
serializedTask)
serializedTask: ByteBuffer,
taskDataBytes: Array[Byte]): Unit = {
val tr = new TaskRunner(context, taskId = taskId, attemptNumber = attemptNumber,
taskName, serializedTask, taskDataBytes)
runningTasks.put(taskId, tr)
threadPool.execute(tr)
}
Expand Down Expand Up @@ -189,7 +190,8 @@ private[spark] class Executor(
val taskId: Long,
val attemptNumber: Int,
taskName: String,
serializedTask: ByteBuffer)
serializedTask: ByteBuffer,
taskDataBytes: Array[Byte])
extends Runnable {

/** Whether this task has been killed. */
Expand Down Expand Up @@ -256,6 +258,7 @@ private[spark] class Executor(

updateDependencies(taskFiles, taskJars)
task = ser.deserialize[Task[Any]](taskBytes, Thread.currentThread.getContextClassLoader)
task.taskDataBytes = taskDataBytes
task.localProperties = taskProps
task.setTaskMemoryManager(taskMemoryManager)

Expand Down Expand Up @@ -319,11 +322,6 @@ private[spark] class Executor(
throw new TaskKilledException
}

val resultSer = env.serializer.newInstance()
val beforeSerialization = System.nanoTime()
val valueBytes = resultSer.serialize(value)
val afterSerialization = System.nanoTime()

// Deserialization happens in two parts: first, we deserialize a Task object, which
// includes the Partition. Second, Task.run() deserializes the RDD and function to be run.
task.metrics.setExecutorDeserializeTime(math.max(
Expand All @@ -336,13 +334,13 @@ private[spark] class Executor(
task.metrics.setExecutorCpuTime(
(taskFinishCpu - taskStartCpu) - task.executorDeserializeCpuTime)
task.metrics.setJvmGCTime(computeTotalGcTime() - startGCTime)
task.metrics.setResultSerializationTime(math.max(
afterSerialization - beforeSerialization, 0L) / 1000000.0)

// Note: accumulator updates must be collected after TaskMetrics is updated
// Now resultSerializationTime is evaluated directly inside the
// serialization write methods and added to final serialized bytes
// to avoid double serialization of Task (for timing then TaskResult).
val accumUpdates = task.collectAccumulatorUpdates()
// TODO: do not serialize value twice
val directResult = new DirectTaskResult(valueBytes, accumUpdates)
val directResult = new DirectTaskResult(value, accumUpdates,
Some(task.metrics.resultSerializationTimeMetric))
val serializedDirectResult = ser.serialize(directResult)
val resultSize = serializedDirectResult.limit

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ class TaskMetrics private[spark] () extends Serializable with KryoSerializable {
private[spark] def setJvmGCTime(v: Long): Unit = _jvmGCTime.setValue(v)
private[spark] def setResultSerializationTime(v: Double): Unit =
_resultSerializationTime.setValue(v)
private[spark] def resultSerializationTimeMetric = _resultSerializationTime
private[spark] def incMemoryBytesSpilled(v: Long): Unit = _memoryBytesSpilled.add(v)
private[spark] def incDiskBytesSpilled(v: Long): Unit = _diskBytesSpilled.add(v)
private[spark] def incPeakExecutionMemory(v: Long): Unit = _peakExecutionMemory.add(v)
Expand Down Expand Up @@ -334,7 +335,12 @@ private[spark] object TaskMetrics extends Logging {
} else {
tm.nameToAccums.get(name).foreach {
case l: LongAccumulator => l.setValue(value.asInstanceOf[Long])
case d => d.asInstanceOf[DoubleAccumulator].setValue(value.asInstanceOf[Double])
case d: DoubleAccumulator => value match {
case v: Double => d.setValue(v)
case _ => d.setValue(value.asInstanceOf[Long])
}
case o => throw new UnsupportedOperationException(
s"Unexpected accumulator $o for TaskMetrics")
}
}
}
Expand Down
69 changes: 61 additions & 8 deletions core/src/main/scala/org/apache/spark/io/CompressionCodec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ package org.apache.spark.io

import java.io._

import com.ning.compress.lzf.{LZFInputStream, LZFOutputStream}
import net.jpountz.lz4.LZ4BlockOutputStream
import com.ning.compress.lzf.{LZFDecoder, LZFEncoder, LZFInputStream, LZFOutputStream}
import net.jpountz.lz4.{LZ4BlockOutputStream, LZ4Factory}
import org.xerial.snappy.{Snappy, SnappyInputStream, SnappyOutputStream}

import org.apache.spark.SparkConf
Expand All @@ -41,6 +41,11 @@ trait CompressionCodec {
def compressedOutputStream(s: OutputStream): OutputStream

def compressedInputStream(s: InputStream): InputStream

def compress(input: Array[Byte], inputLen: Int): Array[Byte]

def decompress(input: Array[Byte], inputOffset: Int, inputLen: Int,
outputLen: Int): Array[Byte]
}

private[spark] object CompressionCodec {
Expand All @@ -66,16 +71,32 @@ private[spark] object CompressionCodec {
}

def createCodec(conf: SparkConf, codecName: String): CompressionCodec = {
codecCreator(conf, codecName)()
}

def codecCreator(conf: SparkConf, codecName: String): () => CompressionCodec = {
if (codecName == DEFAULT_COMPRESSION_CODEC) {
return () => new LZ4CompressionCodec(conf)
}
val codecClass = shortCompressionCodecNames.getOrElse(codecName.toLowerCase, codecName)
val codec = try {
try {
val ctor = Utils.classForName(codecClass).getConstructor(classOf[SparkConf])
Some(ctor.newInstance(conf).asInstanceOf[CompressionCodec])
() => {
try {
ctor.newInstance(conf).asInstanceOf[CompressionCodec]
} catch {
case e: IllegalArgumentException => throw fail(codecName)
}
}
} catch {
case e: ClassNotFoundException => None
case e: IllegalArgumentException => None
case e: ClassNotFoundException => throw fail(codecName)
case e: NoSuchMethodException => throw fail(codecName)
}
codec.getOrElse(throw new IllegalArgumentException(s"Codec [$codecName] is not available. " +
s"Consider setting $configKey=$FALLBACK_COMPRESSION_CODEC"))
}

private def fail(codecName: String): IllegalArgumentException = {
new IllegalArgumentException(s"Codec [$codecName] is not available. " +
s"Consider setting $configKey=$FALLBACK_COMPRESSION_CODEC")
}

/**
Expand Down Expand Up @@ -115,6 +136,16 @@ class LZ4CompressionCodec(conf: SparkConf) extends CompressionCodec {
}

override def compressedInputStream(s: InputStream): InputStream = new LZ4BlockInputStream(s)

override def compress(input: Array[Byte], inputLen: Int): Array[Byte] = {
LZ4Factory.fastestInstance().fastCompressor().compress(input, 0, inputLen)
}

override def decompress(input: Array[Byte], inputOffset: Int, inputLen: Int,
outputLen: Int): Array[Byte] = {
LZ4Factory.fastestInstance().fastDecompressor().decompress(input,
inputOffset, outputLen)
}
}


Expand All @@ -134,6 +165,17 @@ class LZFCompressionCodec(conf: SparkConf) extends CompressionCodec {
}

override def compressedInputStream(s: InputStream): InputStream = new LZFInputStream(s)

override def compress(input: Array[Byte], inputLen: Int): Array[Byte] = {
LZFEncoder.encode(input, 0, inputLen)
}

override def decompress(input: Array[Byte], inputOffset: Int, inputLen: Int,
outputLen: Int): Array[Byte] = {
val output = new Array[Byte](outputLen)
LZFDecoder.decode(input, inputOffset, inputLen, output)
output
}
}


Expand All @@ -156,6 +198,17 @@ class SnappyCompressionCodec(conf: SparkConf) extends CompressionCodec {
}

override def compressedInputStream(s: InputStream): InputStream = new SnappyInputStream(s)

override def compress(input: Array[Byte], inputLen: Int): Array[Byte] = {
Snappy.rawCompress(input, inputLen)
}

override def decompress(input: Array[Byte], inputOffset: Int,
inputLen: Int, outputLen: Int): Array[Byte] = {
val output = new Array[Byte](outputLen)
Snappy.uncompress(input, inputOffset, inputLen, output, 0)
output
}
}

/**
Expand Down
39 changes: 32 additions & 7 deletions core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -981,19 +981,36 @@ class DAGScheduler(
// task gets a different copy of the RDD. This provides stronger isolation between tasks that
// might modify state of objects referenced in their closures. This is necessary in Hadoop
// where the JobConf/Configuration object is not thread-safe.
var taskBinary: Broadcast[Array[Byte]] = null
var taskBinary: Option[Broadcast[Array[Byte]]] = None
var taskData: TaskData = TaskData.EMPTY
try {
// For ShuffleMapTask, serialize and broadcast (rdd, shuffleDep).
// For ResultTask, serialize and broadcast (rdd, func).
val taskBinaryBytes: Array[Byte] = stage match {
val bytes = stage.taskBinaryBytes
val taskBinaryBytes: Array[Byte] = if (bytes != null) bytes else stage match {
case stage: ShuffleMapStage =>
JavaUtils.bufferToArray(
closureSerializer.serialize((stage.rdd, stage.shuffleDep): AnyRef))
case stage: ResultStage =>
JavaUtils.bufferToArray(closureSerializer.serialize((stage.rdd, stage.func): AnyRef))
}

taskBinary = sc.broadcast(taskBinaryBytes)
if (bytes == null) stage.taskBinaryBytes = taskBinaryBytes

// use direct byte shipping for small size or if number of partitions is small
val taskBytesLen = taskBinaryBytes.length
if (taskBytesLen <= DAGScheduler.TASK_INLINE_LIMIT ||
partitionsToCompute.length <= DAGScheduler.TASK_INLINE_PARTITION_LIMIT) {
if (stage.taskData.uncompressedLen > 0) {
taskData = stage.taskData
} else {
// compress inline task data (broadcast compresses as per conf)
taskData = new TaskData(env.createCompressionCodec.compress(
taskBinaryBytes, taskBytesLen), taskBytesLen)
stage.taskData = taskData
}
} else {
taskBinary = Some(sc.broadcast(taskBinaryBytes))
}
} catch {
// In the case of a failure during serialization, abort the stage.
case e: NotSerializableException =>
Expand All @@ -1014,7 +1031,7 @@ class DAGScheduler(
partitionsToCompute.map { id =>
val locs = taskIdToLocations(id)
val part = stage.rdd.partitions(id)
new ShuffleMapTask(stage.id, stage.latestInfo.attemptId,
new ShuffleMapTask(stage.id, stage.latestInfo.attemptId, taskData,
taskBinary, part, locs, stage.latestInfo.taskMetrics, properties, Option(jobId),
Option(sc.applicationId), sc.applicationAttemptId)
}
Expand All @@ -1024,7 +1041,7 @@ class DAGScheduler(
val p: Int = stage.partitions(id)
val part = stage.rdd.partitions(p)
val locs = taskIdToLocations(id)
new ResultTask(stage.id, stage.latestInfo.attemptId,
new ResultTask(stage.id, stage.latestInfo.attemptId, taskData,
taskBinary, part, locs, id, properties, stage.latestInfo.taskMetrics,
Option(jobId), Option(sc.applicationId), sc.applicationAttemptId)
}
Expand Down Expand Up @@ -1381,7 +1398,7 @@ class DAGScheduler(
* Marks a stage as finished and removes it from the list of running stages.
*/
private def markStageAsFinished(stage: Stage, errorMessage: Option[String] = None): Unit = {
val serviceTime = stage.latestInfo.submissionTime match {
val serviceTime = if (!log.isInfoEnabled) 0L else stage.latestInfo.submissionTime match {
case Some(t) => "%.03f".format((clock.getTimeMillis() - t) / 1000.0)
case _ => "Unknown"
}
Expand Down Expand Up @@ -1674,4 +1691,12 @@ private[spark] object DAGScheduler {
// this is a simplistic way to avoid resubmitting tasks in the non-fetchable map stage one by one
// as more failure events come in
val RESUBMIT_TIMEOUT = 200

// The maximum size of uncompressed common task bytes (rdd, closure)
// that will be shipped with the task else will be broadcast separately.
val TASK_INLINE_LIMIT = 100 * 1024

// The maximum number of partitions below which common task bytes will be
// shipped with the task else will be broadcast separately.
val TASK_INLINE_PARTITION_LIMIT = 8
}
Loading

0 comments on commit 22141bd

Please sign in to comment.