diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala index aefb2f5685537..0635b98742096 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala @@ -26,8 +26,7 @@ import org.apache.spark._ import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockFetcherIterator, ShuffleBlockId} -private[hash] class BlockStoreShuffleFetcher extends Logging { - +private[hash] object BlockStoreShuffleFetcher extends Logging { def fetchBlockStreams( shuffleId: Int, reduceId: Int, diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala index b868f32f5cce1..ca6eddf8d5c12 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala @@ -17,7 +17,6 @@ package org.apache.spark.shuffle.hash -import org.apache.spark.storage.BlockManager import org.apache.spark.{InterruptibleIterator, SparkEnv, TaskContext} import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleReader} @@ -28,19 +27,18 @@ private[spark] class HashShuffleReader[K, C]( handle: BaseShuffleHandle[K, _, C], startPartition: Int, endPartition: Int, - context: TaskContext, - blockManager: BlockManager = SparkEnv.get.blockManager, - blockStoreShuffleFetcher: BlockStoreShuffleFetcher = new BlockStoreShuffleFetcher) + context: TaskContext) extends ShuffleReader[K, C] { require(endPartition == startPartition + 1, "Hash shuffle currently only supports fetching one partition") private val dep = handle.dependency + private val blockManager = SparkEnv.get.blockManager /** Read the combined key-values for this reduce task */ override def read(): Iterator[Product2[K, C]] = { - val blockStreams = blockStoreShuffleFetcher.fetchBlockStreams( + val blockStreams = BlockStoreShuffleFetcher.fetchBlockStreams( handle.shuffleId, startPartition, context) // Wrap the streams for compression based on configuration diff --git a/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala index 53b2b89a5e641..491dc3659e184 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala @@ -17,22 +17,16 @@ package org.apache.spark.shuffle.hash -import java.io._ -import java.nio.ByteBuffer +import java.io.{File, FileWriter} import scala.language.reflectiveCalls -import org.mockito.Matchers.any -import org.mockito.Mockito._ -import org.mockito.invocation.InvocationOnMock -import org.mockito.stubbing.Answer - -import org.apache.spark._ -import org.apache.spark.executor.{ShuffleReadMetrics, TaskMetrics, ShuffleWriteMetrics} +import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkEnv, SparkFunSuite} +import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} -import org.apache.spark.serializer._ -import org.apache.spark.shuffle.{BaseShuffleHandle, FileShuffleBlockResolver} -import org.apache.spark.storage.{BlockId, BlockManager, ShuffleBlockId, FileSegment} +import org.apache.spark.serializer.JavaSerializer +import org.apache.spark.shuffle.FileShuffleBlockResolver +import org.apache.spark.storage.{ShuffleBlockId, FileSegment} class HashShuffleManagerSuite extends SparkFunSuite with LocalSparkContext { private val testConf = new SparkConf(false) @@ -113,100 +107,4 @@ class HashShuffleManagerSuite extends SparkFunSuite with LocalSparkContext { for (i <- 0 until numBytes) writer.write(i) writer.close() } - - test("HashShuffleReader.read() releases resources and tracks metrics") { - val shuffleId = 1 - val numMaps = 2 - val numKeyValuePairs = 10 - - val mockContext = mock(classOf[TaskContext]) - - val mockTaskMetrics = mock(classOf[TaskMetrics]) - val mockReadMetrics = mock(classOf[ShuffleReadMetrics]) - when(mockTaskMetrics.createShuffleReadMetricsForDependency()).thenReturn(mockReadMetrics) - when(mockContext.taskMetrics()).thenReturn(mockTaskMetrics) - - val mockShuffleFetcher = mock(classOf[BlockStoreShuffleFetcher]) - - val mockDep = mock(classOf[ShuffleDependency[_, _, _]]) - when(mockDep.keyOrdering).thenReturn(None) - when(mockDep.aggregator).thenReturn(None) - when(mockDep.serializer).thenReturn(Some(new Serializer { - override def newInstance(): SerializerInstance = new SerializerInstance { - - override def deserializeStream(s: InputStream): DeserializationStream = - new DeserializationStream { - override def readObject[T: ClassManifest](): T = null.asInstanceOf[T] - - override def close(): Unit = s.close() - - private val values = { - for (i <- 0 to numKeyValuePairs * 2) yield i - }.iterator - - private def getValueOrEOF(): Int = { - if (values.hasNext) { - values.next() - } else { - throw new EOFException("End of the file: mock deserializeStream") - } - } - - // NOTE: the readKey and readValue methods are called by asKeyValueIterator() - // which is wrapped in a NextIterator - override def readKey[T: ClassManifest](): T = getValueOrEOF().asInstanceOf[T] - - override def readValue[T: ClassManifest](): T = getValueOrEOF().asInstanceOf[T] - } - - override def deserialize[T: ClassManifest](bytes: ByteBuffer, loader: ClassLoader): T = - null.asInstanceOf[T] - - override def serialize[T: ClassManifest](t: T): ByteBuffer = ByteBuffer.allocate(0) - - override def serializeStream(s: OutputStream): SerializationStream = - null.asInstanceOf[SerializationStream] - - override def deserialize[T: ClassManifest](bytes: ByteBuffer): T = null.asInstanceOf[T] - } - })) - - val mockBlockManager = { - // Create a block manager that isn't configured for compression, just returns input stream - val blockManager = mock(classOf[BlockManager]) - when(blockManager.wrapForCompression(any[BlockId](), any[InputStream]())) - .thenAnswer(new Answer[InputStream] { - override def answer(invocation: InvocationOnMock): InputStream = { - val blockId = invocation.getArguments()(0).asInstanceOf[BlockId] - val inputStream = invocation.getArguments()(1).asInstanceOf[InputStream] - inputStream - } - }) - blockManager - } - - val mockInputStream = mock(classOf[InputStream]) - when(mockShuffleFetcher.fetchBlockStreams(any[Int](), any[Int](), any[TaskContext]())) - .thenReturn(Iterator.single((mock(classOf[BlockId]), mockInputStream))) - - val shuffleHandle = new BaseShuffleHandle(shuffleId, numMaps, mockDep) - - val reader = new HashShuffleReader(shuffleHandle, 0, 1, - mockContext, mockBlockManager, mockShuffleFetcher) - - val values = reader.read() - // Verify that we're reading the correct values - var numValuesRead = 0 - for (((key: Int, value: Int), i) <- values.zipWithIndex) { - assert(key == i * 2) - assert(value == i * 2 + 1) - numValuesRead += 1 - } - // Verify that we read the correct number of values - assert(numKeyValuePairs == numValuesRead) - // Verify that our input stream was closed - verify(mockInputStream, times(1)).close() - // Verify that we collected metrics for each key/value pair - verify(mockReadMetrics, times(numKeyValuePairs)).incRecordsRead(1) - } }