Skip to content

Commit

Permalink
Consolidate metric code. Make it clear why InterrubtibleIterator is n…
Browse files Browse the repository at this point in the history
…eeded.

There is also some Scala style cleanup in this commit.
  • Loading branch information
massie committed Jun 10, 2015
1 parent 5c30405 commit 4abb855
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -50,42 +50,39 @@ private[spark] class HashShuffleReader[K, C](
val serializerInstance = ser.newInstance()

// Create a key/value iterator for each stream
val recordIterator = wrappedStreams.flatMap { wrappedStream =>
val recordIter = wrappedStreams.flatMap { wrappedStream =>
serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator
}

// Update the context task metrics for each record read.
val readMetrics = context.taskMetrics.createShuffleReadMetricsForDependency()
// Update read metrics for each record materialized
val metricIter = new InterruptibleIterator[(Any, Any)](context, recordIterator) {
override def next(): (Any, Any) = {
val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]](
recordIter.map(record => {
readMetrics.incRecordsRead(1)
delegate.next()
}
}
record
}),
context.taskMetrics().updateShuffleReadMetrics())

val iter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]](metricIter, {
context.taskMetrics().updateShuffleReadMetrics()
})
// An interruptible iterator must be used here in order to support task cancellation
val interruptibleIter = new InterruptibleIterator[(Any, Any)](context, metricIter)

val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) {
if (dep.mapSideCombine) {
// We are reading values that are already combined
val combinedKeyValuesIterator = iter.asInstanceOf[Iterator[(K, C)]]
new InterruptibleIterator(context,
dep.aggregator.get.combineCombinersByKey(combinedKeyValuesIterator, context))
val combinedKeyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, C)]]
dep.aggregator.get.combineCombinersByKey(combinedKeyValuesIterator, context)
} else {
// We don't know the value type, but also don't care -- the dependency *should*
// have made sure its compatible w/ this aggregator, which will convert the value
// type to the combined type C
val keyValuesIterator = iter.asInstanceOf[Iterator[(K, Nothing)]]
new InterruptibleIterator(context,
dep.aggregator.get.combineValuesByKey(keyValuesIterator, context))
val keyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, Nothing)]]
dep.aggregator.get.combineValuesByKey(keyValuesIterator, context)
}
} else {
require(!dep.mapSideCombine, "Map-side combine without Aggregator specified!")

// Convert the Product2s to pairs since this is what downstream RDDs currently expect
iter.asInstanceOf[Iterator[Product2[K, C]]].map(pair => (pair._1, pair._2))
interruptibleIter.asInstanceOf[Iterator[Product2[K, C]]].map(pair => (pair._1, pair._2))
}

// Sort the output if there is a sort ordering defined.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@ import org.mockito.Mockito._
import org.mockito.invocation.InvocationOnMock
import org.mockito.stubbing.Answer

import org.apache.spark.{SparkFunSuite, TaskContextImpl}
import org.apache.spark.network._
import org.apache.spark.network.buffer.ManagedBuffer
import org.apache.spark.network.shuffle.BlockFetchingListener
import org.apache.spark.{SparkFunSuite, TaskContextImpl}


class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite {
Expand Down Expand Up @@ -61,11 +61,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite {
// Create a mock managed buffer for testing
def createMockManagedBuffer(): ManagedBuffer = {
val mockManagedBuffer = mock(classOf[ManagedBuffer])
when(mockManagedBuffer.createInputStream()).thenAnswer(new Answer[InputStream] {
override def answer(invocation: InvocationOnMock): InputStream = {
mock(classOf[InputStream])
}
})
when(mockManagedBuffer.createInputStream()).thenReturn(mock(classOf[InputStream]))
mockManagedBuffer
}

Expand All @@ -76,19 +72,18 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite {

// Make sure blockManager.getBlockData would return the blocks
val localBlocks = Map[BlockId, ManagedBuffer](
ShuffleBlockId(0, 0, 0) -> mock(classOf[ManagedBuffer]),
ShuffleBlockId(0, 1, 0) -> mock(classOf[ManagedBuffer]),
ShuffleBlockId(0, 2, 0) -> mock(classOf[ManagedBuffer]))
ShuffleBlockId(0, 0, 0) -> createMockManagedBuffer(),
ShuffleBlockId(0, 1, 0) -> createMockManagedBuffer(),
ShuffleBlockId(0, 2, 0) -> createMockManagedBuffer())
localBlocks.foreach { case (blockId, buf) =>
doReturn(buf).when(blockManager).getBlockData(meq(blockId))
}

// Make sure remote blocks would return
val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2)
val remoteBlocks = Map[BlockId, ManagedBuffer](
ShuffleBlockId(0, 3, 0) -> mock(classOf[ManagedBuffer]),
ShuffleBlockId(0, 4, 0) -> mock(classOf[ManagedBuffer])
)
ShuffleBlockId(0, 3, 0) -> createMockManagedBuffer(),
ShuffleBlockId(0, 4, 0) -> createMockManagedBuffer())

val transfer = createMockTransfer(remoteBlocks)

Expand All @@ -109,13 +104,13 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite {

for (i <- 0 until 5) {
assert(iterator.hasNext, s"iterator should have 5 elements but actually has $i elements")
val (blockId, subIterator) = iterator.next()
assert(subIterator.isSuccess,
val (blockId, inputStream) = iterator.next()
assert(inputStream.isSuccess,
s"iterator should have 5 elements defined but actually has $i elements")

// Make sure we release buffers when a wrapped input stream is closed.
val mockBuf = localBlocks.getOrElse(blockId, remoteBlocks(blockId))
val wrappedInputStream = new BufferReleasingInputStream(mock(classOf[InputStream]), iterator)
val wrappedInputStream = new BufferReleasingInputStream(inputStream.get, iterator)
verify(mockBuf, times(0)).release()
wrappedInputStream.close()
verify(mockBuf, times(1)).release()
Expand Down

0 comments on commit 4abb855

Please sign in to comment.