diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala index e93d0ba950266..16d206a6f8043 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala @@ -25,7 +25,6 @@ import scala.util.{Failure, Success, Try} import org.apache.spark._ import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockFetcherIterator, ShuffleBlockId} -import org.apache.spark.util.CompletionIterator private[shuffle] object BlockStoreShuffleFetcher extends Logging { def fetchBlockStreams( @@ -80,10 +79,6 @@ private[shuffle] object BlockStoreShuffleFetcher extends Logging { // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024) - val itr = blockFetcherItr.map(unpackBlock) - - CompletionIterator[(BlockId, InputStream), Iterator[(BlockId, InputStream)]](itr, { - context.taskMetrics().updateShuffleReadMetrics() - }) + blockFetcherItr.map(unpackBlock) } } diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala index 4efa9b0fc1871..40e54ca0a3ab2 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala @@ -17,11 +17,11 @@ package org.apache.spark.shuffle.hash +import org.apache.spark.{SparkEnv, TaskContext, InterruptibleIterator} import org.apache.spark.serializer.Serializer -import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleReader} +import org.apache.spark.shuffle.{ShuffleReader, BaseShuffleHandle} import org.apache.spark.util.CompletionIterator import org.apache.spark.util.collection.ExternalSorter -import org.apache.spark.{InterruptibleIterator, SparkEnv, TaskContext} private[spark] class HashShuffleReader[K, C]( handle: BaseShuffleHandle[K, _, C], @@ -51,24 +51,22 @@ private[spark] class HashShuffleReader[K, C]( // Create a key/value iterator for each stream val recordIterator = wrappedStreams.flatMap { wrappedStream => - val kvIter = serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator - CompletionIterator[(Any, Any), Iterator[(Any, Any)]](kvIter, { - // Close the stream once all the records have been read from it to free underlying - // ManagedBuffer as soon as possible. Note that in case of task failure, the task's - // TaskCompletionListener will make sure this is released. - wrappedStream.close() - }) + serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator } + val readMetrics = context.taskMetrics.createShuffleReadMetricsForDependency() // Update read metrics for each record materialized - val iter = new InterruptibleIterator[(Any, Any)](context, recordIterator) { - val readMetrics = context.taskMetrics.createShuffleReadMetricsForDependency() - override def next(): (Any, Any) = { - readMetrics.incRecordsRead(1) - delegate.next() - } + val metricIter = new InterruptibleIterator[(Any, Any)](context, recordIterator) { + override def next(): (Any, Any) = { + readMetrics.incRecordsRead(1) + delegate.next() + } } + val iter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]](metricIter, { + context.taskMetrics().updateShuffleReadMetrics() + }) + val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) { if (dep.mapSideCombine) { // We are reading values that are already combined diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index 68f6b47fffc38..a1376c8f4e484 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -23,10 +23,10 @@ import java.util.concurrent.LinkedBlockingQueue import scala.collection.mutable.{ArrayBuffer, HashSet, Queue} import scala.util.{Failure, Try} +import org.apache.spark.{Logging, TaskContext} import org.apache.spark.network.buffer.ManagedBuffer import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient} import org.apache.spark.util.Utils -import org.apache.spark.{Logging, TaskContext} /** * An iterator that fetches multiple blocks. For local blocks, it fetches from the local block @@ -306,7 +306,7 @@ final class ShuffleBlockFetcherIterator( // not exist, SPARK-4085). In that case, we should propagate the right exception so // the scheduler gets a FetchFailedException. Try(buf.createInputStream()).map { inputStream => - new WrappedInputStream(inputStream, this) + new BufferReleasingInputStream(inputStream, this) } } @@ -314,8 +314,10 @@ final class ShuffleBlockFetcherIterator( } } -// Helper class that ensures a ManagerBuffer is released upon InputStream.close() -private class WrappedInputStream(delegate: InputStream, iterator: ShuffleBlockFetcherIterator) +/** Helper class that ensures a ManagerBuffer is released upon InputStream.close() */ +private class BufferReleasingInputStream( + delegate: InputStream, + iterator: ShuffleBlockFetcherIterator) extends InputStream { private var closed = false diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index 60e6840cb00bd..f7dc651e6d5d0 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -115,7 +115,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite { // Make sure we release buffers when a wrapped input stream is closed. val mockBuf = localBlocks.getOrElse(blockId, remoteBlocks(blockId)) - val wrappedInputStream = new WrappedInputStream(mock(classOf[InputStream]), iterator) + val wrappedInputStream = new BufferReleasingInputStream(mock(classOf[InputStream]), iterator) verify(mockBuf, times(0)).release() wrappedInputStream.close() verify(mockBuf, times(1)).release()