From 19135f298e215ae11f4c8fd3b8c51147fd8bcc46 Mon Sep 17 00:00:00 2001 From: Matt Massie Date: Wed, 27 May 2015 16:16:19 -0700 Subject: [PATCH] [SPARK-7884] Allow Spark shuffle APIs to be more customizable This commit updates the shuffle read path to enable ShuffleReader implementations more control over the deserialization process. The BlockStoreShuffleFetcher.fetch() method has been renamed to BlockStoreShuffleFetcher.fetchBlockStreams(). Previously, this method returned a record iterator; now, it returns an iterator of (BlockId, Try[InputStream]). Deserialization of records is now handled in the ShuffleReader.read() method. This change creates a cleaner separation of concerns and allows implementations of ShuffleReader more flexibility in how records are deserialized. --- .../hash/BlockStoreShuffleFetcher.scala | 35 +++----- .../shuffle/hash/HashShuffleReader.scala | 27 +++++- .../storage/ShuffleBlockFetcherIterator.scala | 89 ++++++++++++------- .../ShuffleBlockFetcherIteratorSuite.scala | 39 ++++---- 4 files changed, 119 insertions(+), 71 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala index 597d46a3d2223..9a15f9bab834a 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala @@ -17,23 +17,22 @@ package org.apache.spark.shuffle.hash -import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.HashMap +import java.io.InputStream + +import scala.collection.mutable.{ArrayBuffer, HashMap} import scala.util.{Failure, Success, Try} import org.apache.spark._ -import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockFetcherIterator, ShuffleBlockId} import org.apache.spark.util.CompletionIterator private[hash] object BlockStoreShuffleFetcher extends Logging { - def fetch[T]( + def fetchBlockStreams( shuffleId: Int, reduceId: Int, - context: TaskContext, - serializer: Serializer) - : Iterator[T] = + context: TaskContext) + : Iterator[(BlockId, InputStream)] = { logDebug("Fetching outputs for shuffle %d, reduce %d".format(shuffleId, reduceId)) val blockManager = SparkEnv.get.blockManager @@ -53,12 +52,12 @@ private[hash] object BlockStoreShuffleFetcher extends Logging { (address, splits.map(s => (ShuffleBlockId(shuffleId, s._1, reduceId), s._2))) } - def unpackBlock(blockPair: (BlockId, Try[Iterator[Any]])) : Iterator[T] = { + def unpackBlock(blockPair: (BlockId, Try[InputStream])) : (BlockId, InputStream) = { val blockId = blockPair._1 val blockOption = blockPair._2 blockOption match { - case Success(block) => { - block.asInstanceOf[Iterator[T]] + case Success(inputStream) => { + (blockId, inputStream) } case Failure(e) => { blockId match { @@ -78,21 +77,13 @@ private[hash] object BlockStoreShuffleFetcher extends Logging { SparkEnv.get.blockManager.shuffleClient, blockManager, blocksByAddress, - serializer, // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024) - val itr = blockFetcherItr.flatMap(unpackBlock) - val completionIter = CompletionIterator[T, Iterator[T]](itr, { - context.taskMetrics.updateShuffleReadMetrics() - }) + val itr = blockFetcherItr.map(unpackBlock) - new InterruptibleIterator[T](context, completionIter) { - val readMetrics = context.taskMetrics.createShuffleReadMetricsForDependency() - override def next(): T = { - readMetrics.incRecordsRead(1) - delegate.next() - } - } + CompletionIterator[(BlockId, InputStream), Iterator[(BlockId, InputStream)]](itr, { + context.taskMetrics().updateShuffleReadMetrics() + }) } } diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala index 41bafabde05b9..0f315b85bfca6 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala @@ -17,10 +17,10 @@ package org.apache.spark.shuffle.hash -import org.apache.spark.{InterruptibleIterator, TaskContext} import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleReader} import org.apache.spark.util.collection.ExternalSorter +import org.apache.spark.{InterruptibleIterator, SparkEnv, TaskContext} private[spark] class HashShuffleReader[K, C]( handle: BaseShuffleHandle[K, _, C], @@ -33,11 +33,34 @@ private[spark] class HashShuffleReader[K, C]( "Hash shuffle currently only supports fetching one partition") private val dep = handle.dependency + private val blockManager = SparkEnv.get.blockManager /** Read the combined key-values for this reduce task */ override def read(): Iterator[Product2[K, C]] = { + val blockStreams = BlockStoreShuffleFetcher.fetchBlockStreams( + handle.shuffleId, startPartition, context) + + // Wrap the streams for compression based on configuration + val wrappedStreams = blockStreams.map { case (blockId, inputStream) => + blockManager.wrapForCompression(blockId, inputStream) + } + val ser = Serializer.getSerializer(dep.serializer) - val iter = BlockStoreShuffleFetcher.fetch(handle.shuffleId, startPartition, context, ser) + val serializerInstance = ser.newInstance() + + // Create a key/value iterator for each stream + val recordIterator = wrappedStreams.flatMap { wrappedStream => + serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator + } + + // Update read metrics for each record materialized + val iter = new InterruptibleIterator[Any](context, recordIterator) { + val readMetrics = context.taskMetrics.createShuffleReadMetricsForDependency() + override def next(): Any = { + readMetrics.incRecordsRead(1) + delegate.next() + } + }.asInstanceOf[Iterator[Nothing]] val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) { if (dep.mapSideCombine) { diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index d0faab62c9e9e..3758a758943d4 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -17,23 +17,24 @@ package org.apache.spark.storage +import java.io.InputStream import java.util.concurrent.LinkedBlockingQueue -import scala.collection.mutable.{ArrayBuffer, HashSet, Queue} +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer import scala.util.{Failure, Try} -import org.apache.spark.{Logging, TaskContext} -import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient} import org.apache.spark.network.buffer.ManagedBuffer -import org.apache.spark.serializer.{SerializerInstance, Serializer} -import org.apache.spark.util.{CompletionIterator, Utils} +import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient} +import org.apache.spark.util.Utils +import org.apache.spark.{Logging, TaskContext} /** * An iterator that fetches multiple blocks. For local blocks, it fetches from the local block * manager. For remote blocks, it fetches them using the provided BlockTransferService. * - * This creates an iterator of (BlockID, values) tuples so the caller can handle blocks in a - * pipelined fashion as they are received. + * This creates an iterator of (BlockID, Try[InputStream]) tuples so the caller can handle blocks + * in a pipelined fashion as they are received. * * The implementation throttles the remote fetches to they don't exceed maxBytesInFlight to avoid * using too much memory. @@ -44,7 +45,6 @@ import org.apache.spark.util.{CompletionIterator, Utils} * @param blocksByAddress list of blocks to fetch grouped by the [[BlockManagerId]]. * For each block we also require the size (in bytes as a long field) in * order to throttle the memory usage. - * @param serializer serializer used to deserialize the data. * @param maxBytesInFlight max size (in bytes) of remote blocks to fetch at any given point. */ private[spark] @@ -53,9 +53,8 @@ final class ShuffleBlockFetcherIterator( shuffleClient: ShuffleClient, blockManager: BlockManager, blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])], - serializer: Serializer, maxBytesInFlight: Long) - extends Iterator[(BlockId, Try[Iterator[Any]])] with Logging { + extends Iterator[(BlockId, Try[InputStream])] with Logging { import ShuffleBlockFetcherIterator._ @@ -79,11 +78,11 @@ final class ShuffleBlockFetcherIterator( private[this] val localBlocks = new ArrayBuffer[BlockId]() /** Remote blocks to fetch, excluding zero-sized blocks. */ - private[this] val remoteBlocks = new HashSet[BlockId]() + private[this] val remoteBlocks = new mutable.HashSet[BlockId]() /** * A queue to hold our results. This turns the asynchronous model provided by - * [[BlockTransferService]] into a synchronous model (iterator). + * [[org.apache.spark.network.BlockTransferService]] into a synchronous model (iterator). */ private[this] val results = new LinkedBlockingQueue[FetchResult] @@ -97,14 +96,12 @@ final class ShuffleBlockFetcherIterator( * Queue of fetch requests to issue; we'll pull requests off this gradually to make sure that * the number of bytes in flight is limited to maxBytesInFlight. */ - private[this] val fetchRequests = new Queue[FetchRequest] + private[this] val fetchRequests = new mutable.Queue[FetchRequest] /** Current bytes in flight from our requests */ private[this] var bytesInFlight = 0L - private[this] val shuffleMetrics = context.taskMetrics.createShuffleReadMetricsForDependency() - - private[this] val serializerInstance: SerializerInstance = serializer.newInstance() + private[this] val shuffleMetrics = context.taskMetrics().createShuffleReadMetricsForDependency() /** * Whether the iterator is still active. If isZombie is true, the callback interface will no @@ -114,17 +111,23 @@ final class ShuffleBlockFetcherIterator( initialize() - /** - * Mark the iterator as zombie, and release all buffers that haven't been deserialized yet. - */ - private[this] def cleanup() { - isZombie = true + // Decrements the buffer reference count. + // The currentResult is set to null to prevent releasing the buffer again on cleanup() + private[storage] def releaseCurrentResultBuffer(): Unit = { // Release the current buffer if necessary currentResult match { case SuccessFetchResult(_, _, buf) => buf.release() case _ => } + currentResult = null + } + /** + * Mark the iterator as zombie, and release all buffers that haven't been deserialized yet. + */ + private[this] def cleanup() { + isZombie = true + releaseCurrentResultBuffer() // Release buffers in the results queue val iter = results.iterator() while (iter.hasNext) { @@ -272,7 +275,7 @@ final class ShuffleBlockFetcherIterator( override def hasNext: Boolean = numBlocksProcessed < numBlocksToFetch - override def next(): (BlockId, Try[Iterator[Any]]) = { + override def next(): (BlockId, Try[InputStream]) = { numBlocksProcessed += 1 val startFetchWait = System.currentTimeMillis() currentResult = results.take() @@ -290,22 +293,15 @@ final class ShuffleBlockFetcherIterator( sendRequest(fetchRequests.dequeue()) } - val iteratorTry: Try[Iterator[Any]] = result match { + val iteratorTry: Try[InputStream] = result match { case FailureFetchResult(_, e) => Failure(e) case SuccessFetchResult(blockId, _, buf) => // There is a chance that createInputStream can fail (e.g. fetching a local file that does // not exist, SPARK-4085). In that case, we should propagate the right exception so // the scheduler gets a FetchFailedException. - Try(buf.createInputStream()).map { is0 => - val is = blockManager.wrapForCompression(blockId, is0) - val iter = serializerInstance.deserializeStream(is).asKeyValueIterator - CompletionIterator[Any, Iterator[Any]](iter, { - // Once the iterator is exhausted, release the buffer and set currentResult to null - // so we don't release it again in cleanup. - currentResult = null - buf.release() - }) + Try(buf.createInputStream()).map { inputStream => + new WrappedInputStream(inputStream, this) } } @@ -313,6 +309,35 @@ final class ShuffleBlockFetcherIterator( } } +// Helper class that ensures a ManagerBuffer is released upon InputStream.close() +private class WrappedInputStream(delegate: InputStream, iterator: ShuffleBlockFetcherIterator) + extends InputStream { + private var closed = false + + override def read(): Int = delegate.read() + + override def close(): Unit = { + if (!closed) { + delegate.close() + iterator.releaseCurrentResultBuffer() + closed = true + } + } + + override def available(): Int = delegate.available() + + override def mark(readlimit: Int): Unit = delegate.mark(readlimit) + + override def skip(n: Long): Long = delegate.skip(n) + + override def markSupported(): Boolean = delegate.markSupported() + + override def read(b: Array[Byte]): Int = delegate.read(b) + + override def read(b: Array[Byte], off: Int, len: Int): Int = delegate.read(b, off, len) + + override def reset(): Unit = delegate.reset() +} private[storage] object ShuffleBlockFetcherIterator { diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index 2a7fe67ad8585..60e6840cb00bd 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -17,21 +17,22 @@ package org.apache.spark.storage +import java.io.InputStream import java.util.concurrent.Semaphore -import scala.concurrent.future import scala.concurrent.ExecutionContext.Implicits.global +import scala.concurrent.future import org.mockito.Matchers.{any, eq => meq} import org.mockito.Mockito._ import org.mockito.invocation.InvocationOnMock import org.mockito.stubbing.Answer -import org.apache.spark.{SparkConf, SparkFunSuite, TaskContextImpl} import org.apache.spark.network._ import org.apache.spark.network.buffer.ManagedBuffer import org.apache.spark.network.shuffle.BlockFetchingListener -import org.apache.spark.serializer.TestSerializer +import org.apache.spark.{SparkFunSuite, TaskContextImpl} + class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite { // Some of the tests are quite tricky because we are testing the cleanup behavior @@ -57,7 +58,16 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite { transfer } - private val conf = new SparkConf + // Create a mock managed buffer for testing + def createMockManagedBuffer(): ManagedBuffer = { + val mockManagedBuffer = mock(classOf[ManagedBuffer]) + when(mockManagedBuffer.createInputStream()).thenAnswer(new Answer[InputStream] { + override def answer(invocation: InvocationOnMock): InputStream = { + mock(classOf[InputStream]) + } + }) + mockManagedBuffer + } test("successful 3 local reads + 2 remote reads") { val blockManager = mock(classOf[BlockManager]) @@ -92,7 +102,6 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite { transfer, blockManager, blocksByAddress, - new TestSerializer, 48 * 1024 * 1024) // 3 local blocks fetched in initialization @@ -104,10 +113,13 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite { assert(subIterator.isSuccess, s"iterator should have 5 elements defined but actually has $i elements") - // Make sure we release the buffer once the iterator is exhausted. + // Make sure we release buffers when a wrapped input stream is closed. val mockBuf = localBlocks.getOrElse(blockId, remoteBlocks(blockId)) + val wrappedInputStream = new WrappedInputStream(mock(classOf[InputStream]), iterator) verify(mockBuf, times(0)).release() - subIterator.get.foreach(_ => Unit) // exhaust the iterator + wrappedInputStream.close() + verify(mockBuf, times(1)).release() + wrappedInputStream.close() // close should be idempotent verify(mockBuf, times(1)).release() } @@ -125,10 +137,9 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite { // Make sure remote blocks would return val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2) val blocks = Map[BlockId, ManagedBuffer]( - ShuffleBlockId(0, 0, 0) -> mock(classOf[ManagedBuffer]), - ShuffleBlockId(0, 1, 0) -> mock(classOf[ManagedBuffer]), - ShuffleBlockId(0, 2, 0) -> mock(classOf[ManagedBuffer]) - ) + ShuffleBlockId(0, 0, 0) -> createMockManagedBuffer(), + ShuffleBlockId(0, 1, 0) -> createMockManagedBuffer(), + ShuffleBlockId(0, 2, 0) -> createMockManagedBuffer()) // Semaphore to coordinate event sequence in two different threads. val sem = new Semaphore(0) @@ -159,11 +170,10 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite { transfer, blockManager, blocksByAddress, - new TestSerializer, 48 * 1024 * 1024) - // Exhaust the first block, and then it should be released. - iterator.next()._2.get.foreach(_ => Unit) + verify(blocks(ShuffleBlockId(0, 0, 0)), times(0)).release() + iterator.next()._2.get.close() // close() first block's input stream verify(blocks(ShuffleBlockId(0, 0, 0)), times(1)).release() // Get the 2nd block but do not exhaust the iterator @@ -222,7 +232,6 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite { transfer, blockManager, blocksByAddress, - new TestSerializer, 48 * 1024 * 1024) // Continue only after the mock calls onBlockFetchFailure