Skip to content

Commit

Permalink
Update shuffle read metrics in ShuffleReader instead of BlockStoreShu…
Browse files Browse the repository at this point in the history
…ffleFetcher.

This commit also includes Scala style cleanup.
  • Loading branch information
massie committed Jun 9, 2015
1 parent 7e8e0fe commit f93841e
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ import scala.util.{Failure, Success, Try}
import org.apache.spark._
import org.apache.spark.shuffle.FetchFailedException
import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockFetcherIterator, ShuffleBlockId}
import org.apache.spark.util.CompletionIterator

private[shuffle] object BlockStoreShuffleFetcher extends Logging {
def fetchBlockStreams(
Expand Down Expand Up @@ -80,10 +79,6 @@ private[shuffle] object BlockStoreShuffleFetcher extends Logging {
// Note: we use getSizeAsMb when no suffix is provided for backwards compatibility
SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024)

val itr = blockFetcherItr.map(unpackBlock)

CompletionIterator[(BlockId, InputStream), Iterator[(BlockId, InputStream)]](itr, {
context.taskMetrics().updateShuffleReadMetrics()
})
blockFetcherItr.map(unpackBlock)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@

package org.apache.spark.shuffle.hash

import org.apache.spark.{SparkEnv, TaskContext, InterruptibleIterator}
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleReader}
import org.apache.spark.shuffle.{ShuffleReader, BaseShuffleHandle}
import org.apache.spark.util.CompletionIterator
import org.apache.spark.util.collection.ExternalSorter
import org.apache.spark.{InterruptibleIterator, SparkEnv, TaskContext}

private[spark] class HashShuffleReader[K, C](
handle: BaseShuffleHandle[K, _, C],
Expand Down Expand Up @@ -51,24 +51,22 @@ private[spark] class HashShuffleReader[K, C](

// Create a key/value iterator for each stream
val recordIterator = wrappedStreams.flatMap { wrappedStream =>
val kvIter = serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator
CompletionIterator[(Any, Any), Iterator[(Any, Any)]](kvIter, {
// Close the stream once all the records have been read from it to free underlying
// ManagedBuffer as soon as possible. Note that in case of task failure, the task's
// TaskCompletionListener will make sure this is released.
wrappedStream.close()
})
serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator
}

val readMetrics = context.taskMetrics.createShuffleReadMetricsForDependency()
// Update read metrics for each record materialized
val iter = new InterruptibleIterator[(Any, Any)](context, recordIterator) {
val readMetrics = context.taskMetrics.createShuffleReadMetricsForDependency()
override def next(): (Any, Any) = {
readMetrics.incRecordsRead(1)
delegate.next()
}
val metricIter = new InterruptibleIterator[(Any, Any)](context, recordIterator) {
override def next(): (Any, Any) = {
readMetrics.incRecordsRead(1)
delegate.next()
}
}

val iter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]](metricIter, {
context.taskMetrics().updateShuffleReadMetrics()
})

val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) {
if (dep.mapSideCombine) {
// We are reading values that are already combined
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@ import java.util.concurrent.LinkedBlockingQueue
import scala.collection.mutable.{ArrayBuffer, HashSet, Queue}
import scala.util.{Failure, Try}

import org.apache.spark.{Logging, TaskContext}
import org.apache.spark.network.buffer.ManagedBuffer
import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient}
import org.apache.spark.util.Utils
import org.apache.spark.{Logging, TaskContext}

/**
* An iterator that fetches multiple blocks. For local blocks, it fetches from the local block
Expand Down Expand Up @@ -306,16 +306,18 @@ final class ShuffleBlockFetcherIterator(
// not exist, SPARK-4085). In that case, we should propagate the right exception so
// the scheduler gets a FetchFailedException.
Try(buf.createInputStream()).map { inputStream =>
new WrappedInputStream(inputStream, this)
new BufferReleasingInputStream(inputStream, this)
}
}

(result.blockId, iteratorTry)
}
}

// Helper class that ensures a ManagerBuffer is released upon InputStream.close()
private class WrappedInputStream(delegate: InputStream, iterator: ShuffleBlockFetcherIterator)
/** Helper class that ensures a ManagerBuffer is released upon InputStream.close() */
private class BufferReleasingInputStream(
delegate: InputStream,
iterator: ShuffleBlockFetcherIterator)
extends InputStream {
private var closed = false

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite {

// Make sure we release buffers when a wrapped input stream is closed.
val mockBuf = localBlocks.getOrElse(blockId, remoteBlocks(blockId))
val wrappedInputStream = new WrappedInputStream(mock(classOf[InputStream]), iterator)
val wrappedInputStream = new BufferReleasingInputStream(mock(classOf[InputStream]), iterator)
verify(mockBuf, times(0)).release()
wrappedInputStream.close()
verify(mockBuf, times(1)).release()
Expand Down

0 comments on commit f93841e

Please sign in to comment.