diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index c2b9c660ddaec..eac1f2326a29d 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -374,6 +374,7 @@ private[spark] class Executor( for (taskRunner <- runningTasks.values()) { if (!taskRunner.attemptedTask.isEmpty) { Option(taskRunner.task).flatMap(_.metrics).foreach { metrics => + metrics.updateShuffleReadMetrics tasksMetrics += ((taskRunner.taskId, metrics)) } } diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala index 11a6e10243211..99a88c13456df 100644 --- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala @@ -17,6 +17,8 @@ package org.apache.spark.executor +import scala.collection.mutable.ArrayBuffer + import org.apache.spark.annotation.DeveloperApi import org.apache.spark.storage.{BlockId, BlockStatus} @@ -81,12 +83,27 @@ class TaskMetrics extends Serializable { var inputMetrics: Option[InputMetrics] = None /** - * If this task reads from shuffle output, metrics on getting shuffle data will be collected here + * If this task reads from shuffle output, metrics on getting shuffle data will be collected here. + * This includes read metrics aggregated over all the task's shuffle dependencies. */ private var _shuffleReadMetrics: Option[ShuffleReadMetrics] = None def shuffleReadMetrics = _shuffleReadMetrics + /** + * This should only be used when recreating TaskMetrics, not when updating read metrics in + * executors. + */ + private[spark] def setShuffleReadMetrics(shuffleReadMetrics: Option[ShuffleReadMetrics]) { + _shuffleReadMetrics = shuffleReadMetrics + } + + /** + * ShuffleReadMetrics per dependency for collecting independently while task is in progress. + */ + @transient private lazy val depsShuffleReadMetrics: ArrayBuffer[ShuffleReadMetrics] = + new ArrayBuffer[ShuffleReadMetrics]() + /** * If this task writes to shuffle output, metrics on the written shuffle data will be collected * here @@ -98,19 +115,31 @@ class TaskMetrics extends Serializable { */ var updatedBlocks: Option[Seq[(BlockId, BlockStatus)]] = None - /** Adds the given ShuffleReadMetrics to any existing shuffle metrics for this task. */ - def updateShuffleReadMetrics(newMetrics: ShuffleReadMetrics) = synchronized { - _shuffleReadMetrics match { - case Some(existingMetrics) => - existingMetrics.shuffleFinishTime = math.max( - existingMetrics.shuffleFinishTime, newMetrics.shuffleFinishTime) - existingMetrics.fetchWaitTime += newMetrics.fetchWaitTime - existingMetrics.localBlocksFetched += newMetrics.localBlocksFetched - existingMetrics.remoteBlocksFetched += newMetrics.remoteBlocksFetched - existingMetrics.remoteBytesRead += newMetrics.remoteBytesRead - case None => - _shuffleReadMetrics = Some(newMetrics) + /** + * A task may have multiple shuffle readers for multiple dependencies. To avoid synchronization + * issues from readers in different threads, in-progress tasks use a ShuffleReadMetrics for each + * dependency, and merge these metrics before reporting them to the driver. This method returns + * a ShuffleReadMetrics for a dependency and registers it for merging later. + */ + private [spark] def createShuffleReadMetricsForDependency(): ShuffleReadMetrics = synchronized { + val readMetrics = new ShuffleReadMetrics() + depsShuffleReadMetrics += readMetrics + readMetrics + } + + /** + * Aggregates shuffle read metrics for all registered dependencies into shuffleReadMetrics. + */ + private[spark] def updateShuffleReadMetrics() = synchronized { + val merged = new ShuffleReadMetrics() + for (depMetrics <- depsShuffleReadMetrics) { + merged.fetchWaitTime += depMetrics.fetchWaitTime + merged.localBlocksFetched += depMetrics.localBlocksFetched + merged.remoteBlocksFetched += depMetrics.remoteBlocksFetched + merged.remoteBytesRead += depMetrics.remoteBytesRead + merged.shuffleFinishTime = math.max(merged.shuffleFinishTime, depMetrics.shuffleFinishTime) } + _shuffleReadMetrics = Some(merged) } } diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala index 99788828981c7..12b475658e29d 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala @@ -32,7 +32,8 @@ private[hash] object BlockStoreShuffleFetcher extends Logging { shuffleId: Int, reduceId: Int, context: TaskContext, - serializer: Serializer) + serializer: Serializer, + shuffleMetrics: ShuffleReadMetrics) : Iterator[T] = { logDebug("Fetching outputs for shuffle %d, reduce %d".format(shuffleId, reduceId)) @@ -73,17 +74,11 @@ private[hash] object BlockStoreShuffleFetcher extends Logging { } } - val blockFetcherItr = blockManager.getMultiple(blocksByAddress, serializer) + val blockFetcherItr = blockManager.getMultiple(blocksByAddress, serializer, shuffleMetrics) val itr = blockFetcherItr.flatMap(unpackBlock) val completionIter = CompletionIterator[T, Iterator[T]](itr, { - val shuffleMetrics = new ShuffleReadMetrics - shuffleMetrics.shuffleFinishTime = System.currentTimeMillis - shuffleMetrics.fetchWaitTime = blockFetcherItr.fetchWaitTime - shuffleMetrics.remoteBytesRead = blockFetcherItr.remoteBytesRead - shuffleMetrics.localBlocksFetched = blockFetcherItr.numLocalBlocks - shuffleMetrics.remoteBlocksFetched = blockFetcherItr.numRemoteBlocks - context.taskMetrics.updateShuffleReadMetrics(shuffleMetrics) + context.taskMetrics.updateShuffleReadMetrics() }) new InterruptibleIterator[T](context, completionIter) diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala index 88a5f1e5ddf58..7bed97a63f0f6 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala @@ -36,8 +36,10 @@ private[spark] class HashShuffleReader[K, C]( /** Read the combined key-values for this reduce task */ override def read(): Iterator[Product2[K, C]] = { + val readMetrics = context.taskMetrics.createShuffleReadMetricsForDependency() val ser = Serializer.getSerializer(dep.serializer) - val iter = BlockStoreShuffleFetcher.fetch(handle.shuffleId, startPartition, context, ser) + val iter = BlockStoreShuffleFetcher.fetch(handle.shuffleId, startPartition, context, ser, + readMetrics) val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) { if (dep.mapSideCombine) { diff --git a/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala index 938af6f5b923a..5f44f5f3197fd 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala @@ -27,6 +27,7 @@ import scala.util.{Failure, Success} import io.netty.buffer.ByteBuf import org.apache.spark.{Logging, SparkException} +import org.apache.spark.executor.ShuffleReadMetrics import org.apache.spark.network.BufferMessage import org.apache.spark.network.ConnectionManagerId import org.apache.spark.network.netty.ShuffleCopier @@ -47,10 +48,6 @@ import org.apache.spark.util.Utils private[storage] trait BlockFetcherIterator extends Iterator[(BlockId, Option[Iterator[Any]])] with Logging { def initialize() - def numLocalBlocks: Int - def numRemoteBlocks: Int - def fetchWaitTime: Long - def remoteBytesRead: Long } @@ -72,14 +69,12 @@ object BlockFetcherIterator { class BasicBlockFetcherIterator( private val blockManager: BlockManager, val blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])], - serializer: Serializer) + serializer: Serializer, + readMetrics: ShuffleReadMetrics) extends BlockFetcherIterator { import blockManager._ - private var _remoteBytesRead = 0L - private var _fetchWaitTime = 0L - if (blocksByAddress == null) { throw new IllegalArgumentException("BlocksByAddress is null") } @@ -89,13 +84,9 @@ object BlockFetcherIterator { protected var startTime = System.currentTimeMillis - // This represents the number of local blocks, also counting zero-sized blocks - private var numLocal = 0 // BlockIds for local blocks that need to be fetched. Excludes zero-sized blocks protected val localBlocksToFetch = new ArrayBuffer[BlockId]() - // This represents the number of remote blocks, also counting zero-sized blocks - private var numRemote = 0 // BlockIds for remote blocks that need to be fetched. Excludes zero-sized blocks protected val remoteBlocksToFetch = new HashSet[BlockId]() @@ -132,7 +123,10 @@ object BlockFetcherIterator { val networkSize = blockMessage.getData.limit() results.put(new FetchResult(blockId, sizeMap(blockId), () => dataDeserialize(blockId, blockMessage.getData, serializer))) - _remoteBytesRead += networkSize + // TODO: NettyBlockFetcherIterator has some race conditions where multiple threads can + // be incrementing bytes read at the same time (SPARK-2625). + readMetrics.remoteBytesRead += networkSize + readMetrics.remoteBlocksFetched += 1 logDebug("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime)) } } @@ -155,14 +149,14 @@ object BlockFetcherIterator { // Split local and remote blocks. Remote blocks are further split into FetchRequests of size // at most maxBytesInFlight in order to limit the amount of data in flight. val remoteRequests = new ArrayBuffer[FetchRequest] + var totalBlocks = 0 for ((address, blockInfos) <- blocksByAddress) { + totalBlocks += blockInfos.size if (address == blockManagerId) { - numLocal = blockInfos.size // Filter out zero-sized blocks localBlocksToFetch ++= blockInfos.filter(_._2 != 0).map(_._1) _numBlocksToFetch += localBlocksToFetch.size } else { - numRemote += blockInfos.size val iterator = blockInfos.iterator var curRequestSize = 0L var curBlocks = new ArrayBuffer[(BlockId, Long)] @@ -192,7 +186,7 @@ object BlockFetcherIterator { } } logInfo("Getting " + _numBlocksToFetch + " non-empty blocks out of " + - (numLocal + numRemote) + " blocks") + totalBlocks + " blocks") remoteRequests } @@ -205,6 +199,7 @@ object BlockFetcherIterator { // getLocalFromDisk never return None but throws BlockException val iter = getLocalFromDisk(id, serializer).get // Pass 0 as size since it's not in flight + readMetrics.localBlocksFetched += 1 results.put(new FetchResult(id, 0, () => iter)) logDebug("Got local block " + id) } catch { @@ -238,12 +233,6 @@ object BlockFetcherIterator { logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime) + " ms") } - override def numLocalBlocks: Int = numLocal - override def numRemoteBlocks: Int = numRemote - override def fetchWaitTime: Long = _fetchWaitTime - override def remoteBytesRead: Long = _remoteBytesRead - - // Implementing the Iterator methods with an iterator that reads fetched blocks off the queue // as they arrive. @volatile protected var resultsGotten = 0 @@ -255,7 +244,7 @@ object BlockFetcherIterator { val startFetchWait = System.currentTimeMillis() val result = results.take() val stopFetchWait = System.currentTimeMillis() - _fetchWaitTime += (stopFetchWait - startFetchWait) + readMetrics.fetchWaitTime += (stopFetchWait - startFetchWait) if (! result.failed) bytesInFlight -= result.size while (!fetchRequests.isEmpty && (bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) { @@ -269,8 +258,9 @@ object BlockFetcherIterator { class NettyBlockFetcherIterator( blockManager: BlockManager, blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])], - serializer: Serializer) - extends BasicBlockFetcherIterator(blockManager, blocksByAddress, serializer) { + serializer: Serializer, + readMetrics: ShuffleReadMetrics) + extends BasicBlockFetcherIterator(blockManager, blocksByAddress, serializer, readMetrics) { import blockManager._ diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 8d21b02b747ff..e8bbd298c631a 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -29,7 +29,7 @@ import akka.actor.{ActorSystem, Cancellable, Props} import sun.nio.ch.DirectBuffer import org.apache.spark._ -import org.apache.spark.executor.{DataReadMethod, InputMetrics, ShuffleWriteMetrics} +import org.apache.spark.executor._ import org.apache.spark.io.CompressionCodec import org.apache.spark.network._ import org.apache.spark.serializer.Serializer @@ -539,12 +539,15 @@ private[spark] class BlockManager( */ def getMultiple( blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])], - serializer: Serializer): BlockFetcherIterator = { + serializer: Serializer, + readMetrics: ShuffleReadMetrics): BlockFetcherIterator = { val iter = if (conf.getBoolean("spark.shuffle.use.netty", false)) { - new BlockFetcherIterator.NettyBlockFetcherIterator(this, blocksByAddress, serializer) + new BlockFetcherIterator.NettyBlockFetcherIterator(this, blocksByAddress, serializer, + readMetrics) } else { - new BlockFetcherIterator.BasicBlockFetcherIterator(this, blocksByAddress, serializer) + new BlockFetcherIterator.BasicBlockFetcherIterator(this, blocksByAddress, serializer, + readMetrics) } iter.initialize() iter diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index b112b359368cd..6f8eb1ee12634 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -560,9 +560,8 @@ private[spark] object JsonProtocol { metrics.resultSerializationTime = (json \ "Result Serialization Time").extract[Long] metrics.memoryBytesSpilled = (json \ "Memory Bytes Spilled").extract[Long] metrics.diskBytesSpilled = (json \ "Disk Bytes Spilled").extract[Long] - Utils.jsonOption(json \ "Shuffle Read Metrics").map { shuffleReadMetrics => - metrics.updateShuffleReadMetrics(shuffleReadMetricsFromJson(shuffleReadMetrics)) - } + metrics.setShuffleReadMetrics( + Utils.jsonOption(json \ "Shuffle Read Metrics").map(shuffleReadMetricsFromJson)) metrics.shuffleWriteMetrics = Utils.jsonOption(json \ "Shuffle Write Metrics").map(shuffleWriteMetricsFromJson) metrics.inputMetrics = diff --git a/core/src/test/scala/org/apache/spark/storage/BlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockFetcherIteratorSuite.scala index 1538995a6b404..bcbfe8baf36ad 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockFetcherIteratorSuite.scala @@ -33,6 +33,7 @@ import org.mockito.invocation.InvocationOnMock import org.apache.spark.storage.BlockFetcherIterator._ import org.apache.spark.network.{ConnectionManager, Message} +import org.apache.spark.executor.ShuffleReadMetrics class BlockFetcherIteratorSuite extends FunSuite with Matchers { @@ -70,8 +71,8 @@ class BlockFetcherIteratorSuite extends FunSuite with Matchers { (bmId, blIds.map(blId => (blId, 1.asInstanceOf[Long])).toSeq) ) - val iterator = new BasicBlockFetcherIterator(blockManager, - blocksByAddress, null) + val iterator = new BasicBlockFetcherIterator(blockManager, blocksByAddress, null, + new ShuffleReadMetrics()) iterator.initialize() @@ -121,8 +122,8 @@ class BlockFetcherIteratorSuite extends FunSuite with Matchers { (bmId, blIds.map(blId => (blId, 1.asInstanceOf[Long])).toSeq) ) - val iterator = new BasicBlockFetcherIterator(blockManager, - blocksByAddress, null) + val iterator = new BasicBlockFetcherIterator(blockManager, blocksByAddress, null, + new ShuffleReadMetrics()) iterator.initialize() @@ -165,7 +166,7 @@ class BlockFetcherIteratorSuite extends FunSuite with Matchers { ) val iterator = new BasicBlockFetcherIterator(blockManager, - blocksByAddress, null) + blocksByAddress, null, new ShuffleReadMetrics()) iterator.initialize() iterator.foreach{ @@ -219,7 +220,7 @@ class BlockFetcherIteratorSuite extends FunSuite with Matchers { ) val iterator = new BasicBlockFetcherIterator(blockManager, - blocksByAddress, null) + blocksByAddress, null, new ShuffleReadMetrics()) iterator.initialize() iterator.foreach{ case (_, r) => { diff --git a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala index cb8252515238e..f5ba31c309277 100644 --- a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala @@ -65,7 +65,7 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc // finish this task, should get updated shuffleRead shuffleReadMetrics.remoteBytesRead = 1000 - taskMetrics.updateShuffleReadMetrics(shuffleReadMetrics) + taskMetrics.setShuffleReadMetrics(Some(shuffleReadMetrics)) var taskInfo = new TaskInfo(1234L, 0, 1, 0L, "exe-1", "host1", TaskLocality.NODE_LOCAL, false) taskInfo.finishTime = 1 var task = new ShuffleMapTask(0) @@ -142,7 +142,7 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc val taskMetrics = new TaskMetrics() val shuffleReadMetrics = new ShuffleReadMetrics() val shuffleWriteMetrics = new ShuffleWriteMetrics() - taskMetrics.updateShuffleReadMetrics(shuffleReadMetrics) + taskMetrics.setShuffleReadMetrics(Some(shuffleReadMetrics)) taskMetrics.shuffleWriteMetrics = Some(shuffleWriteMetrics) shuffleReadMetrics.remoteBytesRead = base + 1 shuffleReadMetrics.remoteBlocksFetched = base + 2 diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index 2002a817d9168..97ffb07662482 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -539,7 +539,7 @@ class JsonProtocolSuite extends FunSuite { sr.localBlocksFetched = e sr.fetchWaitTime = a + d sr.remoteBlocksFetched = f - t.updateShuffleReadMetrics(sr) + t.setShuffleReadMetrics(Some(sr)) } sw.shuffleBytesWritten = a + b + c sw.shuffleWriteTime = b + c + d