Skip to content

Commit

Permalink
SPARK-2565. Update ShuffleReadMetrics as blocks are fetched
Browse files Browse the repository at this point in the history
Author: Sandy Ryza <[email protected]>

Closes #1507 from sryza/sandy-spark-2565 and squashes the following commits:

74dad41 [Sandy Ryza] SPARK-2565. Update ShuffleReadMetrics as blocks are fetched
  • Loading branch information
sryza authored and pwendell committed Aug 8, 2014
1 parent 6906b69 commit 4c51098
Show file tree
Hide file tree
Showing 10 changed files with 84 additions and 64 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,7 @@ private[spark] class Executor(
for (taskRunner <- runningTasks.values()) {
if (!taskRunner.attemptedTask.isEmpty) {
Option(taskRunner.task).flatMap(_.metrics).foreach { metrics =>
metrics.updateShuffleReadMetrics
tasksMetrics += ((taskRunner.taskId, metrics))
}
}
Expand Down
55 changes: 42 additions & 13 deletions core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.executor

import scala.collection.mutable.ArrayBuffer

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.storage.{BlockId, BlockStatus}

Expand Down Expand Up @@ -81,12 +83,27 @@ class TaskMetrics extends Serializable {
var inputMetrics: Option[InputMetrics] = None

/**
* If this task reads from shuffle output, metrics on getting shuffle data will be collected here
* If this task reads from shuffle output, metrics on getting shuffle data will be collected here.
* This includes read metrics aggregated over all the task's shuffle dependencies.
*/
private var _shuffleReadMetrics: Option[ShuffleReadMetrics] = None

def shuffleReadMetrics = _shuffleReadMetrics

/**
* This should only be used when recreating TaskMetrics, not when updating read metrics in
* executors.
*/
private[spark] def setShuffleReadMetrics(shuffleReadMetrics: Option[ShuffleReadMetrics]) {
_shuffleReadMetrics = shuffleReadMetrics
}

/**
* ShuffleReadMetrics per dependency for collecting independently while task is in progress.
*/
@transient private lazy val depsShuffleReadMetrics: ArrayBuffer[ShuffleReadMetrics] =
new ArrayBuffer[ShuffleReadMetrics]()

/**
* If this task writes to shuffle output, metrics on the written shuffle data will be collected
* here
Expand All @@ -98,19 +115,31 @@ class TaskMetrics extends Serializable {
*/
var updatedBlocks: Option[Seq[(BlockId, BlockStatus)]] = None

/** Adds the given ShuffleReadMetrics to any existing shuffle metrics for this task. */
def updateShuffleReadMetrics(newMetrics: ShuffleReadMetrics) = synchronized {
_shuffleReadMetrics match {
case Some(existingMetrics) =>
existingMetrics.shuffleFinishTime = math.max(
existingMetrics.shuffleFinishTime, newMetrics.shuffleFinishTime)
existingMetrics.fetchWaitTime += newMetrics.fetchWaitTime
existingMetrics.localBlocksFetched += newMetrics.localBlocksFetched
existingMetrics.remoteBlocksFetched += newMetrics.remoteBlocksFetched
existingMetrics.remoteBytesRead += newMetrics.remoteBytesRead
case None =>
_shuffleReadMetrics = Some(newMetrics)
/**
* A task may have multiple shuffle readers for multiple dependencies. To avoid synchronization
* issues from readers in different threads, in-progress tasks use a ShuffleReadMetrics for each
* dependency, and merge these metrics before reporting them to the driver. This method returns
* a ShuffleReadMetrics for a dependency and registers it for merging later.
*/
private [spark] def createShuffleReadMetricsForDependency(): ShuffleReadMetrics = synchronized {
val readMetrics = new ShuffleReadMetrics()
depsShuffleReadMetrics += readMetrics
readMetrics
}

/**
* Aggregates shuffle read metrics for all registered dependencies into shuffleReadMetrics.
*/
private[spark] def updateShuffleReadMetrics() = synchronized {
val merged = new ShuffleReadMetrics()
for (depMetrics <- depsShuffleReadMetrics) {
merged.fetchWaitTime += depMetrics.fetchWaitTime
merged.localBlocksFetched += depMetrics.localBlocksFetched
merged.remoteBlocksFetched += depMetrics.remoteBlocksFetched
merged.remoteBytesRead += depMetrics.remoteBytesRead
merged.shuffleFinishTime = math.max(merged.shuffleFinishTime, depMetrics.shuffleFinishTime)
}
_shuffleReadMetrics = Some(merged)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ private[hash] object BlockStoreShuffleFetcher extends Logging {
shuffleId: Int,
reduceId: Int,
context: TaskContext,
serializer: Serializer)
serializer: Serializer,
shuffleMetrics: ShuffleReadMetrics)
: Iterator[T] =
{
logDebug("Fetching outputs for shuffle %d, reduce %d".format(shuffleId, reduceId))
Expand Down Expand Up @@ -73,17 +74,11 @@ private[hash] object BlockStoreShuffleFetcher extends Logging {
}
}

val blockFetcherItr = blockManager.getMultiple(blocksByAddress, serializer)
val blockFetcherItr = blockManager.getMultiple(blocksByAddress, serializer, shuffleMetrics)
val itr = blockFetcherItr.flatMap(unpackBlock)

val completionIter = CompletionIterator[T, Iterator[T]](itr, {
val shuffleMetrics = new ShuffleReadMetrics
shuffleMetrics.shuffleFinishTime = System.currentTimeMillis
shuffleMetrics.fetchWaitTime = blockFetcherItr.fetchWaitTime
shuffleMetrics.remoteBytesRead = blockFetcherItr.remoteBytesRead
shuffleMetrics.localBlocksFetched = blockFetcherItr.numLocalBlocks
shuffleMetrics.remoteBlocksFetched = blockFetcherItr.numRemoteBlocks
context.taskMetrics.updateShuffleReadMetrics(shuffleMetrics)
context.taskMetrics.updateShuffleReadMetrics()
})

new InterruptibleIterator[T](context, completionIter)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,10 @@ private[spark] class HashShuffleReader[K, C](

/** Read the combined key-values for this reduce task */
override def read(): Iterator[Product2[K, C]] = {
val readMetrics = context.taskMetrics.createShuffleReadMetricsForDependency()
val ser = Serializer.getSerializer(dep.serializer)
val iter = BlockStoreShuffleFetcher.fetch(handle.shuffleId, startPartition, context, ser)
val iter = BlockStoreShuffleFetcher.fetch(handle.shuffleId, startPartition, context, ser,
readMetrics)

val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) {
if (dep.mapSideCombine) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import scala.util.{Failure, Success}
import io.netty.buffer.ByteBuf

import org.apache.spark.{Logging, SparkException}
import org.apache.spark.executor.ShuffleReadMetrics
import org.apache.spark.network.BufferMessage
import org.apache.spark.network.ConnectionManagerId
import org.apache.spark.network.netty.ShuffleCopier
Expand All @@ -47,10 +48,6 @@ import org.apache.spark.util.Utils
private[storage]
trait BlockFetcherIterator extends Iterator[(BlockId, Option[Iterator[Any]])] with Logging {
def initialize()
def numLocalBlocks: Int
def numRemoteBlocks: Int
def fetchWaitTime: Long
def remoteBytesRead: Long
}


Expand All @@ -72,14 +69,12 @@ object BlockFetcherIterator {
class BasicBlockFetcherIterator(
private val blockManager: BlockManager,
val blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])],
serializer: Serializer)
serializer: Serializer,
readMetrics: ShuffleReadMetrics)
extends BlockFetcherIterator {

import blockManager._

private var _remoteBytesRead = 0L
private var _fetchWaitTime = 0L

if (blocksByAddress == null) {
throw new IllegalArgumentException("BlocksByAddress is null")
}
Expand All @@ -89,13 +84,9 @@ object BlockFetcherIterator {

protected var startTime = System.currentTimeMillis

// This represents the number of local blocks, also counting zero-sized blocks
private var numLocal = 0
// BlockIds for local blocks that need to be fetched. Excludes zero-sized blocks
protected val localBlocksToFetch = new ArrayBuffer[BlockId]()

// This represents the number of remote blocks, also counting zero-sized blocks
private var numRemote = 0
// BlockIds for remote blocks that need to be fetched. Excludes zero-sized blocks
protected val remoteBlocksToFetch = new HashSet[BlockId]()

Expand Down Expand Up @@ -132,7 +123,10 @@ object BlockFetcherIterator {
val networkSize = blockMessage.getData.limit()
results.put(new FetchResult(blockId, sizeMap(blockId),
() => dataDeserialize(blockId, blockMessage.getData, serializer)))
_remoteBytesRead += networkSize
// TODO: NettyBlockFetcherIterator has some race conditions where multiple threads can
// be incrementing bytes read at the same time (SPARK-2625).
readMetrics.remoteBytesRead += networkSize
readMetrics.remoteBlocksFetched += 1
logDebug("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime))
}
}
Expand All @@ -155,14 +149,14 @@ object BlockFetcherIterator {
// Split local and remote blocks. Remote blocks are further split into FetchRequests of size
// at most maxBytesInFlight in order to limit the amount of data in flight.
val remoteRequests = new ArrayBuffer[FetchRequest]
var totalBlocks = 0
for ((address, blockInfos) <- blocksByAddress) {
totalBlocks += blockInfos.size
if (address == blockManagerId) {
numLocal = blockInfos.size
// Filter out zero-sized blocks
localBlocksToFetch ++= blockInfos.filter(_._2 != 0).map(_._1)
_numBlocksToFetch += localBlocksToFetch.size
} else {
numRemote += blockInfos.size
val iterator = blockInfos.iterator
var curRequestSize = 0L
var curBlocks = new ArrayBuffer[(BlockId, Long)]
Expand Down Expand Up @@ -192,7 +186,7 @@ object BlockFetcherIterator {
}
}
logInfo("Getting " + _numBlocksToFetch + " non-empty blocks out of " +
(numLocal + numRemote) + " blocks")
totalBlocks + " blocks")
remoteRequests
}

Expand All @@ -205,6 +199,7 @@ object BlockFetcherIterator {
// getLocalFromDisk never return None but throws BlockException
val iter = getLocalFromDisk(id, serializer).get
// Pass 0 as size since it's not in flight
readMetrics.localBlocksFetched += 1
results.put(new FetchResult(id, 0, () => iter))
logDebug("Got local block " + id)
} catch {
Expand Down Expand Up @@ -238,12 +233,6 @@ object BlockFetcherIterator {
logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime) + " ms")
}

override def numLocalBlocks: Int = numLocal
override def numRemoteBlocks: Int = numRemote
override def fetchWaitTime: Long = _fetchWaitTime
override def remoteBytesRead: Long = _remoteBytesRead


// Implementing the Iterator methods with an iterator that reads fetched blocks off the queue
// as they arrive.
@volatile protected var resultsGotten = 0
Expand All @@ -255,7 +244,7 @@ object BlockFetcherIterator {
val startFetchWait = System.currentTimeMillis()
val result = results.take()
val stopFetchWait = System.currentTimeMillis()
_fetchWaitTime += (stopFetchWait - startFetchWait)
readMetrics.fetchWaitTime += (stopFetchWait - startFetchWait)
if (! result.failed) bytesInFlight -= result.size
while (!fetchRequests.isEmpty &&
(bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) {
Expand All @@ -269,8 +258,9 @@ object BlockFetcherIterator {
class NettyBlockFetcherIterator(
blockManager: BlockManager,
blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])],
serializer: Serializer)
extends BasicBlockFetcherIterator(blockManager, blocksByAddress, serializer) {
serializer: Serializer,
readMetrics: ShuffleReadMetrics)
extends BasicBlockFetcherIterator(blockManager, blocksByAddress, serializer, readMetrics) {

import blockManager._

Expand Down
11 changes: 7 additions & 4 deletions core/src/main/scala/org/apache/spark/storage/BlockManager.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import akka.actor.{ActorSystem, Cancellable, Props}
import sun.nio.ch.DirectBuffer

import org.apache.spark._
import org.apache.spark.executor.{DataReadMethod, InputMetrics, ShuffleWriteMetrics}
import org.apache.spark.executor._
import org.apache.spark.io.CompressionCodec
import org.apache.spark.network._
import org.apache.spark.serializer.Serializer
Expand Down Expand Up @@ -539,12 +539,15 @@ private[spark] class BlockManager(
*/
def getMultiple(
blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])],
serializer: Serializer): BlockFetcherIterator = {
serializer: Serializer,
readMetrics: ShuffleReadMetrics): BlockFetcherIterator = {
val iter =
if (conf.getBoolean("spark.shuffle.use.netty", false)) {
new BlockFetcherIterator.NettyBlockFetcherIterator(this, blocksByAddress, serializer)
new BlockFetcherIterator.NettyBlockFetcherIterator(this, blocksByAddress, serializer,
readMetrics)
} else {
new BlockFetcherIterator.BasicBlockFetcherIterator(this, blocksByAddress, serializer)
new BlockFetcherIterator.BasicBlockFetcherIterator(this, blocksByAddress, serializer,
readMetrics)
}
iter.initialize()
iter
Expand Down
5 changes: 2 additions & 3 deletions core/src/main/scala/org/apache/spark/util/JsonProtocol.scala
Original file line number Diff line number Diff line change
Expand Up @@ -560,9 +560,8 @@ private[spark] object JsonProtocol {
metrics.resultSerializationTime = (json \ "Result Serialization Time").extract[Long]
metrics.memoryBytesSpilled = (json \ "Memory Bytes Spilled").extract[Long]
metrics.diskBytesSpilled = (json \ "Disk Bytes Spilled").extract[Long]
Utils.jsonOption(json \ "Shuffle Read Metrics").map { shuffleReadMetrics =>
metrics.updateShuffleReadMetrics(shuffleReadMetricsFromJson(shuffleReadMetrics))
}
metrics.setShuffleReadMetrics(
Utils.jsonOption(json \ "Shuffle Read Metrics").map(shuffleReadMetricsFromJson))
metrics.shuffleWriteMetrics =
Utils.jsonOption(json \ "Shuffle Write Metrics").map(shuffleWriteMetricsFromJson)
metrics.inputMetrics =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import org.mockito.invocation.InvocationOnMock

import org.apache.spark.storage.BlockFetcherIterator._
import org.apache.spark.network.{ConnectionManager, Message}
import org.apache.spark.executor.ShuffleReadMetrics

class BlockFetcherIteratorSuite extends FunSuite with Matchers {

Expand Down Expand Up @@ -70,8 +71,8 @@ class BlockFetcherIteratorSuite extends FunSuite with Matchers {
(bmId, blIds.map(blId => (blId, 1.asInstanceOf[Long])).toSeq)
)

val iterator = new BasicBlockFetcherIterator(blockManager,
blocksByAddress, null)
val iterator = new BasicBlockFetcherIterator(blockManager, blocksByAddress, null,
new ShuffleReadMetrics())

iterator.initialize()

Expand Down Expand Up @@ -121,8 +122,8 @@ class BlockFetcherIteratorSuite extends FunSuite with Matchers {
(bmId, blIds.map(blId => (blId, 1.asInstanceOf[Long])).toSeq)
)

val iterator = new BasicBlockFetcherIterator(blockManager,
blocksByAddress, null)
val iterator = new BasicBlockFetcherIterator(blockManager, blocksByAddress, null,
new ShuffleReadMetrics())

iterator.initialize()

Expand Down Expand Up @@ -165,7 +166,7 @@ class BlockFetcherIteratorSuite extends FunSuite with Matchers {
)

val iterator = new BasicBlockFetcherIterator(blockManager,
blocksByAddress, null)
blocksByAddress, null, new ShuffleReadMetrics())

iterator.initialize()
iterator.foreach{
Expand Down Expand Up @@ -219,7 +220,7 @@ class BlockFetcherIteratorSuite extends FunSuite with Matchers {
)

val iterator = new BasicBlockFetcherIterator(blockManager,
blocksByAddress, null)
blocksByAddress, null, new ShuffleReadMetrics())
iterator.initialize()
iterator.foreach{
case (_, r) => {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc

// finish this task, should get updated shuffleRead
shuffleReadMetrics.remoteBytesRead = 1000
taskMetrics.updateShuffleReadMetrics(shuffleReadMetrics)
taskMetrics.setShuffleReadMetrics(Some(shuffleReadMetrics))
var taskInfo = new TaskInfo(1234L, 0, 1, 0L, "exe-1", "host1", TaskLocality.NODE_LOCAL, false)
taskInfo.finishTime = 1
var task = new ShuffleMapTask(0)
Expand Down Expand Up @@ -142,7 +142,7 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc
val taskMetrics = new TaskMetrics()
val shuffleReadMetrics = new ShuffleReadMetrics()
val shuffleWriteMetrics = new ShuffleWriteMetrics()
taskMetrics.updateShuffleReadMetrics(shuffleReadMetrics)
taskMetrics.setShuffleReadMetrics(Some(shuffleReadMetrics))
taskMetrics.shuffleWriteMetrics = Some(shuffleWriteMetrics)
shuffleReadMetrics.remoteBytesRead = base + 1
shuffleReadMetrics.remoteBlocksFetched = base + 2
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -539,7 +539,7 @@ class JsonProtocolSuite extends FunSuite {
sr.localBlocksFetched = e
sr.fetchWaitTime = a + d
sr.remoteBlocksFetched = f
t.updateShuffleReadMetrics(sr)
t.setShuffleReadMetrics(Some(sr))
}
sw.shuffleBytesWritten = a + b + c
sw.shuffleWriteTime = b + c + d
Expand Down

0 comments on commit 4c51098

Please sign in to comment.