From 11eabbe125b2ee572fad359c33c93f5e6fdf0b2d Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Mon, 7 Apr 2014 23:40:21 -0700 Subject: [PATCH] [SPARK-1103] Automatic garbage collection of RDD, shuffle and broadcast data MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR allows Spark to automatically cleanup metadata and data related to persisted RDDs, shuffles and broadcast variables when the corresponding RDDs, shuffles and broadcast variables fall out of scope from the driver program. This is still a work in progress as broadcast cleanup has not been implemented. **Implementation Details** A new class `ContextCleaner` is responsible cleaning all the state. It is instantiated as part of a `SparkContext`. RDD and ShuffleDependency classes have overridden `finalize()` function that gets called whenever their instances go out of scope. The `finalize()` function enqueues the object’s identifier (i.e. RDD ID, shuffle ID, etc.) with the `ContextCleaner`, which is a very short and cheap operation and should not significantly affect the garbage collection mechanism. The `ContextCleaner`, on a different thread, performs the cleanup, whose details are given below. *RDD cleanup:* `ContextCleaner` calls `RDD.unpersist()` is used to cleanup persisted RDDs. Regarding metadata, the DAGScheduler automatically cleans up all metadata related to a RDD after all jobs have completed. Only the `SparkContext.persistentRDDs` keeps strong references to persisted RDDs. The `TimeStampedHashMap` used for that has been replaced by `TimeStampedWeakValueHashMap` that keeps only weak references to the RDDs, allowing them to be garbage collected. *Shuffle cleanup:* New BlockManager message `RemoveShuffle()` asks the `BlockManagerMaster` and currently active `BlockManager`s to delete all the disk blocks related to the shuffle ID. `ContextCleaner` cleans up shuffle data using this message and also cleans up the metadata in the `MapOutputTracker` of the driver. The `MapOutputTracker` at the workers, that caches the shuffle metadata, maintains a `BoundedHashMap` to limit the shuffle information it caches. Refetching the shuffle information from the driver is not too costly. *Broadcast cleanup:* To be done. [This PR](https://github.com/apache/incubator-spark/pull/543/) adds mechanism for explicit cleanup of broadcast variables. `Broadcast.finalize()` will enqueue its own ID with ContextCleaner and the PRs mechanism will be used to unpersist the Broadcast data. *Other cleanup:* `ShuffleMapTask` and `ResultTask` caches tasks and used TTL based cleanup (using `TimeStampedHashMap`), so nothing got cleaned up if TTL was not set. Instead, they now use `BoundedHashMap` to keep a limited number of map output information. Cost of repopulating the cache if necessary is very small. **Current state of implementation** Implemented RDD and shuffle cleanup. Things left to be done are. - Cleaning up for broadcast variable still to be done. - Automatic cleaning up keys with empty weak refs as values in `TimeStampedWeakValueHashMap` Author: Tathagata Das Author: Andrew Or Author: Roman Pastukhov Closes #126 from tdas/state-cleanup and squashes the following commits: 61b8d6e [Tathagata Das] Fixed issue with Tachyon + new BlockManager methods. f489fdc [Tathagata Das] Merge remote-tracking branch 'apache/master' into state-cleanup d25a86e [Tathagata Das] Fixed stupid typo. cff023c [Tathagata Das] Fixed issues based on Andrew's comments. 4d05314 [Tathagata Das] Scala style fix. 2b95b5e [Tathagata Das] Added more documentation on Broadcast implementations, specially which blocks are told about to the driver. Also, fixed Broadcast API to hide destroy functionality. 41c9ece [Tathagata Das] Added more unit tests for BlockManager, DiskBlockManager, and ContextCleaner. 6222697 [Tathagata Das] Fixed bug and adding unit test for removeBroadcast in BlockManagerSuite. 104a89a [Tathagata Das] Fixed failing BroadcastSuite unit tests by introducing blocking for removeShuffle and removeBroadcast in BlockManager* a430f06 [Tathagata Das] Fixed compilation errors. b27f8e8 [Tathagata Das] Merge pull request #3 from andrewor14/cleanup cd72d19 [Andrew Or] Make automatic cleanup configurable (not documented) ada45f0 [Andrew Or] Merge branch 'state-cleanup' of github.com:tdas/spark into cleanup a2cc8bc [Tathagata Das] Merge remote-tracking branch 'apache/master' into state-cleanup c5b1d98 [Andrew Or] Address Patrick's comments a6460d4 [Andrew Or] Merge github.com:apache/spark into cleanup 762a4d8 [Tathagata Das] Merge pull request #1 from andrewor14/cleanup f0aabb1 [Andrew Or] Correct semantics for TimeStampedWeakValueHashMap + add tests 5016375 [Andrew Or] Address TD's comments 7ed72fb [Andrew Or] Fix style test fail + remove verbose test message regarding broadcast 634a097 [Andrew Or] Merge branch 'state-cleanup' of github.com:tdas/spark into cleanup 7edbc98 [Tathagata Das] Merge remote-tracking branch 'apache-github/master' into state-cleanup 8557c12 [Andrew Or] Merge github.com:apache/spark into cleanup e442246 [Andrew Or] Merge github.com:apache/spark into cleanup 88904a3 [Andrew Or] Make TimeStampedWeakValueHashMap a wrapper of TimeStampedHashMap fbfeec8 [Andrew Or] Add functionality to query executors for their local BlockStatuses 34f436f [Andrew Or] Generalize BroadcastBlockId to remove BroadcastHelperBlockId 0d17060 [Andrew Or] Import, comments, and style fixes (minor) c92e4d9 [Andrew Or] Merge github.com:apache/spark into cleanup f201a8d [Andrew Or] Test broadcast cleanup in ContextCleanerSuite + remove BoundedHashMap e95479c [Andrew Or] Add tests for unpersisting broadcast 544ac86 [Andrew Or] Clean up broadcast blocks through BlockManager* d0edef3 [Andrew Or] Add framework for broadcast cleanup ba52e00 [Andrew Or] Refactor broadcast classes c7ccef1 [Andrew Or] Merge branch 'bc-unpersist-merge' of github.com:ignatich/incubator-spark into cleanup 6c9dcf6 [Tathagata Das] Added missing Apache license d2f8b97 [Tathagata Das] Removed duplicate unpersistRDD. a007307 [Tathagata Das] Merge remote-tracking branch 'apache/master' into state-cleanup 620eca3 [Tathagata Das] Changes based on PR comments. f2881fd [Tathagata Das] Changed ContextCleaner to use ReferenceQueue instead of finalizer e1fba5f [Tathagata Das] Style fix 892b952 [Tathagata Das] Removed use of BoundedHashMap, and made BlockManagerSlaveActor cleanup shuffle metadata in MapOutputTrackerWorker. a7260d3 [Tathagata Das] Added try-catch in context cleaner and null value cleaning in TimeStampedWeakValueHashMap. e61daa0 [Tathagata Das] Modifications based on the comments on PR 126. ae9da88 [Tathagata Das] Removed unncessary TimeStampedHashMap from DAGScheduler, added try-catches in finalize() methods, and replaced ArrayBlockingQueue to LinkedBlockingQueue to avoid blocking in Java's finalizing thread. cb0a5a6 [Tathagata Das] Fixed docs and styles. a24fefc [Tathagata Das] Merge remote-tracking branch 'apache/master' into state-cleanup 8512612 [Tathagata Das] Changed TimeStampedHashMap to use WrappedJavaHashMap. e427a9e [Tathagata Das] Added ContextCleaner to automatically clean RDDs and shuffles when they fall out of scope. Also replaced TimeStampedHashMap to BoundedHashMaps and TimeStampedWeakValueHashMap for the necessary hashmap behavior. 80dd977 [Roman Pastukhov] Fix for Broadcast unpersist patch. 1e752f1 [Roman Pastukhov] Added unpersist method to Broadcast. --- .../org/apache/spark/ContextCleaner.scala | 192 ++++++++ .../scala/org/apache/spark/Dependency.scala | 2 + .../org/apache/spark/MapOutputTracker.scala | 148 ++++--- .../scala/org/apache/spark/SparkContext.scala | 23 +- .../scala/org/apache/spark/SparkEnv.scala | 25 +- .../apache/spark/broadcast/Broadcast.scala | 107 +++-- .../spark/broadcast/BroadcastFactory.scala | 3 +- .../spark/broadcast/BroadcastManager.scala | 66 +++ .../spark/broadcast/HttpBroadcast.scala | 128 ++++-- .../broadcast/HttpBroadcastFactory.scala | 45 ++ .../spark/broadcast/TorrentBroadcast.scala | 162 ++++--- .../broadcast/TorrentBroadcastFactory.scala | 46 ++ .../spark/network/ConnectionManager.scala | 1 - .../main/scala/org/apache/spark/rdd/RDD.scala | 5 +- .../apache/spark/scheduler/DAGScheduler.scala | 38 +- .../apache/spark/scheduler/ResultTask.scala | 16 +- .../spark/scheduler/ShuffleMapTask.scala | 14 +- .../spark/scheduler/TaskSchedulerImpl.scala | 2 +- .../org/apache/spark/storage/BlockId.scala | 24 +- .../apache/spark/storage/BlockManager.scala | 67 ++- .../spark/storage/BlockManagerMaster.scala | 84 +++- .../storage/BlockManagerMasterActor.scala | 107 ++++- .../spark/storage/BlockManagerMessages.scala | 20 +- .../storage/BlockManagerSlaveActor.scala | 60 ++- .../spark/storage/DiskBlockManager.scala | 14 + .../spark/storage/ShuffleBlockManager.scala | 44 +- .../apache/spark/storage/ThreadingTest.scala | 6 +- .../apache/spark/util/MetadataCleaner.scala | 19 +- .../spark/util/TimeStampedHashMap.scala | 109 ++--- .../util/TimeStampedWeakValueHashMap.scala | 170 +++++++ .../scala/org/apache/spark/util/Utils.scala | 8 +- .../org/apache/spark/AkkaUtilsSuite.scala | 8 +- .../org/apache/spark/BroadcastSuite.scala | 311 +++++++++++-- .../apache/spark/ContextCleanerSuite.scala | 415 ++++++++++++++++++ .../apache/spark/MapOutputTrackerSuite.scala | 25 +- .../spark/storage/BlockManagerSuite.scala | 243 ++++++++-- .../spark/storage/DiskBlockManagerSuite.scala | 10 +- .../apache/spark/util/JsonProtocolSuite.scala | 5 +- .../spark/util/TimeStampedHashMapSuite.scala | 264 +++++++++++ .../spark/streaming/dstream/DStream.scala | 4 +- 40 files changed, 2571 insertions(+), 469 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/ContextCleaner.scala create mode 100644 core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala create mode 100644 core/src/main/scala/org/apache/spark/broadcast/HttpBroadcastFactory.scala create mode 100644 core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala create mode 100644 core/src/main/scala/org/apache/spark/util/TimeStampedWeakValueHashMap.scala create mode 100644 core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala create mode 100644 core/src/test/scala/org/apache/spark/util/TimeStampedHashMapSuite.scala diff --git a/core/src/main/scala/org/apache/spark/ContextCleaner.scala b/core/src/main/scala/org/apache/spark/ContextCleaner.scala new file mode 100644 index 0000000000000..54e08d7866f75 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/ContextCleaner.scala @@ -0,0 +1,192 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import java.lang.ref.{ReferenceQueue, WeakReference} + +import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer} + +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.rdd.RDD + +/** + * Classes that represent cleaning tasks. + */ +private sealed trait CleanupTask +private case class CleanRDD(rddId: Int) extends CleanupTask +private case class CleanShuffle(shuffleId: Int) extends CleanupTask +private case class CleanBroadcast(broadcastId: Long) extends CleanupTask + +/** + * A WeakReference associated with a CleanupTask. + * + * When the referent object becomes only weakly reachable, the corresponding + * CleanupTaskWeakReference is automatically added to the given reference queue. + */ +private class CleanupTaskWeakReference( + val task: CleanupTask, + referent: AnyRef, + referenceQueue: ReferenceQueue[AnyRef]) + extends WeakReference(referent, referenceQueue) + +/** + * An asynchronous cleaner for RDD, shuffle, and broadcast state. + * + * This maintains a weak reference for each RDD, ShuffleDependency, and Broadcast of interest, + * to be processed when the associated object goes out of scope of the application. Actual + * cleanup is performed in a separate daemon thread. + */ +private[spark] class ContextCleaner(sc: SparkContext) extends Logging { + + private val referenceBuffer = new ArrayBuffer[CleanupTaskWeakReference] + with SynchronizedBuffer[CleanupTaskWeakReference] + + private val referenceQueue = new ReferenceQueue[AnyRef] + + private val listeners = new ArrayBuffer[CleanerListener] + with SynchronizedBuffer[CleanerListener] + + private val cleaningThread = new Thread() { override def run() { keepCleaning() }} + + /** + * Whether the cleaning thread will block on cleanup tasks. + * This is set to true only for tests. + */ + private val blockOnCleanupTasks = sc.conf.getBoolean( + "spark.cleaner.referenceTracking.blocking", false) + + @volatile private var stopped = false + + /** Attach a listener object to get information of when objects are cleaned. */ + def attachListener(listener: CleanerListener) { + listeners += listener + } + + /** Start the cleaner. */ + def start() { + cleaningThread.setDaemon(true) + cleaningThread.setName("Spark Context Cleaner") + cleaningThread.start() + } + + /** Stop the cleaner. */ + def stop() { + stopped = true + } + + /** Register a RDD for cleanup when it is garbage collected. */ + def registerRDDForCleanup(rdd: RDD[_]) { + registerForCleanup(rdd, CleanRDD(rdd.id)) + } + + /** Register a ShuffleDependency for cleanup when it is garbage collected. */ + def registerShuffleForCleanup(shuffleDependency: ShuffleDependency[_, _]) { + registerForCleanup(shuffleDependency, CleanShuffle(shuffleDependency.shuffleId)) + } + + /** Register a Broadcast for cleanup when it is garbage collected. */ + def registerBroadcastForCleanup[T](broadcast: Broadcast[T]) { + registerForCleanup(broadcast, CleanBroadcast(broadcast.id)) + } + + /** Register an object for cleanup. */ + private def registerForCleanup(objectForCleanup: AnyRef, task: CleanupTask) { + referenceBuffer += new CleanupTaskWeakReference(task, objectForCleanup, referenceQueue) + } + + /** Keep cleaning RDD, shuffle, and broadcast state. */ + private def keepCleaning() { + while (!stopped) { + try { + val reference = Option(referenceQueue.remove(ContextCleaner.REF_QUEUE_POLL_TIMEOUT)) + .map(_.asInstanceOf[CleanupTaskWeakReference]) + reference.map(_.task).foreach { task => + logDebug("Got cleaning task " + task) + referenceBuffer -= reference.get + task match { + case CleanRDD(rddId) => + doCleanupRDD(rddId, blocking = blockOnCleanupTasks) + case CleanShuffle(shuffleId) => + doCleanupShuffle(shuffleId, blocking = blockOnCleanupTasks) + case CleanBroadcast(broadcastId) => + doCleanupBroadcast(broadcastId, blocking = blockOnCleanupTasks) + } + } + } catch { + case t: Throwable => logError("Error in cleaning thread", t) + } + } + } + + /** Perform RDD cleanup. */ + def doCleanupRDD(rddId: Int, blocking: Boolean) { + try { + logDebug("Cleaning RDD " + rddId) + sc.unpersistRDD(rddId, blocking) + listeners.foreach(_.rddCleaned(rddId)) + logInfo("Cleaned RDD " + rddId) + } catch { + case t: Throwable => logError("Error cleaning RDD " + rddId, t) + } + } + + /** Perform shuffle cleanup, asynchronously. */ + def doCleanupShuffle(shuffleId: Int, blocking: Boolean) { + try { + logDebug("Cleaning shuffle " + shuffleId) + mapOutputTrackerMaster.unregisterShuffle(shuffleId) + blockManagerMaster.removeShuffle(shuffleId, blocking) + listeners.foreach(_.shuffleCleaned(shuffleId)) + logInfo("Cleaned shuffle " + shuffleId) + } catch { + case t: Throwable => logError("Error cleaning shuffle " + shuffleId, t) + } + } + + /** Perform broadcast cleanup. */ + def doCleanupBroadcast(broadcastId: Long, blocking: Boolean) { + try { + logDebug("Cleaning broadcast " + broadcastId) + broadcastManager.unbroadcast(broadcastId, true, blocking) + listeners.foreach(_.broadcastCleaned(broadcastId)) + logInfo("Cleaned broadcast " + broadcastId) + } catch { + case t: Throwable => logError("Error cleaning broadcast " + broadcastId, t) + } + } + + private def blockManagerMaster = sc.env.blockManager.master + private def broadcastManager = sc.env.broadcastManager + private def mapOutputTrackerMaster = sc.env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster] + + // Used for testing. These methods explicitly blocks until cleanup is completed + // to ensure that more reliable testing. +} + +private object ContextCleaner { + private val REF_QUEUE_POLL_TIMEOUT = 100 +} + +/** + * Listener class used for testing when any item has been cleaned by the Cleaner class. + */ +private[spark] trait CleanerListener { + def rddCleaned(rddId: Int) + def shuffleCleaned(shuffleId: Int) + def broadcastCleaned(broadcastId: Long) +} diff --git a/core/src/main/scala/org/apache/spark/Dependency.scala b/core/src/main/scala/org/apache/spark/Dependency.scala index 3132dcf745e19..1cd629c15bd46 100644 --- a/core/src/main/scala/org/apache/spark/Dependency.scala +++ b/core/src/main/scala/org/apache/spark/Dependency.scala @@ -55,6 +55,8 @@ class ShuffleDependency[K, V]( extends Dependency(rdd.asInstanceOf[RDD[Product2[K, V]]]) { val shuffleId: Int = rdd.context.newShuffleId() + + rdd.sparkContext.cleaner.foreach(_.registerShuffleForCleanup(this)) } diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 80cbf951cb70e..ee82d9fa7874b 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -20,21 +20,21 @@ package org.apache.spark import java.io._ import java.util.zip.{GZIPInputStream, GZIPOutputStream} -import scala.collection.mutable.HashSet +import scala.collection.mutable.{HashSet, HashMap, Map} import scala.concurrent.Await import akka.actor._ import akka.pattern.ask - import org.apache.spark.scheduler.MapStatus import org.apache.spark.storage.BlockManagerId -import org.apache.spark.util.{AkkaUtils, MetadataCleaner, MetadataCleanerType, TimeStampedHashMap} +import org.apache.spark.util._ private[spark] sealed trait MapOutputTrackerMessage private[spark] case class GetMapOutputStatuses(shuffleId: Int) extends MapOutputTrackerMessage private[spark] case object StopMapOutputTracker extends MapOutputTrackerMessage +/** Actor class for MapOutputTrackerMaster */ private[spark] class MapOutputTrackerMasterActor(tracker: MapOutputTrackerMaster, conf: SparkConf) extends Actor with Logging { val maxAkkaFrameSize = AkkaUtils.maxFrameSizeBytes(conf) @@ -65,26 +65,41 @@ private[spark] class MapOutputTrackerMasterActor(tracker: MapOutputTrackerMaster } } -private[spark] class MapOutputTracker(conf: SparkConf) extends Logging { - +/** + * Class that keeps track of the location of the map output of + * a stage. This is abstract because different versions of MapOutputTracker + * (driver and worker) use different HashMap to store its metadata. + */ +private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging { private val timeout = AkkaUtils.askTimeout(conf) - // Set to the MapOutputTrackerActor living on the driver + /** Set to the MapOutputTrackerActor living on the driver. */ var trackerActor: ActorRef = _ - protected val mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]] + /** + * This HashMap has different behavior for the master and the workers. + * + * On the master, it serves as the source of map outputs recorded from ShuffleMapTasks. + * On the workers, it simply serves as a cache, in which a miss triggers a fetch from the + * master's corresponding HashMap. + */ + protected val mapStatuses: Map[Int, Array[MapStatus]] - // Incremented every time a fetch fails so that client nodes know to clear - // their cache of map output locations if this happens. + /** + * Incremented every time a fetch fails so that client nodes know to clear + * their cache of map output locations if this happens. + */ protected var epoch: Long = 0 - protected val epochLock = new java.lang.Object + protected val epochLock = new AnyRef - private val metadataCleaner = - new MetadataCleaner(MetadataCleanerType.MAP_OUTPUT_TRACKER, this.cleanup, conf) + /** Remembers which map output locations are currently being fetched on a worker. */ + private val fetching = new HashSet[Int] - // Send a message to the trackerActor and get its result within a default timeout, or - // throw a SparkException if this fails. - private def askTracker(message: Any): Any = { + /** + * Send a message to the trackerActor and get its result within a default timeout, or + * throw a SparkException if this fails. + */ + protected def askTracker(message: Any): Any = { try { val future = trackerActor.ask(message)(timeout) Await.result(future, timeout) @@ -94,17 +109,17 @@ private[spark] class MapOutputTracker(conf: SparkConf) extends Logging { } } - // Send a one-way message to the trackerActor, to which we expect it to reply with true. - private def communicate(message: Any) { + /** Send a one-way message to the trackerActor, to which we expect it to reply with true. */ + protected def sendTracker(message: Any) { if (askTracker(message) != true) { throw new SparkException("Error reply received from MapOutputTracker") } } - // Remembers which map output locations are currently being fetched on a worker - private val fetching = new HashSet[Int] - - // Called on possibly remote nodes to get the server URIs and output sizes for a given shuffle + /** + * Called from executors to get the server URIs and output sizes of the map outputs of + * a given shuffle. + */ def getServerStatuses(shuffleId: Int, reduceId: Int): Array[(BlockManagerId, Long)] = { val statuses = mapStatuses.get(shuffleId).orNull if (statuses == null) { @@ -152,8 +167,7 @@ private[spark] class MapOutputTracker(conf: SparkConf) extends Logging { fetchedStatuses.synchronized { return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, fetchedStatuses) } - } - else { + } else { throw new FetchFailedException(null, shuffleId, -1, reduceId, new Exception("Missing all output locations for shuffle " + shuffleId)) } @@ -164,27 +178,18 @@ private[spark] class MapOutputTracker(conf: SparkConf) extends Logging { } } - protected def cleanup(cleanupTime: Long) { - mapStatuses.clearOldValues(cleanupTime) - } - - def stop() { - communicate(StopMapOutputTracker) - mapStatuses.clear() - metadataCleaner.cancel() - trackerActor = null - } - - // Called to get current epoch number + /** Called to get current epoch number. */ def getEpoch: Long = { epochLock.synchronized { return epoch } } - // Called on workers to update the epoch number, potentially clearing old outputs - // because of a fetch failure. (Each worker task calls this with the latest epoch - // number on the master at the time it was created.) + /** + * Called from executors to update the epoch number, potentially clearing old outputs + * because of a fetch failure. Each worker task calls this with the latest epoch + * number on the master at the time it was created. + */ def updateEpoch(newEpoch: Long) { epochLock.synchronized { if (newEpoch > epoch) { @@ -194,17 +199,40 @@ private[spark] class MapOutputTracker(conf: SparkConf) extends Logging { } } } + + /** Unregister shuffle data. */ + def unregisterShuffle(shuffleId: Int) { + mapStatuses.remove(shuffleId) + } + + /** Stop the tracker. */ + def stop() { } } +/** + * MapOutputTracker for the driver. This uses TimeStampedHashMap to keep track of map + * output information, which allows old output information based on a TTL. + */ private[spark] class MapOutputTrackerMaster(conf: SparkConf) extends MapOutputTracker(conf) { - // Cache a serialized version of the output statuses for each shuffle to send them out faster + /** Cache a serialized version of the output statuses for each shuffle to send them out faster */ private var cacheEpoch = epoch - private val cachedSerializedStatuses = new TimeStampedHashMap[Int, Array[Byte]] + + /** + * Timestamp based HashMap for storing mapStatuses and cached serialized statuses in the master, + * so that statuses are dropped only by explicit de-registering or by TTL-based cleaning (if set). + * Other than these two scenarios, nothing should be dropped from this HashMap. + */ + protected val mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]]() + private val cachedSerializedStatuses = new TimeStampedHashMap[Int, Array[Byte]]() + + // For cleaning up TimeStampedHashMaps + private val metadataCleaner = + new MetadataCleaner(MetadataCleanerType.MAP_OUTPUT_TRACKER, this.cleanup, conf) def registerShuffle(shuffleId: Int, numMaps: Int) { - if (mapStatuses.putIfAbsent(shuffleId, new Array[MapStatus](numMaps)).isDefined) { + if (mapStatuses.put(shuffleId, new Array[MapStatus](numMaps)).isDefined) { throw new IllegalArgumentException("Shuffle ID " + shuffleId + " registered twice") } } @@ -216,6 +244,7 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf) } } + /** Register multiple map output information for the given shuffle */ def registerMapOutputs(shuffleId: Int, statuses: Array[MapStatus], changeEpoch: Boolean = false) { mapStatuses.put(shuffleId, Array[MapStatus]() ++ statuses) if (changeEpoch) { @@ -223,6 +252,7 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf) } } + /** Unregister map output information of the given shuffle, mapper and block manager */ def unregisterMapOutput(shuffleId: Int, mapId: Int, bmAddress: BlockManagerId) { val arrayOpt = mapStatuses.get(shuffleId) if (arrayOpt.isDefined && arrayOpt.get != null) { @@ -238,6 +268,17 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf) } } + /** Unregister shuffle data */ + override def unregisterShuffle(shuffleId: Int) { + mapStatuses.remove(shuffleId) + cachedSerializedStatuses.remove(shuffleId) + } + + /** Check if the given shuffle is being tracked */ + def containsShuffle(shuffleId: Int): Boolean = { + cachedSerializedStatuses.contains(shuffleId) || mapStatuses.contains(shuffleId) + } + def incrementEpoch() { epochLock.synchronized { epoch += 1 @@ -274,23 +315,26 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf) bytes } - protected override def cleanup(cleanupTime: Long) { - super.cleanup(cleanupTime) - cachedSerializedStatuses.clearOldValues(cleanupTime) - } - override def stop() { - super.stop() + sendTracker(StopMapOutputTracker) + mapStatuses.clear() + trackerActor = null + metadataCleaner.cancel() cachedSerializedStatuses.clear() } - override def updateEpoch(newEpoch: Long) { - // This might be called on the MapOutputTrackerMaster if we're running in local mode. + private def cleanup(cleanupTime: Long) { + mapStatuses.clearOldValues(cleanupTime) + cachedSerializedStatuses.clearOldValues(cleanupTime) } +} - def has(shuffleId: Int): Boolean = { - cachedSerializedStatuses.get(shuffleId).isDefined || mapStatuses.contains(shuffleId) - } +/** + * MapOutputTracker for the workers, which fetches map output information from the driver's + * MapOutputTrackerMaster. + */ +private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTracker(conf) { + protected val mapStatuses = new HashMap[Int, Array[MapStatus]] } private[spark] object MapOutputTracker { diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index e5ebd350eeced..d7124616d3bfb 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -45,7 +45,7 @@ import org.apache.spark.scheduler.cluster.mesos.{CoarseMesosSchedulerBackend, Me import org.apache.spark.scheduler.local.LocalBackend import org.apache.spark.storage.{BlockManagerSource, RDDInfo, StorageStatus, StorageUtils} import org.apache.spark.ui.SparkUI -import org.apache.spark.util.{ClosureCleaner, MetadataCleaner, MetadataCleanerType, TimeStampedHashMap, Utils} +import org.apache.spark.util.{ClosureCleaner, MetadataCleaner, MetadataCleanerType, TimeStampedWeakValueHashMap, Utils} /** * Main entry point for Spark functionality. A SparkContext represents the connection to a Spark @@ -157,7 +157,7 @@ class SparkContext( private[spark] val addedJars = HashMap[String, Long]() // Keeps track of all persisted RDDs - private[spark] val persistentRdds = new TimeStampedHashMap[Int, RDD[_]] + private[spark] val persistentRdds = new TimeStampedWeakValueHashMap[Int, RDD[_]] private[spark] val metadataCleaner = new MetadataCleaner(MetadataCleanerType.SPARK_CONTEXT, this.cleanup, conf) @@ -233,6 +233,15 @@ class SparkContext( @volatile private[spark] var dagScheduler = new DAGScheduler(this) dagScheduler.start() + private[spark] val cleaner: Option[ContextCleaner] = { + if (conf.getBoolean("spark.cleaner.referenceTracking", true)) { + Some(new ContextCleaner(this)) + } else { + None + } + } + cleaner.foreach(_.start()) + postEnvironmentUpdate() /** A default Hadoop Configuration for the Hadoop code (e.g. file systems) that we reuse. */ @@ -679,7 +688,11 @@ class SparkContext( * [[org.apache.spark.broadcast.Broadcast]] object for reading it in distributed functions. * The variable will be sent to each cluster only once. */ - def broadcast[T](value: T): Broadcast[T] = env.broadcastManager.newBroadcast[T](value, isLocal) + def broadcast[T](value: T): Broadcast[T] = { + val bc = env.broadcastManager.newBroadcast[T](value, isLocal) + cleaner.foreach(_.registerBroadcastForCleanup(bc)) + bc + } /** * Add a file to be downloaded with this Spark job on every node. @@ -789,8 +802,7 @@ class SparkContext( /** * Unpersist an RDD from memory and/or disk storage */ - private[spark] def unpersistRDD(rdd: RDD[_], blocking: Boolean = true) { - val rddId = rdd.id + private[spark] def unpersistRDD(rddId: Int, blocking: Boolean = true) { env.blockManager.master.removeRdd(rddId, blocking) persistentRdds.remove(rddId) listenerBus.post(SparkListenerUnpersistRDD(rddId)) @@ -869,6 +881,7 @@ class SparkContext( dagScheduler = null if (dagSchedulerCopy != null) { metadataCleaner.cancel() + cleaner.foreach(_.stop()) dagSchedulerCopy.stop() listenerBus.stop() taskScheduler = null diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 5ceac28fe7afb..9ea123f174b95 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -180,12 +180,24 @@ object SparkEnv extends Logging { } } + val mapOutputTracker = if (isDriver) { + new MapOutputTrackerMaster(conf) + } else { + new MapOutputTrackerWorker(conf) + } + + // Have to assign trackerActor after initialization as MapOutputTrackerActor + // requires the MapOutputTracker itself + mapOutputTracker.trackerActor = registerOrLookup( + "MapOutputTracker", + new MapOutputTrackerMasterActor(mapOutputTracker.asInstanceOf[MapOutputTrackerMaster], conf)) + val blockManagerMaster = new BlockManagerMaster(registerOrLookup( "BlockManagerMaster", new BlockManagerMasterActor(isLocal, conf, listenerBus)), conf) val blockManager = new BlockManager(executorId, actorSystem, blockManagerMaster, - serializer, conf, securityManager) + serializer, conf, securityManager, mapOutputTracker) val connectionManager = blockManager.connectionManager @@ -193,17 +205,6 @@ object SparkEnv extends Logging { val cacheManager = new CacheManager(blockManager) - // Have to assign trackerActor after initialization as MapOutputTrackerActor - // requires the MapOutputTracker itself - val mapOutputTracker = if (isDriver) { - new MapOutputTrackerMaster(conf) - } else { - new MapOutputTracker(conf) - } - mapOutputTracker.trackerActor = registerOrLookup( - "MapOutputTracker", - new MapOutputTrackerMasterActor(mapOutputTracker.asInstanceOf[MapOutputTrackerMaster], conf)) - val shuffleFetcher = instantiateClass[ShuffleFetcher]( "spark.shuffle.fetcher", "org.apache.spark.BlockStoreShuffleFetcher") diff --git a/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala index e3c3a12d16f2a..738a3b1bed7f3 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala @@ -18,9 +18,8 @@ package org.apache.spark.broadcast import java.io.Serializable -import java.util.concurrent.atomic.AtomicLong -import org.apache.spark._ +import org.apache.spark.SparkException /** * A broadcast variable. Broadcast variables allow the programmer to keep a read-only variable @@ -29,7 +28,8 @@ import org.apache.spark._ * attempts to distribute broadcast variables using efficient broadcast algorithms to reduce * communication cost. * - * Broadcast variables are created from a variable `v` by calling [[SparkContext#broadcast]]. + * Broadcast variables are created from a variable `v` by calling + * [[org.apache.spark.SparkContext#broadcast]]. * The broadcast variable is a wrapper around `v`, and its value can be accessed by calling the * `value` method. The interpreter session below shows this: * @@ -51,49 +51,80 @@ import org.apache.spark._ * @tparam T Type of the data contained in the broadcast variable. */ abstract class Broadcast[T](val id: Long) extends Serializable { - def value: T - // We cannot have an abstract readObject here due to some weird issues with - // readObject having to be 'private' in sub-classes. + /** + * Flag signifying whether the broadcast variable is valid + * (that is, not already destroyed) or not. + */ + @volatile private var _isValid = true - override def toString = "Broadcast(" + id + ")" -} - -private[spark] -class BroadcastManager(val _isDriver: Boolean, conf: SparkConf, securityManager: SecurityManager) - extends Logging with Serializable { - - private var initialized = false - private var broadcastFactory: BroadcastFactory = null - - initialize() - - // Called by SparkContext or Executor before using Broadcast - private def initialize() { - synchronized { - if (!initialized) { - val broadcastFactoryClass = conf.get( - "spark.broadcast.factory", "org.apache.spark.broadcast.HttpBroadcastFactory") - - broadcastFactory = - Class.forName(broadcastFactoryClass).newInstance.asInstanceOf[BroadcastFactory] + /** Get the broadcasted value. */ + def value: T = { + assertValid() + getValue() + } - // Initialize appropriate BroadcastFactory and BroadcastObject - broadcastFactory.initialize(isDriver, conf, securityManager) + /** + * Asynchronously delete cached copies of this broadcast on the executors. + * If the broadcast is used after this is called, it will need to be re-sent to each executor. + */ + def unpersist() { + unpersist(blocking = false) + } - initialized = true - } - } + /** + * Delete cached copies of this broadcast on the executors. If the broadcast is used after + * this is called, it will need to be re-sent to each executor. + * @param blocking Whether to block until unpersisting has completed + */ + def unpersist(blocking: Boolean) { + assertValid() + doUnpersist(blocking) } - def stop() { - broadcastFactory.stop() + /** + * Destroy all data and metadata related to this broadcast variable. Use this with caution; + * once a broadcast variable has been destroyed, it cannot be used again. + */ + private[spark] def destroy(blocking: Boolean) { + assertValid() + _isValid = false + doDestroy(blocking) } - private val nextBroadcastId = new AtomicLong(0) + /** + * Whether this Broadcast is actually usable. This should be false once persisted state is + * removed from the driver. + */ + private[spark] def isValid: Boolean = { + _isValid + } - def newBroadcast[T](value_ : T, isLocal: Boolean) = - broadcastFactory.newBroadcast[T](value_, isLocal, nextBroadcastId.getAndIncrement()) + /** + * Actually get the broadcasted value. Concrete implementations of Broadcast class must + * define their own way to get the value. + */ + private[spark] def getValue(): T + + /** + * Actually unpersist the broadcasted value on the executors. Concrete implementations of + * Broadcast class must define their own logic to unpersist their own data. + */ + private[spark] def doUnpersist(blocking: Boolean) + + /** + * Actually destroy all data and metadata related to this broadcast variable. + * Implementation of Broadcast class must define their own logic to destroy their own + * state. + */ + private[spark] def doDestroy(blocking: Boolean) + + /** Check if this broadcast is valid. If not valid, exception is thrown. */ + private[spark] def assertValid() { + if (!_isValid) { + throw new SparkException("Attempted to use %s after it has been destroyed!".format(toString)) + } + } - def isDriver = _isDriver + override def toString = "Broadcast(" + id + ")" } diff --git a/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala b/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala index 6beecaeced5be..c7f7c59cfb449 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala @@ -27,7 +27,8 @@ import org.apache.spark.SparkConf * entire Spark job. */ trait BroadcastFactory { - def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager): Unit + def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager): Unit def newBroadcast[T](value: T, isLocal: Boolean, id: Long): Broadcast[T] + def unbroadcast(id: Long, removeFromDriver: Boolean, blocking: Boolean): Unit def stop(): Unit } diff --git a/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala b/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala new file mode 100644 index 0000000000000..cf62aca4d45e8 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.broadcast + +import java.util.concurrent.atomic.AtomicLong + +import org.apache.spark._ + +private[spark] class BroadcastManager( + val isDriver: Boolean, + conf: SparkConf, + securityManager: SecurityManager) + extends Logging { + + private var initialized = false + private var broadcastFactory: BroadcastFactory = null + + initialize() + + // Called by SparkContext or Executor before using Broadcast + private def initialize() { + synchronized { + if (!initialized) { + val broadcastFactoryClass = + conf.get("spark.broadcast.factory", "org.apache.spark.broadcast.HttpBroadcastFactory") + + broadcastFactory = + Class.forName(broadcastFactoryClass).newInstance.asInstanceOf[BroadcastFactory] + + // Initialize appropriate BroadcastFactory and BroadcastObject + broadcastFactory.initialize(isDriver, conf, securityManager) + + initialized = true + } + } + } + + def stop() { + broadcastFactory.stop() + } + + private val nextBroadcastId = new AtomicLong(0) + + def newBroadcast[T](value_ : T, isLocal: Boolean) = { + broadcastFactory.newBroadcast[T](value_, isLocal, nextBroadcastId.getAndIncrement()) + } + + def unbroadcast(id: Long, removeFromDriver: Boolean, blocking: Boolean) { + broadcastFactory.unbroadcast(id, removeFromDriver, blocking) + } +} diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala index e8eb04bb10469..f6a8a8af91e4b 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala @@ -17,34 +17,65 @@ package org.apache.spark.broadcast -import java.io.{File, FileOutputStream, ObjectInputStream, OutputStream} -import java.net.{URL, URLConnection, URI} +import java.io.{File, FileOutputStream, ObjectInputStream, ObjectOutputStream, OutputStream} +import java.net.{URI, URL, URLConnection} import java.util.concurrent.TimeUnit -import it.unimi.dsi.fastutil.io.FastBufferedInputStream -import it.unimi.dsi.fastutil.io.FastBufferedOutputStream +import it.unimi.dsi.fastutil.io.{FastBufferedInputStream, FastBufferedOutputStream} -import org.apache.spark.{SparkConf, HttpServer, Logging, SecurityManager, SparkEnv} +import org.apache.spark.{HttpServer, Logging, SecurityManager, SparkConf, SparkEnv} import org.apache.spark.io.CompressionCodec import org.apache.spark.storage.{BroadcastBlockId, StorageLevel} import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashSet, Utils} +/** + * A [[org.apache.spark.broadcast.Broadcast]] implementation that uses HTTP server + * as a broadcast mechanism. The first time a HTTP broadcast variable (sent as part of a + * task) is deserialized in the executor, the broadcasted data is fetched from the driver + * (through a HTTP server running at the driver) and stored in the BlockManager of the + * executor to speed up future accesses. + */ private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long) extends Broadcast[T](id) with Logging with Serializable { - def value = value_ + def getValue = value_ - def blockId = BroadcastBlockId(id) + val blockId = BroadcastBlockId(id) + /* + * Broadcasted data is also stored in the BlockManager of the driver. The BlockManagerMaster + * does not need to be told about this block as not only need to know about this data block. + */ HttpBroadcast.synchronized { - SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, false) + SparkEnv.get.blockManager.putSingle( + blockId, value_, StorageLevel.MEMORY_AND_DISK, tellMaster = false) } if (!isLocal) { HttpBroadcast.write(id, value_) } - // Called by JVM when deserializing an object + /** + * Remove all persisted state associated with this HTTP broadcast on the executors. + */ + def doUnpersist(blocking: Boolean) { + HttpBroadcast.unpersist(id, removeFromDriver = false, blocking) + } + + /** + * Remove all persisted state associated with this HTTP broadcast on the executors and driver. + */ + def doDestroy(blocking: Boolean) { + HttpBroadcast.unpersist(id, removeFromDriver = true, blocking) + } + + /** Used by the JVM when serializing this object. */ + private def writeObject(out: ObjectOutputStream) { + assertValid() + out.defaultWriteObject() + } + + /** Used by the JVM when deserializing this object. */ private def readObject(in: ObjectInputStream) { in.defaultReadObject() HttpBroadcast.synchronized { @@ -54,7 +85,13 @@ private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolea logInfo("Started reading broadcast variable " + id) val start = System.nanoTime value_ = HttpBroadcast.read[T](id) - SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, false) + /* + * We cache broadcast data in the BlockManager so that subsequent tasks using it + * do not need to re-fetch. This data is only used locally and no other node + * needs to fetch this block, so we don't notify the master. + */ + SparkEnv.get.blockManager.putSingle( + blockId, value_, StorageLevel.MEMORY_AND_DISK, tellMaster = false) val time = (System.nanoTime - start) / 1e9 logInfo("Reading broadcast variable " + id + " took " + time + " s") } @@ -63,23 +100,8 @@ private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolea } } -/** - * A [[BroadcastFactory]] implementation that uses a HTTP server as the broadcast medium. - */ -class HttpBroadcastFactory extends BroadcastFactory { - def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager) { - HttpBroadcast.initialize(isDriver, conf, securityMgr) - } - - def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) = - new HttpBroadcast[T](value_, isLocal, id) - - def stop() { HttpBroadcast.stop() } -} - -private object HttpBroadcast extends Logging { +private[spark] object HttpBroadcast extends Logging { private var initialized = false - private var broadcastDir: File = null private var compress: Boolean = false private var bufferSize: Int = 65536 @@ -89,11 +111,9 @@ private object HttpBroadcast extends Logging { // TODO: This shouldn't be a global variable so that multiple SparkContexts can coexist private val files = new TimeStampedHashSet[String] - private var cleaner: MetadataCleaner = null - private val httpReadTimeout = TimeUnit.MILLISECONDS.convert(5, TimeUnit.MINUTES).toInt - private var compressionCodec: CompressionCodec = null + private var cleaner: MetadataCleaner = null def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager) { synchronized { @@ -136,8 +156,10 @@ private object HttpBroadcast extends Logging { logInfo("Broadcast server started at " + serverUri) } + def getFile(id: Long) = new File(broadcastDir, BroadcastBlockId(id).name) + def write(id: Long, value: Any) { - val file = new File(broadcastDir, BroadcastBlockId(id).name) + val file = getFile(id) val out: OutputStream = { if (compress) { compressionCodec.compressedOutputStream(new FileOutputStream(file)) @@ -160,7 +182,7 @@ private object HttpBroadcast extends Logging { if (securityManager.isAuthenticationEnabled()) { logDebug("broadcast security enabled") val newuri = Utils.constructURIForAuthentication(new URI(url), securityManager) - uc = newuri.toURL().openConnection() + uc = newuri.toURL.openConnection() uc.setAllowUserInteraction(false) } else { logDebug("broadcast not using security") @@ -169,7 +191,7 @@ private object HttpBroadcast extends Logging { val in = { uc.setReadTimeout(httpReadTimeout) - val inputStream = uc.getInputStream(); + val inputStream = uc.getInputStream if (compress) { compressionCodec.compressedInputStream(inputStream) } else { @@ -183,20 +205,48 @@ private object HttpBroadcast extends Logging { obj } - def cleanup(cleanupTime: Long) { + /** + * Remove all persisted blocks associated with this HTTP broadcast on the executors. + * If removeFromDriver is true, also remove these persisted blocks on the driver + * and delete the associated broadcast file. + */ + def unpersist(id: Long, removeFromDriver: Boolean, blocking: Boolean) = synchronized { + SparkEnv.get.blockManager.master.removeBroadcast(id, removeFromDriver, blocking) + if (removeFromDriver) { + val file = getFile(id) + files.remove(file.toString) + deleteBroadcastFile(file) + } + } + + /** + * Periodically clean up old broadcasts by removing the associated map entries and + * deleting the associated files. + */ + private def cleanup(cleanupTime: Long) { val iterator = files.internalMap.entrySet().iterator() while(iterator.hasNext) { val entry = iterator.next() val (file, time) = (entry.getKey, entry.getValue) if (time < cleanupTime) { - try { - iterator.remove() - new File(file.toString).delete() - logInfo("Deleted broadcast file '" + file + "'") - } catch { - case e: Exception => logWarning("Could not delete broadcast file '" + file + "'", e) + iterator.remove() + deleteBroadcastFile(new File(file.toString)) + } + } + } + + private def deleteBroadcastFile(file: File) { + try { + if (file.exists) { + if (file.delete()) { + logInfo("Deleted broadcast file: %s".format(file)) + } else { + logWarning("Could not delete broadcast file: %s".format(file)) } } + } catch { + case e: Exception => + logError("Exception while deleting broadcast file: %s".format(file), e) } } } diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcastFactory.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcastFactory.scala new file mode 100644 index 0000000000000..e3f6cdc6154dd --- /dev/null +++ b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcastFactory.scala @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.broadcast + +import org.apache.spark.{SecurityManager, SparkConf} + +/** + * A [[org.apache.spark.broadcast.BroadcastFactory]] implementation that uses a + * HTTP server as the broadcast mechanism. Refer to + * [[org.apache.spark.broadcast.HttpBroadcast]] for more details about this mechanism. + */ +class HttpBroadcastFactory extends BroadcastFactory { + def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager) { + HttpBroadcast.initialize(isDriver, conf, securityMgr) + } + + def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) = + new HttpBroadcast[T](value_, isLocal, id) + + def stop() { HttpBroadcast.stop() } + + /** + * Remove all persisted state associated with the HTTP broadcast with the given ID. + * @param removeFromDriver Whether to remove state from the driver + * @param blocking Whether to block until unbroadcasted + */ + def unbroadcast(id: Long, removeFromDriver: Boolean, blocking: Boolean) { + HttpBroadcast.unpersist(id, removeFromDriver, blocking) + } +} diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala index 2595c15104e87..2b32546c6854d 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala @@ -17,24 +17,43 @@ package org.apache.spark.broadcast -import java.io._ +import java.io.{ByteArrayInputStream, ObjectInputStream, ObjectOutputStream} import scala.math import scala.util.Random -import org.apache.spark._ -import org.apache.spark.storage.{BroadcastBlockId, BroadcastHelperBlockId, StorageLevel} +import org.apache.spark.{Logging, SparkConf, SparkEnv, SparkException} +import org.apache.spark.storage.{BroadcastBlockId, StorageLevel} import org.apache.spark.util.Utils +/** + * A [[org.apache.spark.broadcast.Broadcast]] implementation that uses a BitTorrent-like + * protocol to do a distributed transfer of the broadcasted data to the executors. + * The mechanism is as follows. The driver divides the serializes the broadcasted data, + * divides it into smaller chunks, and stores them in the BlockManager of the driver. + * These chunks are reported to the BlockManagerMaster so that all the executors can + * learn the location of those chunks. The first time the broadcast variable (sent as + * part of task) is deserialized at a executor, all the chunks are fetched using + * the BlockManager. When all the chunks are fetched (initially from the driver's + * BlockManager), they are combined and deserialized to recreate the broadcasted data. + * However, the chunks are also stored in the BlockManager and reported to the + * BlockManagerMaster. As more executors fetch the chunks, BlockManagerMaster learns + * multiple locations for each chunk. Hence, subsequent fetches of each chunk will be + * made to other executors who already have those chunks, resulting in a distributed + * fetching. This prevents the driver from being the bottleneck in sending out multiple + * copies of the broadcast data (one per executor) as done by the + * [[org.apache.spark.broadcast.HttpBroadcast]]. + */ private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long) -extends Broadcast[T](id) with Logging with Serializable { + extends Broadcast[T](id) with Logging with Serializable { - def value = value_ + def getValue = value_ - def broadcastId = BroadcastBlockId(id) + val broadcastId = BroadcastBlockId(id) TorrentBroadcast.synchronized { - SparkEnv.get.blockManager.putSingle(broadcastId, value_, StorageLevel.MEMORY_AND_DISK, false) + SparkEnv.get.blockManager.putSingle( + broadcastId, value_, StorageLevel.MEMORY_AND_DISK, tellMaster = false) } @transient var arrayOfBlocks: Array[TorrentBlock] = null @@ -46,32 +65,52 @@ extends Broadcast[T](id) with Logging with Serializable { sendBroadcast() } - def sendBroadcast() { - var tInfo = TorrentBroadcast.blockifyObject(value_) + /** + * Remove all persisted state associated with this Torrent broadcast on the executors. + */ + def doUnpersist(blocking: Boolean) { + TorrentBroadcast.unpersist(id, removeFromDriver = false, blocking) + } + + /** + * Remove all persisted state associated with this Torrent broadcast on the executors + * and driver. + */ + def doDestroy(blocking: Boolean) { + TorrentBroadcast.unpersist(id, removeFromDriver = true, blocking) + } + def sendBroadcast() { + val tInfo = TorrentBroadcast.blockifyObject(value_) totalBlocks = tInfo.totalBlocks totalBytes = tInfo.totalBytes hasBlocks = tInfo.totalBlocks // Store meta-info - val metaId = BroadcastHelperBlockId(broadcastId, "meta") + val metaId = BroadcastBlockId(id, "meta") val metaInfo = TorrentInfo(null, totalBlocks, totalBytes) TorrentBroadcast.synchronized { SparkEnv.get.blockManager.putSingle( - metaId, metaInfo, StorageLevel.MEMORY_AND_DISK, true) + metaId, metaInfo, StorageLevel.MEMORY_AND_DISK, tellMaster = true) } // Store individual pieces for (i <- 0 until totalBlocks) { - val pieceId = BroadcastHelperBlockId(broadcastId, "piece" + i) + val pieceId = BroadcastBlockId(id, "piece" + i) TorrentBroadcast.synchronized { SparkEnv.get.blockManager.putSingle( - pieceId, tInfo.arrayOfBlocks(i), StorageLevel.MEMORY_AND_DISK, true) + pieceId, tInfo.arrayOfBlocks(i), StorageLevel.MEMORY_AND_DISK, tellMaster = true) } } } - // Called by JVM when deserializing an object + /** Used by the JVM when serializing this object. */ + private def writeObject(out: ObjectOutputStream) { + assertValid() + out.defaultWriteObject() + } + + /** Used by the JVM when deserializing this object. */ private def readObject(in: ObjectInputStream) { in.defaultReadObject() TorrentBroadcast.synchronized { @@ -86,18 +125,22 @@ extends Broadcast[T](id) with Logging with Serializable { // Initialize @transient variables that will receive garbage values from the master. resetWorkerVariables() - if (receiveBroadcast(id)) { + if (receiveBroadcast()) { value_ = TorrentBroadcast.unBlockifyObject[T](arrayOfBlocks, totalBytes, totalBlocks) - // Store the merged copy in cache so that the next worker doesn't need to rebuild it. - // This creates a tradeoff between memory usage and latency. - // Storing copy doubles the memory footprint; not storing doubles deserialization cost. + /* Store the merged copy in cache so that the next worker doesn't need to rebuild it. + * This creates a trade-off between memory usage and latency. Storing copy doubles + * the memory footprint; not storing doubles deserialization cost. Also, + * this does not need to be reported to BlockManagerMaster since other executors + * does not need to access this block (they only need to fetch the chunks, + * which are reported). + */ SparkEnv.get.blockManager.putSingle( - broadcastId, value_, StorageLevel.MEMORY_AND_DISK, false) + broadcastId, value_, StorageLevel.MEMORY_AND_DISK, tellMaster = false) // Remove arrayOfBlocks from memory once value_ is on local cache resetWorkerVariables() - } else { + } else { logError("Reading broadcast variable " + id + " failed") } @@ -114,9 +157,10 @@ extends Broadcast[T](id) with Logging with Serializable { hasBlocks = 0 } - def receiveBroadcast(variableID: Long): Boolean = { - // Receive meta-info - val metaId = BroadcastHelperBlockId(broadcastId, "meta") + def receiveBroadcast(): Boolean = { + // Receive meta-info about the size of broadcast data, + // the number of chunks it is divided into, etc. + val metaId = BroadcastBlockId(id, "meta") var attemptId = 10 while (attemptId > 0 && totalBlocks == -1) { TorrentBroadcast.synchronized { @@ -138,17 +182,21 @@ extends Broadcast[T](id) with Logging with Serializable { return false } - // Receive actual blocks + /* + * Fetch actual chunks of data. Note that all these chunks are stored in + * the BlockManager and reported to the master, so that other executors + * can find out and pull the chunks from this executor. + */ val recvOrder = new Random().shuffle(Array.iterate(0, totalBlocks)(_ + 1).toList) for (pid <- recvOrder) { - val pieceId = BroadcastHelperBlockId(broadcastId, "piece" + pid) + val pieceId = BroadcastBlockId(id, "piece" + pid) TorrentBroadcast.synchronized { SparkEnv.get.blockManager.getSingle(pieceId) match { case Some(x) => arrayOfBlocks(pid) = x.asInstanceOf[TorrentBlock] hasBlocks += 1 SparkEnv.get.blockManager.putSingle( - pieceId, arrayOfBlocks(pid), StorageLevel.MEMORY_AND_DISK, true) + pieceId, arrayOfBlocks(pid), StorageLevel.MEMORY_AND_DISK, tellMaster = true) case None => throw new SparkException("Failed to get " + pieceId + " of " + broadcastId) @@ -156,16 +204,16 @@ extends Broadcast[T](id) with Logging with Serializable { } } - (hasBlocks == totalBlocks) + hasBlocks == totalBlocks } } -private object TorrentBroadcast -extends Logging { - +private[spark] object TorrentBroadcast extends Logging { + private lazy val BLOCK_SIZE = conf.getInt("spark.broadcast.blockSize", 4096) * 1024 private var initialized = false private var conf: SparkConf = null + def initialize(_isDriver: Boolean, conf: SparkConf) { TorrentBroadcast.conf = conf // TODO: we might have to fix it in tests synchronized { @@ -179,39 +227,37 @@ extends Logging { initialized = false } - lazy val BLOCK_SIZE = conf.getInt("spark.broadcast.blockSize", 4096) * 1024 - def blockifyObject[T](obj: T): TorrentInfo = { val byteArray = Utils.serialize[T](obj) val bais = new ByteArrayInputStream(byteArray) - var blockNum = (byteArray.length / BLOCK_SIZE) + var blockNum = byteArray.length / BLOCK_SIZE if (byteArray.length % BLOCK_SIZE != 0) { blockNum += 1 } - var retVal = new Array[TorrentBlock](blockNum) - var blockID = 0 + val blocks = new Array[TorrentBlock](blockNum) + var blockId = 0 for (i <- 0 until (byteArray.length, BLOCK_SIZE)) { val thisBlockSize = math.min(BLOCK_SIZE, byteArray.length - i) - var tempByteArray = new Array[Byte](thisBlockSize) - val hasRead = bais.read(tempByteArray, 0, thisBlockSize) + val tempByteArray = new Array[Byte](thisBlockSize) + bais.read(tempByteArray, 0, thisBlockSize) - retVal(blockID) = new TorrentBlock(blockID, tempByteArray) - blockID += 1 + blocks(blockId) = new TorrentBlock(blockId, tempByteArray) + blockId += 1 } bais.close() - val tInfo = TorrentInfo(retVal, blockNum, byteArray.length) - tInfo.hasBlocks = blockNum - - tInfo + val info = TorrentInfo(blocks, blockNum, byteArray.length) + info.hasBlocks = blockNum + info } - def unBlockifyObject[T](arrayOfBlocks: Array[TorrentBlock], - totalBytes: Int, - totalBlocks: Int): T = { + def unBlockifyObject[T]( + arrayOfBlocks: Array[TorrentBlock], + totalBytes: Int, + totalBlocks: Int): T = { val retByteArray = new Array[Byte](totalBytes) for (i <- 0 until totalBlocks) { System.arraycopy(arrayOfBlocks(i).byteArray, 0, retByteArray, @@ -220,6 +266,13 @@ extends Logging { Utils.deserialize[T](retByteArray, Thread.currentThread.getContextClassLoader) } + /** + * Remove all persisted blocks associated with this torrent broadcast on the executors. + * If removeFromDriver is true, also remove these persisted blocks on the driver. + */ + def unpersist(id: Long, removeFromDriver: Boolean, blocking: Boolean) = synchronized { + SparkEnv.get.blockManager.master.removeBroadcast(id, removeFromDriver, blocking) + } } private[spark] case class TorrentBlock( @@ -228,25 +281,10 @@ private[spark] case class TorrentBlock( extends Serializable private[spark] case class TorrentInfo( - @transient arrayOfBlocks : Array[TorrentBlock], + @transient arrayOfBlocks: Array[TorrentBlock], totalBlocks: Int, totalBytes: Int) extends Serializable { @transient var hasBlocks = 0 } - -/** - * A [[BroadcastFactory]] that creates a torrent-based implementation of broadcast. - */ -class TorrentBroadcastFactory extends BroadcastFactory { - - def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager) { - TorrentBroadcast.initialize(isDriver, conf) - } - - def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) = - new TorrentBroadcast[T](value_, isLocal, id) - - def stop() { TorrentBroadcast.stop() } -} diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala new file mode 100644 index 0000000000000..d216b58718148 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.broadcast + +import org.apache.spark.{SecurityManager, SparkConf} + +/** + * A [[org.apache.spark.broadcast.Broadcast]] implementation that uses a BitTorrent-like + * protocol to do a distributed transfer of the broadcasted data to the executors. Refer to + * [[org.apache.spark.broadcast.TorrentBroadcast]] for more details. + */ +class TorrentBroadcastFactory extends BroadcastFactory { + + def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager) { + TorrentBroadcast.initialize(isDriver, conf) + } + + def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) = + new TorrentBroadcast[T](value_, isLocal, id) + + def stop() { TorrentBroadcast.stop() } + + /** + * Remove all persisted state associated with the torrent broadcast with the given ID. + * @param removeFromDriver Whether to remove state from the driver. + * @param blocking Whether to block until unbroadcasted + */ + def unbroadcast(id: Long, removeFromDriver: Boolean, blocking: Boolean) { + TorrentBroadcast.unpersist(id, removeFromDriver, blocking) + } +} diff --git a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala index 6b0a972f0bbe0..bdf586351ac14 100644 --- a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala +++ b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala @@ -17,7 +17,6 @@ package org.apache.spark.network -import java.net._ import java.nio._ import java.nio.channels._ import java.nio.channels.spi._ diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index c43823bd769b7..bf3c57ad41eb2 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -138,6 +138,8 @@ abstract class RDD[T: ClassTag]( "Cannot change storage level of an RDD after it was already assigned a level") } sc.persistRDD(this) + // Register the RDD with the ContextCleaner for automatic GC-based cleanup + sc.cleaner.foreach(_.registerRDDForCleanup(this)) storageLevel = newLevel this } @@ -156,7 +158,7 @@ abstract class RDD[T: ClassTag]( */ def unpersist(blocking: Boolean = true): RDD[T] = { logInfo("Removing RDD " + id + " from persistence list") - sc.unpersistRDD(this, blocking) + sc.unpersistRDD(id, blocking) storageLevel = StorageLevel.NONE this } @@ -1141,5 +1143,4 @@ abstract class RDD[T: ClassTag]( def toJavaRDD() : JavaRDD[T] = { new JavaRDD(this)(elementClassTag) } - } diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 442a95bb2c44b..6368665f249ee 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -32,7 +32,7 @@ import org.apache.spark.executor.TaskMetrics import org.apache.spark.partial.{ApproximateActionListener, ApproximateEvaluator, PartialResult} import org.apache.spark.rdd.RDD import org.apache.spark.storage.{BlockId, BlockManager, BlockManagerMaster, RDDBlockId} -import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashMap, Utils} +import org.apache.spark.util.Utils /** * The high-level scheduling layer that implements stage-oriented scheduling. It computes a DAG of @@ -80,13 +80,13 @@ class DAGScheduler( private[scheduler] def numTotalJobs: Int = nextJobId.get() private val nextStageId = new AtomicInteger(0) - private[scheduler] val jobIdToStageIds = new TimeStampedHashMap[Int, HashSet[Int]] - private[scheduler] val stageIdToJobIds = new TimeStampedHashMap[Int, HashSet[Int]] - private[scheduler] val stageIdToStage = new TimeStampedHashMap[Int, Stage] - private[scheduler] val shuffleToMapStage = new TimeStampedHashMap[Int, Stage] + private[scheduler] val jobIdToStageIds = new HashMap[Int, HashSet[Int]] + private[scheduler] val stageIdToJobIds = new HashMap[Int, HashSet[Int]] + private[scheduler] val stageIdToStage = new HashMap[Int, Stage] + private[scheduler] val shuffleToMapStage = new HashMap[Int, Stage] private[scheduler] val jobIdToActiveJob = new HashMap[Int, ActiveJob] private[scheduler] val resultStageToJob = new HashMap[Stage, ActiveJob] - private[scheduler] val stageToInfos = new TimeStampedHashMap[Stage, StageInfo] + private[scheduler] val stageToInfos = new HashMap[Stage, StageInfo] // Stages we need to run whose parents aren't done private[scheduler] val waitingStages = new HashSet[Stage] @@ -98,7 +98,7 @@ class DAGScheduler( private[scheduler] val failedStages = new HashSet[Stage] // Missing tasks from each stage - private[scheduler] val pendingTasks = new TimeStampedHashMap[Stage, HashSet[Task[_]]] + private[scheduler] val pendingTasks = new HashMap[Stage, HashSet[Task[_]]] private[scheduler] val activeJobs = new HashSet[ActiveJob] @@ -113,9 +113,6 @@ class DAGScheduler( // stray messages to detect. private val failedEpoch = new HashMap[String, Long] - private val metadataCleaner = - new MetadataCleaner(MetadataCleanerType.DAG_SCHEDULER, this.cleanup, env.conf) - taskScheduler.setDAGScheduler(this) /** @@ -258,7 +255,7 @@ class DAGScheduler( : Stage = { val stage = newStage(rdd, numTasks, Some(shuffleDep), jobId, callSite) - if (mapOutputTracker.has(shuffleDep.shuffleId)) { + if (mapOutputTracker.containsShuffle(shuffleDep.shuffleId)) { val serLocs = mapOutputTracker.getSerializedMapOutputStatuses(shuffleDep.shuffleId) val locs = MapOutputTracker.deserializeMapStatuses(serLocs) for (i <- 0 until locs.size) { @@ -390,6 +387,9 @@ class DAGScheduler( stageIdToStage -= stageId stageIdToJobIds -= stageId + ShuffleMapTask.removeStage(stageId) + ResultTask.removeStage(stageId) + logDebug("After removal of stage %d, remaining stages = %d" .format(stageId, stageIdToStage.size)) } @@ -1084,26 +1084,10 @@ class DAGScheduler( Nil } - private def cleanup(cleanupTime: Long) { - Map( - "stageIdToStage" -> stageIdToStage, - "shuffleToMapStage" -> shuffleToMapStage, - "pendingTasks" -> pendingTasks, - "stageToInfos" -> stageToInfos, - "jobIdToStageIds" -> jobIdToStageIds, - "stageIdToJobIds" -> stageIdToJobIds). - foreach { case (s, t) => - val sizeBefore = t.size - t.clearOldValues(cleanupTime) - logInfo("%s %d --> %d".format(s, sizeBefore, t.size)) - } - } - def stop() { if (eventProcessActor != null) { eventProcessActor ! StopDAGScheduler } - metadataCleaner.cancel() taskScheduler.stop() } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala index 3fc6cc9850feb..083fb895d8696 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala @@ -20,21 +20,17 @@ package org.apache.spark.scheduler import java.io._ import java.util.zip.{GZIPInputStream, GZIPOutputStream} +import scala.collection.mutable.HashMap + import org.apache.spark._ -import org.apache.spark.rdd.RDD -import org.apache.spark.rdd.RDDCheckpointData -import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashMap} +import org.apache.spark.rdd.{RDD, RDDCheckpointData} private[spark] object ResultTask { // A simple map between the stage id to the serialized byte array of a task. // Served as a cache for task serialization because serialization can be // expensive on the master node if it needs to launch thousands of tasks. - val serializedInfoCache = new TimeStampedHashMap[Int, Array[Byte]] - - // TODO: This object shouldn't have global variables - val metadataCleaner = new MetadataCleaner( - MetadataCleanerType.RESULT_TASK, serializedInfoCache.clearOldValues, new SparkConf) + private val serializedInfoCache = new HashMap[Int, Array[Byte]] def serializeInfo(stageId: Int, rdd: RDD[_], func: (TaskContext, Iterator[_]) => _): Array[Byte] = { @@ -67,6 +63,10 @@ private[spark] object ResultTask { (rdd, func) } + def removeStage(stageId: Int) { + serializedInfoCache.remove(stageId) + } + def clearCache() { synchronized { serializedInfoCache.clear() diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala index 2a9edf4a76b97..23f3b3e824762 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala @@ -24,22 +24,16 @@ import scala.collection.mutable.HashMap import org.apache.spark._ import org.apache.spark.executor.ShuffleWriteMetrics -import org.apache.spark.rdd.RDD -import org.apache.spark.rdd.RDDCheckpointData +import org.apache.spark.rdd.{RDD, RDDCheckpointData} import org.apache.spark.serializer.Serializer import org.apache.spark.storage._ -import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashMap} private[spark] object ShuffleMapTask { // A simple map between the stage id to the serialized byte array of a task. // Served as a cache for task serialization because serialization can be // expensive on the master node if it needs to launch thousands of tasks. - val serializedInfoCache = new TimeStampedHashMap[Int, Array[Byte]] - - // TODO: This object shouldn't have global variables - val metadataCleaner = new MetadataCleaner( - MetadataCleanerType.SHUFFLE_MAP_TASK, serializedInfoCache.clearOldValues, new SparkConf) + private val serializedInfoCache = new HashMap[Int, Array[Byte]] def serializeInfo(stageId: Int, rdd: RDD[_], dep: ShuffleDependency[_,_]): Array[Byte] = { synchronized { @@ -80,6 +74,10 @@ private[spark] object ShuffleMapTask { HashMap(set.toSeq: _*) } + def removeStage(stageId: Int) { + serializedInfoCache.remove(stageId) + } + def clearCache() { synchronized { serializedInfoCache.clear() diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index a92922166f595..acd152dda89d4 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -42,7 +42,7 @@ import org.apache.spark.scheduler.SchedulingMode.SchedulingMode * * THREADING: SchedulerBackends and task-submitting clients can call this class from multiple * threads, so it needs locks in public API methods to maintain its state. In addition, some - * SchedulerBackends sycnchronize on themselves when they want to send events here, and then + * SchedulerBackends synchronize on themselves when they want to send events here, and then * acquire a lock on us, so we need to make sure that we don't try to lock the backend while * we are holding a lock on ourselves. */ diff --git a/core/src/main/scala/org/apache/spark/storage/BlockId.scala b/core/src/main/scala/org/apache/spark/storage/BlockId.scala index 301d784b350a3..cffea28fbf794 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockId.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockId.scala @@ -34,7 +34,7 @@ private[spark] sealed abstract class BlockId { def asRDDId = if (isRDD) Some(asInstanceOf[RDDBlockId]) else None def isRDD = isInstanceOf[RDDBlockId] def isShuffle = isInstanceOf[ShuffleBlockId] - def isBroadcast = isInstanceOf[BroadcastBlockId] || isInstanceOf[BroadcastHelperBlockId] + def isBroadcast = isInstanceOf[BroadcastBlockId] override def toString = name override def hashCode = name.hashCode @@ -48,18 +48,13 @@ private[spark] case class RDDBlockId(rddId: Int, splitIndex: Int) extends BlockI def name = "rdd_" + rddId + "_" + splitIndex } -private[spark] -case class ShuffleBlockId(shuffleId: Int, mapId: Int, reduceId: Int) extends BlockId { +private[spark] case class ShuffleBlockId(shuffleId: Int, mapId: Int, reduceId: Int) + extends BlockId { def name = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId } -private[spark] case class BroadcastBlockId(broadcastId: Long) extends BlockId { - def name = "broadcast_" + broadcastId -} - -private[spark] -case class BroadcastHelperBlockId(broadcastId: BroadcastBlockId, hType: String) extends BlockId { - def name = broadcastId.name + "_" + hType +private[spark] case class BroadcastBlockId(broadcastId: Long, field: String = "") extends BlockId { + def name = "broadcast_" + broadcastId + (if (field == "") "" else "_" + field) } private[spark] case class TaskResultBlockId(taskId: Long) extends BlockId { @@ -83,8 +78,7 @@ private[spark] case class TestBlockId(id: String) extends BlockId { private[spark] object BlockId { val RDD = "rdd_([0-9]+)_([0-9]+)".r val SHUFFLE = "shuffle_([0-9]+)_([0-9]+)_([0-9]+)".r - val BROADCAST = "broadcast_([0-9]+)".r - val BROADCAST_HELPER = "broadcast_([0-9]+)_([A-Za-z0-9]+)".r + val BROADCAST = "broadcast_([0-9]+)([_A-Za-z0-9]*)".r val TASKRESULT = "taskresult_([0-9]+)".r val STREAM = "input-([0-9]+)-([0-9]+)".r val TEST = "test_(.*)".r @@ -95,10 +89,8 @@ private[spark] object BlockId { RDDBlockId(rddId.toInt, splitIndex.toInt) case SHUFFLE(shuffleId, mapId, reduceId) => ShuffleBlockId(shuffleId.toInt, mapId.toInt, reduceId.toInt) - case BROADCAST(broadcastId) => - BroadcastBlockId(broadcastId.toLong) - case BROADCAST_HELPER(broadcastId, hType) => - BroadcastHelperBlockId(BroadcastBlockId(broadcastId.toLong), hType) + case BROADCAST(broadcastId, field) => + BroadcastBlockId(broadcastId.toLong, field.stripPrefix("_")) case TASKRESULT(taskId) => TaskResultBlockId(taskId.toLong) case STREAM(streamId, uniqueId) => diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 19138d9dde697..b021564477c47 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -19,20 +19,22 @@ package org.apache.spark.storage import java.io.{File, InputStream, OutputStream} import java.nio.{ByteBuffer, MappedByteBuffer} + import scala.collection.mutable.{ArrayBuffer, HashMap} import scala.concurrent.{Await, Future} import scala.concurrent.duration._ import scala.util.Random + import akka.actor.{ActorSystem, Cancellable, Props} import it.unimi.dsi.fastutil.io.{FastBufferedOutputStream, FastByteArrayOutputStream} import sun.nio.ch.DirectBuffer -import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkEnv, SparkException} + +import org.apache.spark.{Logging, MapOutputTracker, SecurityManager, SparkConf, SparkEnv, SparkException} import org.apache.spark.io.CompressionCodec import org.apache.spark.network._ import org.apache.spark.serializer.Serializer import org.apache.spark.util._ - sealed trait Values case class ByteBufferValues(buffer: ByteBuffer) extends Values @@ -46,7 +48,8 @@ private[spark] class BlockManager( val defaultSerializer: Serializer, maxMemory: Long, val conf: SparkConf, - securityManager: SecurityManager) + securityManager: SecurityManager, + mapOutputTracker: MapOutputTracker) extends Logging { val shuffleBlockManager = new ShuffleBlockManager(this) @@ -55,7 +58,7 @@ private[spark] class BlockManager( private val blockInfo = new TimeStampedHashMap[BlockId, BlockInfo] - private[storage] val memoryStore: BlockStore = new MemoryStore(this, maxMemory) + private[storage] val memoryStore = new MemoryStore(this, maxMemory) private[storage] val diskStore = new DiskStore(this, diskBlockManager) var tachyonInitialized = false private[storage] lazy val tachyonStore: TachyonStore = { @@ -98,7 +101,7 @@ private[spark] class BlockManager( val heartBeatFrequency = BlockManager.getHeartBeatFrequency(conf) - val slaveActor = actorSystem.actorOf(Props(new BlockManagerSlaveActor(this)), + val slaveActor = actorSystem.actorOf(Props(new BlockManagerSlaveActor(this, mapOutputTracker)), name = "BlockManagerActor" + BlockManager.ID_GENERATOR.next) // Pending re-registration action being executed asynchronously or null if none @@ -137,9 +140,10 @@ private[spark] class BlockManager( master: BlockManagerMaster, serializer: Serializer, conf: SparkConf, - securityManager: SecurityManager) = { + securityManager: SecurityManager, + mapOutputTracker: MapOutputTracker) = { this(execId, actorSystem, master, serializer, BlockManager.getMaxMemory(conf), - conf, securityManager) + conf, securityManager, mapOutputTracker) } /** @@ -217,9 +221,26 @@ private[spark] class BlockManager( } /** - * Get storage level of local block. If no info exists for the block, then returns null. + * Get the BlockStatus for the block identified by the given ID, if it exists. + * NOTE: This is mainly for testing, and it doesn't fetch information from Tachyon. + */ + def getStatus(blockId: BlockId): Option[BlockStatus] = { + blockInfo.get(blockId).map { info => + val memSize = if (memoryStore.contains(blockId)) memoryStore.getSize(blockId) else 0L + val diskSize = if (diskStore.contains(blockId)) diskStore.getSize(blockId) else 0L + // Assume that block is not in Tachyon + BlockStatus(info.level, memSize, diskSize, 0L) + } + } + + /** + * Get the ids of existing blocks that match the given filter. Note that this will + * query the blocks stored in the disk block manager (that the block manager + * may not know of). */ - def getLevel(blockId: BlockId): StorageLevel = blockInfo.get(blockId).map(_.level).orNull + def getMatchingBlockIds(filter: BlockId => Boolean): Seq[BlockId] = { + (blockInfo.keys ++ diskBlockManager.getAllBlocks()).filter(filter).toSeq + } /** * Tell the master about the current storage status of a block. This will send a block update @@ -525,9 +546,8 @@ private[spark] class BlockManager( /** * A short circuited method to get a block writer that can write data directly to disk. - * The Block will be appended to the File specified by filename. - * This is currently used for writing shuffle files out. Callers should handle error - * cases. + * The Block will be appended to the File specified by filename. This is currently used for + * writing shuffle files out. Callers should handle error cases. */ def getDiskWriter( blockId: BlockId, @@ -863,11 +883,22 @@ private[spark] class BlockManager( * @return The number of blocks removed. */ def removeRdd(rddId: Int): Int = { - // TODO: Instead of doing a linear scan on the blockInfo map, create another map that maps - // from RDD.id to blocks. + // TODO: Avoid a linear scan by creating another mapping of RDD.id to blocks. logInfo("Removing RDD " + rddId) val blocksToRemove = blockInfo.keys.flatMap(_.asRDDId).filter(_.rddId == rddId) - blocksToRemove.foreach(blockId => removeBlock(blockId, tellMaster = false)) + blocksToRemove.foreach { blockId => removeBlock(blockId, tellMaster = false) } + blocksToRemove.size + } + + /** + * Remove all blocks belonging to the given broadcast. + */ + def removeBroadcast(broadcastId: Long, tellMaster: Boolean): Int = { + logInfo("Removing broadcast " + broadcastId) + val blocksToRemove = blockInfo.keys.collect { + case bid @ BroadcastBlockId(`broadcastId`, _) => bid + } + blocksToRemove.foreach { blockId => removeBlock(blockId, tellMaster) } blocksToRemove.size } @@ -908,10 +939,10 @@ private[spark] class BlockManager( } private def dropOldBlocks(cleanupTime: Long, shouldDrop: (BlockId => Boolean)) { - val iterator = blockInfo.internalMap.entrySet().iterator() + val iterator = blockInfo.getEntrySet.iterator while (iterator.hasNext) { val entry = iterator.next() - val (id, info, time) = (entry.getKey, entry.getValue._1, entry.getValue._2) + val (id, info, time) = (entry.getKey, entry.getValue.value, entry.getValue.timestamp) if (time < cleanupTime && shouldDrop(id)) { info.synchronized { val level = info.level @@ -935,7 +966,7 @@ private[spark] class BlockManager( def shouldCompress(blockId: BlockId): Boolean = blockId match { case ShuffleBlockId(_, _, _) => compressShuffle - case BroadcastBlockId(_) => compressBroadcast + case BroadcastBlockId(_, _) => compressBroadcast case RDDBlockId(_, _) => compressRdds case TempBlockId(_) => compressShuffleSpill case _ => false diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala index 4bc1b407ad106..7897fade2df2b 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala @@ -81,6 +81,14 @@ class BlockManagerMaster(var driverActor: ActorRef, conf: SparkConf) extends Log askDriverWithReply[Seq[Seq[BlockManagerId]]](GetLocationsMultipleBlockIds(blockIds)) } + /** + * Check if block manager master has a block. Note that this can be used to check for only + * those blocks that are reported to block manager master. + */ + def contains(blockId: BlockId) = { + !getLocations(blockId).isEmpty + } + /** Get ids of other nodes in the cluster from the driver */ def getPeers(blockManagerId: BlockManagerId, numPeers: Int): Seq[BlockManagerId] = { val result = askDriverWithReply[Seq[BlockManagerId]](GetPeers(blockManagerId, numPeers)) @@ -99,12 +107,10 @@ class BlockManagerMaster(var driverActor: ActorRef, conf: SparkConf) extends Log askDriverWithReply(RemoveBlock(blockId)) } - /** - * Remove all blocks belonging to the given RDD. - */ + /** Remove all blocks belonging to the given RDD. */ def removeRdd(rddId: Int, blocking: Boolean) { val future = askDriverWithReply[Future[Seq[Int]]](RemoveRdd(rddId)) - future onFailure { + future.onFailure { case e: Throwable => logError("Failed to remove RDD " + rddId, e) } if (blocking) { @@ -112,6 +118,31 @@ class BlockManagerMaster(var driverActor: ActorRef, conf: SparkConf) extends Log } } + /** Remove all blocks belonging to the given shuffle. */ + def removeShuffle(shuffleId: Int, blocking: Boolean) { + val future = askDriverWithReply[Future[Seq[Boolean]]](RemoveShuffle(shuffleId)) + future.onFailure { + case e: Throwable => logError("Failed to remove shuffle " + shuffleId, e) + } + if (blocking) { + Await.result(future, timeout) + } + } + + /** Remove all blocks belonging to the given broadcast. */ + def removeBroadcast(broadcastId: Long, removeFromMaster: Boolean, blocking: Boolean) { + val future = askDriverWithReply[Future[Seq[Int]]]( + RemoveBroadcast(broadcastId, removeFromMaster)) + future.onFailure { + case e: Throwable => + logError("Failed to remove broadcast " + broadcastId + + " with removeFromMaster = " + removeFromMaster, e) + } + if (blocking) { + Await.result(future, timeout) + } + } + /** * Return the memory status for each block manager, in the form of a map from * the block manager's id to two long values. The first value is the maximum @@ -126,6 +157,51 @@ class BlockManagerMaster(var driverActor: ActorRef, conf: SparkConf) extends Log askDriverWithReply[Array[StorageStatus]](GetStorageStatus) } + /** + * Return the block's status on all block managers, if any. NOTE: This is a + * potentially expensive operation and should only be used for testing. + * + * If askSlaves is true, this invokes the master to query each block manager for the most + * updated block statuses. This is useful when the master is not informed of the given block + * by all block managers. + */ + def getBlockStatus( + blockId: BlockId, + askSlaves: Boolean = true): Map[BlockManagerId, BlockStatus] = { + val msg = GetBlockStatus(blockId, askSlaves) + /* + * To avoid potential deadlocks, the use of Futures is necessary, because the master actor + * should not block on waiting for a block manager, which can in turn be waiting for the + * master actor for a response to a prior message. + */ + val response = askDriverWithReply[Map[BlockManagerId, Future[Option[BlockStatus]]]](msg) + val (blockManagerIds, futures) = response.unzip + val result = Await.result(Future.sequence(futures), timeout) + if (result == null) { + throw new SparkException("BlockManager returned null for BlockStatus query: " + blockId) + } + val blockStatus = result.asInstanceOf[Iterable[Option[BlockStatus]]] + blockManagerIds.zip(blockStatus).flatMap { case (blockManagerId, status) => + status.map { s => (blockManagerId, s) } + }.toMap + } + + /** + * Return a list of ids of existing blocks such that the ids match the given filter. NOTE: This + * is a potentially expensive operation and should only be used for testing. + * + * If askSlaves is true, this invokes the master to query each block manager for the most + * updated block statuses. This is useful when the master is not informed of the given block + * by all block managers. + */ + def getMatchingBlockIds( + filter: BlockId => Boolean, + askSlaves: Boolean): Seq[BlockId] = { + val msg = GetMatchingBlockIds(filter, askSlaves) + val future = askDriverWithReply[Future[Seq[BlockId]]](msg) + Await.result(future, timeout) + } + /** Stop the driver actor, called only on the Spark driver node */ def stop() { if (driverActor != null) { diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala index 378f4cadc17d7..c57b6e8391b13 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala @@ -94,9 +94,21 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus case GetStorageStatus => sender ! storageStatus + case GetBlockStatus(blockId, askSlaves) => + sender ! blockStatus(blockId, askSlaves) + + case GetMatchingBlockIds(filter, askSlaves) => + sender ! getMatchingBlockIds(filter, askSlaves) + case RemoveRdd(rddId) => sender ! removeRdd(rddId) + case RemoveShuffle(shuffleId) => + sender ! removeShuffle(shuffleId) + + case RemoveBroadcast(broadcastId, removeFromDriver) => + sender ! removeBroadcast(broadcastId, removeFromDriver) + case RemoveBlock(blockId) => removeBlockFromWorkers(blockId) sender ! true @@ -140,9 +152,41 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus // The dispatcher is used as an implicit argument into the Future sequence construction. import context.dispatcher val removeMsg = RemoveRdd(rddId) - Future.sequence(blockManagerInfo.values.map { bm => - bm.slaveActor.ask(removeMsg)(akkaTimeout).mapTo[Int] - }.toSeq) + Future.sequence( + blockManagerInfo.values.map { bm => + bm.slaveActor.ask(removeMsg)(akkaTimeout).mapTo[Int] + }.toSeq + ) + } + + private def removeShuffle(shuffleId: Int): Future[Seq[Boolean]] = { + // Nothing to do in the BlockManagerMasterActor data structures + import context.dispatcher + val removeMsg = RemoveShuffle(shuffleId) + Future.sequence( + blockManagerInfo.values.map { bm => + bm.slaveActor.ask(removeMsg)(akkaTimeout).mapTo[Boolean] + }.toSeq + ) + } + + /** + * Delegate RemoveBroadcast messages to each BlockManager because the master may not notified + * of all broadcast blocks. If removeFromDriver is false, broadcast blocks are only removed + * from the executors, but not from the driver. + */ + private def removeBroadcast(broadcastId: Long, removeFromDriver: Boolean): Future[Seq[Int]] = { + // TODO: Consolidate usages of + import context.dispatcher + val removeMsg = RemoveBroadcast(broadcastId, removeFromDriver) + val requiredBlockManagers = blockManagerInfo.values.filter { info => + removeFromDriver || info.blockManagerId.executorId != "" + } + Future.sequence( + requiredBlockManagers.map { bm => + bm.slaveActor.ask(removeMsg)(akkaTimeout).mapTo[Int] + }.toSeq + ) } private def removeBlockManager(blockManagerId: BlockManagerId) { @@ -225,6 +269,61 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus }.toArray } + /** + * Return the block's status for all block managers, if any. NOTE: This is a + * potentially expensive operation and should only be used for testing. + * + * If askSlaves is true, the master queries each block manager for the most updated block + * statuses. This is useful when the master is not informed of the given block by all block + * managers. + */ + private def blockStatus( + blockId: BlockId, + askSlaves: Boolean): Map[BlockManagerId, Future[Option[BlockStatus]]] = { + import context.dispatcher + val getBlockStatus = GetBlockStatus(blockId) + /* + * Rather than blocking on the block status query, master actor should simply return + * Futures to avoid potential deadlocks. This can arise if there exists a block manager + * that is also waiting for this master actor's response to a previous message. + */ + blockManagerInfo.values.map { info => + val blockStatusFuture = + if (askSlaves) { + info.slaveActor.ask(getBlockStatus)(akkaTimeout).mapTo[Option[BlockStatus]] + } else { + Future { info.getStatus(blockId) } + } + (info.blockManagerId, blockStatusFuture) + }.toMap + } + + /** + * Return the ids of blocks present in all the block managers that match the given filter. + * NOTE: This is a potentially expensive operation and should only be used for testing. + * + * If askSlaves is true, the master queries each block manager for the most updated block + * statuses. This is useful when the master is not informed of the given block by all block + * managers. + */ + private def getMatchingBlockIds( + filter: BlockId => Boolean, + askSlaves: Boolean): Future[Seq[BlockId]] = { + import context.dispatcher + val getMatchingBlockIds = GetMatchingBlockIds(filter) + Future.sequence( + blockManagerInfo.values.map { info => + val future = + if (askSlaves) { + info.slaveActor.ask(getMatchingBlockIds)(akkaTimeout).mapTo[Seq[BlockId]] + } else { + Future { info.blocks.keys.filter(filter).toSeq } + } + future + } + ).map(_.flatten.toSeq) + } + private def register(id: BlockManagerId, maxMemSize: Long, slaveActor: ActorRef) { if (!blockManagerInfo.contains(id)) { blockManagerIdByExecutor.get(id.executorId) match { @@ -334,6 +433,8 @@ private[spark] class BlockManagerInfo( logInfo("Registering block manager %s with %s RAM".format( blockManagerId.hostPort, Utils.bytesToString(maxMem))) + def getStatus(blockId: BlockId) = Option(_blocks.get(blockId)) + def updateLastSeenMs() { _lastSeenMs = System.currentTimeMillis() } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala index 8a36b5cc42dfd..2b53bf33b5fba 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala @@ -34,6 +34,13 @@ private[storage] object BlockManagerMessages { // Remove all blocks belonging to a specific RDD. case class RemoveRdd(rddId: Int) extends ToBlockManagerSlave + // Remove all blocks belonging to a specific shuffle. + case class RemoveShuffle(shuffleId: Int) extends ToBlockManagerSlave + + // Remove all blocks belonging to a specific broadcast. + case class RemoveBroadcast(broadcastId: Long, removeFromDriver: Boolean = true) + extends ToBlockManagerSlave + ////////////////////////////////////////////////////////////////////////////////// // Messages from slaves to the master. @@ -80,7 +87,8 @@ private[storage] object BlockManagerMessages { } object UpdateBlockInfo { - def apply(blockManagerId: BlockManagerId, + def apply( + blockManagerId: BlockManagerId, blockId: BlockId, storageLevel: StorageLevel, memSize: Long, @@ -108,7 +116,13 @@ private[storage] object BlockManagerMessages { case object GetMemoryStatus extends ToBlockManagerMaster - case object ExpireDeadHosts extends ToBlockManagerMaster - case object GetStorageStatus extends ToBlockManagerMaster + + case class GetBlockStatus(blockId: BlockId, askSlaves: Boolean = true) + extends ToBlockManagerMaster + + case class GetMatchingBlockIds(filter: BlockId => Boolean, askSlaves: Boolean = true) + extends ToBlockManagerMaster + + case object ExpireDeadHosts extends ToBlockManagerMaster } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala index bcfb82d3c7336..6d4db064dff58 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala @@ -17,8 +17,11 @@ package org.apache.spark.storage -import akka.actor.Actor +import scala.concurrent.Future +import akka.actor.{ActorRef, Actor} + +import org.apache.spark.{Logging, MapOutputTracker} import org.apache.spark.storage.BlockManagerMessages._ /** @@ -26,14 +29,59 @@ import org.apache.spark.storage.BlockManagerMessages._ * this is used to remove blocks from the slave's BlockManager. */ private[storage] -class BlockManagerSlaveActor(blockManager: BlockManager) extends Actor { - override def receive = { +class BlockManagerSlaveActor( + blockManager: BlockManager, + mapOutputTracker: MapOutputTracker) + extends Actor with Logging { + + import context.dispatcher + // Operations that involve removing blocks may be slow and should be done asynchronously + override def receive = { case RemoveBlock(blockId) => - blockManager.removeBlock(blockId) + doAsync[Boolean]("removing block " + blockId, sender) { + blockManager.removeBlock(blockId) + true + } case RemoveRdd(rddId) => - val numBlocksRemoved = blockManager.removeRdd(rddId) - sender ! numBlocksRemoved + doAsync[Int]("removing RDD " + rddId, sender) { + blockManager.removeRdd(rddId) + } + + case RemoveShuffle(shuffleId) => + doAsync[Boolean]("removing shuffle " + shuffleId, sender) { + if (mapOutputTracker != null) { + mapOutputTracker.unregisterShuffle(shuffleId) + } + blockManager.shuffleBlockManager.removeShuffle(shuffleId) + } + + case RemoveBroadcast(broadcastId, tellMaster) => + doAsync[Int]("removing broadcast " + broadcastId, sender) { + blockManager.removeBroadcast(broadcastId, tellMaster) + } + + case GetBlockStatus(blockId, _) => + sender ! blockManager.getStatus(blockId) + + case GetMatchingBlockIds(filter, _) => + sender ! blockManager.getMatchingBlockIds(filter) + } + + private def doAsync[T](actionMessage: String, responseActor: ActorRef)(body: => T) { + val future = Future { + logDebug(actionMessage) + body + } + future.onSuccess { case response => + logDebug("Done " + actionMessage + ", response is " + response) + responseActor ! response + logDebug("Sent response: " + response + " to " + responseActor) + } + future.onFailure { case t: Throwable => + logError("Error in " + actionMessage, t) + responseActor ! null.asInstanceOf[T] + } } } diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala index f3e1c38744d78..7a24c8f57f43b 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala @@ -90,6 +90,20 @@ private[spark] class DiskBlockManager(shuffleManager: ShuffleBlockManager, rootD def getFile(blockId: BlockId): File = getFile(blockId.name) + /** Check if disk block manager has a block. */ + def containsBlock(blockId: BlockId): Boolean = { + getBlockLocation(blockId).file.exists() + } + + /** List all the blocks currently stored on disk by the disk manager. */ + def getAllBlocks(): Seq[BlockId] = { + // Get all the files inside the array of array of directories + subDirs.flatten.filter(_ != null).flatMap { dir => + val files = dir.list() + if (files != null) files else Seq.empty + }.map(BlockId.apply) + } + /** Produces a unique block id and File suitable for intermediate results. */ def createTempBlock(): (TempBlockId, File) = { var blockId = new TempBlockId(UUID.randomUUID()) diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala index bb07c8cb134cc..4cd4cdbd9909d 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala @@ -169,23 +169,43 @@ class ShuffleBlockManager(blockManager: BlockManager) extends Logging { throw new IllegalStateException("Failed to find shuffle block: " + id) } + /** Remove all the blocks / files and metadata related to a particular shuffle. */ + def removeShuffle(shuffleId: ShuffleId): Boolean = { + // Do not change the ordering of this, if shuffleStates should be removed only + // after the corresponding shuffle blocks have been removed + val cleaned = removeShuffleBlocks(shuffleId) + shuffleStates.remove(shuffleId) + cleaned + } + + /** Remove all the blocks / files related to a particular shuffle. */ + private def removeShuffleBlocks(shuffleId: ShuffleId): Boolean = { + shuffleStates.get(shuffleId) match { + case Some(state) => + if (consolidateShuffleFiles) { + for (fileGroup <- state.allFileGroups; file <- fileGroup.files) { + file.delete() + } + } else { + for (mapId <- state.completedMapTasks; reduceId <- 0 until state.numBuckets) { + val blockId = new ShuffleBlockId(shuffleId, mapId, reduceId) + blockManager.diskBlockManager.getFile(blockId).delete() + } + } + logInfo("Deleted all files for shuffle " + shuffleId) + true + case None => + logInfo("Could not find files for shuffle " + shuffleId + " for deleting") + false + } + } + private def physicalFileName(shuffleId: Int, bucketId: Int, fileId: Int) = { "merged_shuffle_%d_%d_%d".format(shuffleId, bucketId, fileId) } private def cleanup(cleanupTime: Long) { - shuffleStates.clearOldValues(cleanupTime, (shuffleId, state) => { - if (consolidateShuffleFiles) { - for (fileGroup <- state.allFileGroups; file <- fileGroup.files) { - file.delete() - } - } else { - for (mapId <- state.completedMapTasks; reduceId <- 0 until state.numBuckets) { - val blockId = new ShuffleBlockId(shuffleId, mapId, reduceId) - blockManager.diskBlockManager.getFile(blockId).delete() - } - } - }) + shuffleStates.clearOldValues(cleanupTime, (shuffleId, state) => removeShuffleBlocks(shuffleId)) } } diff --git a/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala b/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala index 226ed2a132b00..a107c5182b3be 100644 --- a/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala +++ b/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala @@ -22,7 +22,7 @@ import java.util.concurrent.ArrayBlockingQueue import akka.actor._ import util.Random -import org.apache.spark.{SecurityManager, SparkConf} +import org.apache.spark.{MapOutputTrackerMaster, SecurityManager, SparkConf} import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.serializer.KryoSerializer @@ -48,7 +48,7 @@ private[spark] object ThreadingTest { val block = (1 to blockSize).map(_ => Random.nextInt()) val level = randomLevel() val startTime = System.currentTimeMillis() - manager.put(blockId, block.iterator, level, true) + manager.put(blockId, block.iterator, level, tellMaster = true) println("Pushed block " + blockId + " in " + (System.currentTimeMillis - startTime) + " ms") queue.add((blockId, block)) } @@ -101,7 +101,7 @@ private[spark] object ThreadingTest { conf) val blockManager = new BlockManager( "", actorSystem, blockManagerMaster, serializer, 1024 * 1024, conf, - new SecurityManager(conf)) + new SecurityManager(conf), new MapOutputTrackerMaster(conf)) val producers = (1 to numProducers).map(i => new ProducerThread(blockManager, i)) val consumers = producers.map(p => new ConsumerThread(blockManager, p.queue)) producers.foreach(_.start) diff --git a/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala b/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala index 0448919e09161..7ebed5105b9fd 100644 --- a/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala +++ b/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala @@ -62,8 +62,8 @@ private[spark] class MetadataCleaner( private[spark] object MetadataCleanerType extends Enumeration { - val MAP_OUTPUT_TRACKER, SPARK_CONTEXT, HTTP_BROADCAST, DAG_SCHEDULER, RESULT_TASK, - SHUFFLE_MAP_TASK, BLOCK_MANAGER, SHUFFLE_BLOCK_MANAGER, BROADCAST_VARS = Value + val MAP_OUTPUT_TRACKER, SPARK_CONTEXT, HTTP_BROADCAST, BLOCK_MANAGER, + SHUFFLE_BLOCK_MANAGER, BROADCAST_VARS = Value type MetadataCleanerType = Value @@ -78,15 +78,16 @@ private[spark] object MetadataCleaner { conf.getInt("spark.cleaner.ttl", -1) } - def getDelaySeconds(conf: SparkConf, cleanerType: MetadataCleanerType.MetadataCleanerType): Int = - { - conf.get(MetadataCleanerType.systemProperty(cleanerType), getDelaySeconds(conf).toString) - .toInt + def getDelaySeconds( + conf: SparkConf, + cleanerType: MetadataCleanerType.MetadataCleanerType): Int = { + conf.get(MetadataCleanerType.systemProperty(cleanerType), getDelaySeconds(conf).toString).toInt } - def setDelaySeconds(conf: SparkConf, cleanerType: MetadataCleanerType.MetadataCleanerType, - delay: Int) - { + def setDelaySeconds( + conf: SparkConf, + cleanerType: MetadataCleanerType.MetadataCleanerType, + delay: Int) { conf.set(MetadataCleanerType.systemProperty(cleanerType), delay.toString) } diff --git a/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala b/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala index ddbd084ed7f01..8de75ba9a9c92 100644 --- a/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala +++ b/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala @@ -17,48 +17,54 @@ package org.apache.spark.util +import java.util.Set +import java.util.Map.Entry import java.util.concurrent.ConcurrentHashMap -import scala.collection.JavaConversions -import scala.collection.immutable -import scala.collection.mutable.Map +import scala.collection.{JavaConversions, mutable} import org.apache.spark.Logging +private[spark] case class TimeStampedValue[V](value: V, timestamp: Long) + /** * This is a custom implementation of scala.collection.mutable.Map which stores the insertion * timestamp along with each key-value pair. If specified, the timestamp of each pair can be * updated every time it is accessed. Key-value pairs whose timestamp are older than a particular * threshold time can then be removed using the clearOldValues method. This is intended to * be a drop-in replacement of scala.collection.mutable.HashMap. - * @param updateTimeStampOnGet When enabled, the timestamp of a pair will be - * updated when it is accessed + * + * @param updateTimeStampOnGet Whether timestamp of a pair will be updated when it is accessed */ -class TimeStampedHashMap[A, B](updateTimeStampOnGet: Boolean = false) - extends Map[A, B]() with Logging { - val internalMap = new ConcurrentHashMap[A, (B, Long)]() +private[spark] class TimeStampedHashMap[A, B](updateTimeStampOnGet: Boolean = false) + extends mutable.Map[A, B]() with Logging { + + private val internalMap = new ConcurrentHashMap[A, TimeStampedValue[B]]() def get(key: A): Option[B] = { val value = internalMap.get(key) if (value != null && updateTimeStampOnGet) { - internalMap.replace(key, value, (value._1, currentTime)) + internalMap.replace(key, value, TimeStampedValue(value.value, currentTime)) } - Option(value).map(_._1) + Option(value).map(_.value) } def iterator: Iterator[(A, B)] = { - val jIterator = internalMap.entrySet().iterator() - JavaConversions.asScalaIterator(jIterator).map(kv => (kv.getKey, kv.getValue._1)) + val jIterator = getEntrySet.iterator + JavaConversions.asScalaIterator(jIterator).map(kv => (kv.getKey, kv.getValue.value)) } - override def + [B1 >: B](kv: (A, B1)): Map[A, B1] = { + def getEntrySet: Set[Entry[A, TimeStampedValue[B]]] = internalMap.entrySet + + override def + [B1 >: B](kv: (A, B1)): mutable.Map[A, B1] = { val newMap = new TimeStampedHashMap[A, B1] - newMap.internalMap.putAll(this.internalMap) - newMap.internalMap.put(kv._1, (kv._2, currentTime)) + val oldInternalMap = this.internalMap.asInstanceOf[ConcurrentHashMap[A, TimeStampedValue[B1]]] + newMap.internalMap.putAll(oldInternalMap) + kv match { case (a, b) => newMap.internalMap.put(a, TimeStampedValue(b, currentTime)) } newMap } - override def - (key: A): Map[A, B] = { + override def - (key: A): mutable.Map[A, B] = { val newMap = new TimeStampedHashMap[A, B] newMap.internalMap.putAll(this.internalMap) newMap.internalMap.remove(key) @@ -66,17 +72,10 @@ class TimeStampedHashMap[A, B](updateTimeStampOnGet: Boolean = false) } override def += (kv: (A, B)): this.type = { - internalMap.put(kv._1, (kv._2, currentTime)) + kv match { case (a, b) => internalMap.put(a, TimeStampedValue(b, currentTime)) } this } - // Should we return previous value directly or as Option ? - def putIfAbsent(key: A, value: B): Option[B] = { - val prev = internalMap.putIfAbsent(key, (value, currentTime)) - if (prev != null) Some(prev._1) else None - } - - override def -= (key: A): this.type = { internalMap.remove(key) this @@ -87,53 +86,65 @@ class TimeStampedHashMap[A, B](updateTimeStampOnGet: Boolean = false) } override def apply(key: A): B = { - val value = internalMap.get(key) - if (value == null) throw new NoSuchElementException() - value._1 + get(key).getOrElse { throw new NoSuchElementException() } } - override def filter(p: ((A, B)) => Boolean): Map[A, B] = { - JavaConversions.mapAsScalaConcurrentMap(internalMap).map(kv => (kv._1, kv._2._1)).filter(p) + override def filter(p: ((A, B)) => Boolean): mutable.Map[A, B] = { + JavaConversions.mapAsScalaConcurrentMap(internalMap) + .map { case (k, TimeStampedValue(v, t)) => (k, v) } + .filter(p) } - override def empty: Map[A, B] = new TimeStampedHashMap[A, B]() + override def empty: mutable.Map[A, B] = new TimeStampedHashMap[A, B]() override def size: Int = internalMap.size override def foreach[U](f: ((A, B)) => U) { - val iterator = internalMap.entrySet().iterator() - while(iterator.hasNext) { - val entry = iterator.next() - val kv = (entry.getKey, entry.getValue._1) + val it = getEntrySet.iterator + while(it.hasNext) { + val entry = it.next() + val kv = (entry.getKey, entry.getValue.value) f(kv) } } - def toMap: immutable.Map[A, B] = iterator.toMap + def putIfAbsent(key: A, value: B): Option[B] = { + val prev = internalMap.putIfAbsent(key, TimeStampedValue(value, currentTime)) + Option(prev).map(_.value) + } + + def putAll(map: Map[A, B]) { + map.foreach { case (k, v) => update(k, v) } + } + + def toMap: Map[A, B] = iterator.toMap - /** - * Removes old key-value pairs that have timestamp earlier than `threshTime`, - * calling the supplied function on each such entry before removing. - */ def clearOldValues(threshTime: Long, f: (A, B) => Unit) { - val iterator = internalMap.entrySet().iterator() - while (iterator.hasNext) { - val entry = iterator.next() - if (entry.getValue._2 < threshTime) { - f(entry.getKey, entry.getValue._1) + val it = getEntrySet.iterator + while (it.hasNext) { + val entry = it.next() + if (entry.getValue.timestamp < threshTime) { + f(entry.getKey, entry.getValue.value) logDebug("Removing key " + entry.getKey) - iterator.remove() + it.remove() } } } - /** - * Removes old key-value pairs that have timestamp earlier than `threshTime` - */ + /** Removes old key-value pairs that have timestamp earlier than `threshTime`. */ def clearOldValues(threshTime: Long) { clearOldValues(threshTime, (_, _) => ()) } - private def currentTime: Long = System.currentTimeMillis() + private def currentTime: Long = System.currentTimeMillis + // For testing + + def getTimeStampedValue(key: A): Option[TimeStampedValue[B]] = { + Option(internalMap.get(key)) + } + + def getTimestamp(key: A): Option[Long] = { + getTimeStampedValue(key).map(_.timestamp) + } } diff --git a/core/src/main/scala/org/apache/spark/util/TimeStampedWeakValueHashMap.scala b/core/src/main/scala/org/apache/spark/util/TimeStampedWeakValueHashMap.scala new file mode 100644 index 0000000000000..b65017d6806c6 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/TimeStampedWeakValueHashMap.scala @@ -0,0 +1,170 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import java.lang.ref.WeakReference +import java.util.concurrent.atomic.AtomicInteger + +import scala.collection.mutable + +import org.apache.spark.Logging + +/** + * A wrapper of TimeStampedHashMap that ensures the values are weakly referenced and timestamped. + * + * If the value is garbage collected and the weak reference is null, get() will return a + * non-existent value. These entries are removed from the map periodically (every N inserts), as + * their values are no longer strongly reachable. Further, key-value pairs whose timestamps are + * older than a particular threshold can be removed using the clearOldValues method. + * + * TimeStampedWeakValueHashMap exposes a scala.collection.mutable.Map interface, which allows it + * to be a drop-in replacement for Scala HashMaps. Internally, it uses a Java ConcurrentHashMap, + * so all operations on this HashMap are thread-safe. + * + * @param updateTimeStampOnGet Whether timestamp of a pair will be updated when it is accessed. + */ +private[spark] class TimeStampedWeakValueHashMap[A, B](updateTimeStampOnGet: Boolean = false) + extends mutable.Map[A, B]() with Logging { + + import TimeStampedWeakValueHashMap._ + + private val internalMap = new TimeStampedHashMap[A, WeakReference[B]](updateTimeStampOnGet) + private val insertCount = new AtomicInteger(0) + + /** Return a map consisting only of entries whose values are still strongly reachable. */ + private def nonNullReferenceMap = internalMap.filter { case (_, ref) => ref.get != null } + + def get(key: A): Option[B] = internalMap.get(key) + + def iterator: Iterator[(A, B)] = nonNullReferenceMap.iterator + + override def + [B1 >: B](kv: (A, B1)): mutable.Map[A, B1] = { + val newMap = new TimeStampedWeakValueHashMap[A, B1] + val oldMap = nonNullReferenceMap.asInstanceOf[mutable.Map[A, WeakReference[B1]]] + newMap.internalMap.putAll(oldMap.toMap) + newMap.internalMap += kv + newMap + } + + override def - (key: A): mutable.Map[A, B] = { + val newMap = new TimeStampedWeakValueHashMap[A, B] + newMap.internalMap.putAll(nonNullReferenceMap.toMap) + newMap.internalMap -= key + newMap + } + + override def += (kv: (A, B)): this.type = { + internalMap += kv + if (insertCount.incrementAndGet() % CLEAR_NULL_VALUES_INTERVAL == 0) { + clearNullValues() + } + this + } + + override def -= (key: A): this.type = { + internalMap -= key + this + } + + override def update(key: A, value: B) = this += ((key, value)) + + override def apply(key: A): B = internalMap.apply(key) + + override def filter(p: ((A, B)) => Boolean): mutable.Map[A, B] = nonNullReferenceMap.filter(p) + + override def empty: mutable.Map[A, B] = new TimeStampedWeakValueHashMap[A, B]() + + override def size: Int = internalMap.size + + override def foreach[U](f: ((A, B)) => U) = nonNullReferenceMap.foreach(f) + + def putIfAbsent(key: A, value: B): Option[B] = internalMap.putIfAbsent(key, value) + + def toMap: Map[A, B] = iterator.toMap + + /** Remove old key-value pairs with timestamps earlier than `threshTime`. */ + def clearOldValues(threshTime: Long) = internalMap.clearOldValues(threshTime) + + /** Remove entries with values that are no longer strongly reachable. */ + def clearNullValues() { + val it = internalMap.getEntrySet.iterator + while (it.hasNext) { + val entry = it.next() + if (entry.getValue.value.get == null) { + logDebug("Removing key " + entry.getKey + " because it is no longer strongly reachable.") + it.remove() + } + } + } + + // For testing + + def getTimestamp(key: A): Option[Long] = { + internalMap.getTimeStampedValue(key).map(_.timestamp) + } + + def getReference(key: A): Option[WeakReference[B]] = { + internalMap.getTimeStampedValue(key).map(_.value) + } +} + +/** + * Helper methods for converting to and from WeakReferences. + */ +private object TimeStampedWeakValueHashMap { + + // Number of inserts after which entries with null references are removed + val CLEAR_NULL_VALUES_INTERVAL = 100 + + /* Implicit conversion methods to WeakReferences. */ + + implicit def toWeakReference[V](v: V): WeakReference[V] = new WeakReference[V](v) + + implicit def toWeakReferenceTuple[K, V](kv: (K, V)): (K, WeakReference[V]) = { + kv match { case (k, v) => (k, toWeakReference(v)) } + } + + implicit def toWeakReferenceFunction[K, V, R](p: ((K, V)) => R): ((K, WeakReference[V])) => R = { + (kv: (K, WeakReference[V])) => p(kv) + } + + /* Implicit conversion methods from WeakReferences. */ + + implicit def fromWeakReference[V](ref: WeakReference[V]): V = ref.get + + implicit def fromWeakReferenceOption[V](v: Option[WeakReference[V]]): Option[V] = { + v match { + case Some(ref) => Option(fromWeakReference(ref)) + case None => None + } + } + + implicit def fromWeakReferenceTuple[K, V](kv: (K, WeakReference[V])): (K, V) = { + kv match { case (k, v) => (k, fromWeakReference(v)) } + } + + implicit def fromWeakReferenceIterator[K, V]( + it: Iterator[(K, WeakReference[V])]): Iterator[(K, V)] = { + it.map(fromWeakReferenceTuple) + } + + implicit def fromWeakReferenceMap[K, V]( + map: mutable.Map[K, WeakReference[V]]) : mutable.Map[K, V] = { + mutable.Map(map.mapValues(fromWeakReference).toSeq: _*) + } +} diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 4435b21a7505e..59da51f3e0297 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -499,10 +499,10 @@ private[spark] object Utils extends Logging { private val hostPortParseResults = new ConcurrentHashMap[String, (String, Int)]() def parseHostPort(hostPort: String): (String, Int) = { - { - // Check cache first. - val cached = hostPortParseResults.get(hostPort) - if (cached != null) return cached + // Check cache first. + val cached = hostPortParseResults.get(hostPort) + if (cached != null) { + return cached } val indx: Int = hostPort.lastIndexOf(':') diff --git a/core/src/test/scala/org/apache/spark/AkkaUtilsSuite.scala b/core/src/test/scala/org/apache/spark/AkkaUtilsSuite.scala index d2e303d81c4c8..c5f24c66ce0c1 100644 --- a/core/src/test/scala/org/apache/spark/AkkaUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/AkkaUtilsSuite.scala @@ -56,7 +56,7 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext { val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0, conf = conf, securityManager = securityManagerBad) - val slaveTracker = new MapOutputTracker(conf) + val slaveTracker = new MapOutputTrackerWorker(conf) val selection = slaveSystem.actorSelection( s"akka.tcp://spark@localhost:$boundPort/user/MapOutputTracker") val timeout = AkkaUtils.lookupTimeout(conf) @@ -93,7 +93,7 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext { val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0, conf = badconf, securityManager = securityManagerBad) - val slaveTracker = new MapOutputTracker(conf) + val slaveTracker = new MapOutputTrackerWorker(conf) val selection = slaveSystem.actorSelection( s"akka.tcp://spark@localhost:$boundPort/user/MapOutputTracker") val timeout = AkkaUtils.lookupTimeout(conf) @@ -147,7 +147,7 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext { val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0, conf = goodconf, securityManager = securityManagerGood) - val slaveTracker = new MapOutputTracker(conf) + val slaveTracker = new MapOutputTrackerWorker(conf) val selection = slaveSystem.actorSelection( s"akka.tcp://spark@localhost:$boundPort/user/MapOutputTracker") val timeout = AkkaUtils.lookupTimeout(conf) @@ -200,7 +200,7 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext { val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0, conf = badconf, securityManager = securityManagerBad) - val slaveTracker = new MapOutputTracker(conf) + val slaveTracker = new MapOutputTrackerWorker(conf) val selection = slaveSystem.actorSelection( s"akka.tcp://spark@localhost:$boundPort/user/MapOutputTracker") val timeout = AkkaUtils.lookupTimeout(conf) diff --git a/core/src/test/scala/org/apache/spark/BroadcastSuite.scala b/core/src/test/scala/org/apache/spark/BroadcastSuite.scala index 96ba3929c1685..c9936256a5b95 100644 --- a/core/src/test/scala/org/apache/spark/BroadcastSuite.scala +++ b/core/src/test/scala/org/apache/spark/BroadcastSuite.scala @@ -19,68 +19,297 @@ package org.apache.spark import org.scalatest.FunSuite -class BroadcastSuite extends FunSuite with LocalSparkContext { +import org.apache.spark.storage._ +import org.apache.spark.broadcast.{Broadcast, HttpBroadcast} +import org.apache.spark.storage.BroadcastBlockId +class BroadcastSuite extends FunSuite with LocalSparkContext { - override def afterEach() { - super.afterEach() - System.clearProperty("spark.broadcast.factory") - } + private val httpConf = broadcastConf("HttpBroadcastFactory") + private val torrentConf = broadcastConf("TorrentBroadcastFactory") test("Using HttpBroadcast locally") { - System.setProperty("spark.broadcast.factory", "org.apache.spark.broadcast.HttpBroadcastFactory") - sc = new SparkContext("local", "test") - val list = List(1, 2, 3, 4) - val listBroadcast = sc.broadcast(list) - val results = sc.parallelize(1 to 2).map(x => (x, listBroadcast.value.sum)) - assert(results.collect.toSet === Set((1, 10), (2, 10))) + sc = new SparkContext("local", "test", httpConf) + val list = List[Int](1, 2, 3, 4) + val broadcast = sc.broadcast(list) + val results = sc.parallelize(1 to 2).map(x => (x, broadcast.value.sum)) + assert(results.collect().toSet === Set((1, 10), (2, 10))) } test("Accessing HttpBroadcast variables from multiple threads") { - System.setProperty("spark.broadcast.factory", "org.apache.spark.broadcast.HttpBroadcastFactory") - sc = new SparkContext("local[10]", "test") - val list = List(1, 2, 3, 4) - val listBroadcast = sc.broadcast(list) - val results = sc.parallelize(1 to 10).map(x => (x, listBroadcast.value.sum)) - assert(results.collect.toSet === (1 to 10).map(x => (x, 10)).toSet) + sc = new SparkContext("local[10]", "test", httpConf) + val list = List[Int](1, 2, 3, 4) + val broadcast = sc.broadcast(list) + val results = sc.parallelize(1 to 10).map(x => (x, broadcast.value.sum)) + assert(results.collect().toSet === (1 to 10).map(x => (x, 10)).toSet) } test("Accessing HttpBroadcast variables in a local cluster") { - System.setProperty("spark.broadcast.factory", "org.apache.spark.broadcast.HttpBroadcastFactory") val numSlaves = 4 - sc = new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test") - val list = List(1, 2, 3, 4) - val listBroadcast = sc.broadcast(list) - val results = sc.parallelize(1 to numSlaves).map(x => (x, listBroadcast.value.sum)) - assert(results.collect.toSet === (1 to numSlaves).map(x => (x, 10)).toSet) + sc = new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test", httpConf) + val list = List[Int](1, 2, 3, 4) + val broadcast = sc.broadcast(list) + val results = sc.parallelize(1 to numSlaves).map(x => (x, broadcast.value.sum)) + assert(results.collect().toSet === (1 to numSlaves).map(x => (x, 10)).toSet) } test("Using TorrentBroadcast locally") { - System.setProperty("spark.broadcast.factory", "org.apache.spark.broadcast.TorrentBroadcastFactory") - sc = new SparkContext("local", "test") - val list = List(1, 2, 3, 4) - val listBroadcast = sc.broadcast(list) - val results = sc.parallelize(1 to 2).map(x => (x, listBroadcast.value.sum)) - assert(results.collect.toSet === Set((1, 10), (2, 10))) + sc = new SparkContext("local", "test", torrentConf) + val list = List[Int](1, 2, 3, 4) + val broadcast = sc.broadcast(list) + val results = sc.parallelize(1 to 2).map(x => (x, broadcast.value.sum)) + assert(results.collect().toSet === Set((1, 10), (2, 10))) } test("Accessing TorrentBroadcast variables from multiple threads") { - System.setProperty("spark.broadcast.factory", "org.apache.spark.broadcast.TorrentBroadcastFactory") - sc = new SparkContext("local[10]", "test") - val list = List(1, 2, 3, 4) - val listBroadcast = sc.broadcast(list) - val results = sc.parallelize(1 to 10).map(x => (x, listBroadcast.value.sum)) - assert(results.collect.toSet === (1 to 10).map(x => (x, 10)).toSet) + sc = new SparkContext("local[10]", "test", torrentConf) + val list = List[Int](1, 2, 3, 4) + val broadcast = sc.broadcast(list) + val results = sc.parallelize(1 to 10).map(x => (x, broadcast.value.sum)) + assert(results.collect().toSet === (1 to 10).map(x => (x, 10)).toSet) } test("Accessing TorrentBroadcast variables in a local cluster") { - System.setProperty("spark.broadcast.factory", "org.apache.spark.broadcast.TorrentBroadcastFactory") val numSlaves = 4 - sc = new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test") - val list = List(1, 2, 3, 4) - val listBroadcast = sc.broadcast(list) - val results = sc.parallelize(1 to numSlaves).map(x => (x, listBroadcast.value.sum)) - assert(results.collect.toSet === (1 to numSlaves).map(x => (x, 10)).toSet) + sc = new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test", torrentConf) + val list = List[Int](1, 2, 3, 4) + val broadcast = sc.broadcast(list) + val results = sc.parallelize(1 to numSlaves).map(x => (x, broadcast.value.sum)) + assert(results.collect().toSet === (1 to numSlaves).map(x => (x, 10)).toSet) + } + + test("Unpersisting HttpBroadcast on executors only in local mode") { + testUnpersistHttpBroadcast(distributed = false, removeFromDriver = false) + } + + test("Unpersisting HttpBroadcast on executors and driver in local mode") { + testUnpersistHttpBroadcast(distributed = false, removeFromDriver = true) + } + + test("Unpersisting HttpBroadcast on executors only in distributed mode") { + testUnpersistHttpBroadcast(distributed = true, removeFromDriver = false) + } + + test("Unpersisting HttpBroadcast on executors and driver in distributed mode") { + testUnpersistHttpBroadcast(distributed = true, removeFromDriver = true) + } + + test("Unpersisting TorrentBroadcast on executors only in local mode") { + testUnpersistTorrentBroadcast(distributed = false, removeFromDriver = false) + } + + test("Unpersisting TorrentBroadcast on executors and driver in local mode") { + testUnpersistTorrentBroadcast(distributed = false, removeFromDriver = true) + } + + test("Unpersisting TorrentBroadcast on executors only in distributed mode") { + testUnpersistTorrentBroadcast(distributed = true, removeFromDriver = false) + } + + test("Unpersisting TorrentBroadcast on executors and driver in distributed mode") { + testUnpersistTorrentBroadcast(distributed = true, removeFromDriver = true) + } + /** + * Verify the persistence of state associated with an HttpBroadcast in either local mode or + * local-cluster mode (when distributed = true). + * + * This test creates a broadcast variable, uses it on all executors, and then unpersists it. + * In between each step, this test verifies that the broadcast blocks and the broadcast file + * are present only on the expected nodes. + */ + private def testUnpersistHttpBroadcast(distributed: Boolean, removeFromDriver: Boolean) { + val numSlaves = if (distributed) 2 else 0 + + def getBlockIds(id: Long) = Seq[BroadcastBlockId](BroadcastBlockId(id)) + + // Verify that the broadcast file is created, and blocks are persisted only on the driver + def afterCreation(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) { + assert(blockIds.size === 1) + val statuses = bmm.getBlockStatus(blockIds.head, askSlaves = true) + assert(statuses.size === 1) + statuses.head match { case (bm, status) => + assert(bm.executorId === "", "Block should only be on the driver") + assert(status.storageLevel === StorageLevel.MEMORY_AND_DISK) + assert(status.memSize > 0, "Block should be in memory store on the driver") + assert(status.diskSize === 0, "Block should not be in disk store on the driver") + } + if (distributed) { + // this file is only generated in distributed mode + assert(HttpBroadcast.getFile(blockIds.head.broadcastId).exists, "Broadcast file not found!") + } + } + + // Verify that blocks are persisted in both the executors and the driver + def afterUsingBroadcast(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) { + assert(blockIds.size === 1) + val statuses = bmm.getBlockStatus(blockIds.head, askSlaves = true) + assert(statuses.size === numSlaves + 1) + statuses.foreach { case (_, status) => + assert(status.storageLevel === StorageLevel.MEMORY_AND_DISK) + assert(status.memSize > 0, "Block should be in memory store") + assert(status.diskSize === 0, "Block should not be in disk store") + } + } + + // Verify that blocks are unpersisted on all executors, and on all nodes if removeFromDriver + // is true. In the latter case, also verify that the broadcast file is deleted on the driver. + def afterUnpersist(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) { + assert(blockIds.size === 1) + val statuses = bmm.getBlockStatus(blockIds.head, askSlaves = true) + val expectedNumBlocks = if (removeFromDriver) 0 else 1 + val possiblyNot = if (removeFromDriver) "" else " not" + assert(statuses.size === expectedNumBlocks, + "Block should%s be unpersisted on the driver".format(possiblyNot)) + if (distributed && removeFromDriver) { + // this file is only generated in distributed mode + assert(!HttpBroadcast.getFile(blockIds.head.broadcastId).exists, + "Broadcast file should%s be deleted".format(possiblyNot)) + } + } + + testUnpersistBroadcast(distributed, numSlaves, httpConf, getBlockIds, afterCreation, + afterUsingBroadcast, afterUnpersist, removeFromDriver) + } + + /** + * Verify the persistence of state associated with an TorrentBroadcast in a local-cluster. + * + * This test creates a broadcast variable, uses it on all executors, and then unpersists it. + * In between each step, this test verifies that the broadcast blocks are present only on the + * expected nodes. + */ + private def testUnpersistTorrentBroadcast(distributed: Boolean, removeFromDriver: Boolean) { + val numSlaves = if (distributed) 2 else 0 + + def getBlockIds(id: Long) = { + val broadcastBlockId = BroadcastBlockId(id) + val metaBlockId = BroadcastBlockId(id, "meta") + // Assume broadcast value is small enough to fit into 1 piece + val pieceBlockId = BroadcastBlockId(id, "piece0") + if (distributed) { + // the metadata and piece blocks are generated only in distributed mode + Seq[BroadcastBlockId](broadcastBlockId, metaBlockId, pieceBlockId) + } else { + Seq[BroadcastBlockId](broadcastBlockId) + } + } + + // Verify that blocks are persisted only on the driver + def afterCreation(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) { + blockIds.foreach { blockId => + val statuses = bmm.getBlockStatus(blockIds.head, askSlaves = true) + assert(statuses.size === 1) + statuses.head match { case (bm, status) => + assert(bm.executorId === "", "Block should only be on the driver") + assert(status.storageLevel === StorageLevel.MEMORY_AND_DISK) + assert(status.memSize > 0, "Block should be in memory store on the driver") + assert(status.diskSize === 0, "Block should not be in disk store on the driver") + } + } + } + + // Verify that blocks are persisted in both the executors and the driver + def afterUsingBroadcast(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) { + blockIds.foreach { blockId => + val statuses = bmm.getBlockStatus(blockId, askSlaves = true) + if (blockId.field == "meta") { + // Meta data is only on the driver + assert(statuses.size === 1) + statuses.head match { case (bm, _) => assert(bm.executorId === "") } + } else { + // Other blocks are on both the executors and the driver + assert(statuses.size === numSlaves + 1, + blockId + " has " + statuses.size + " statuses: " + statuses.mkString(",")) + statuses.foreach { case (_, status) => + assert(status.storageLevel === StorageLevel.MEMORY_AND_DISK) + assert(status.memSize > 0, "Block should be in memory store") + assert(status.diskSize === 0, "Block should not be in disk store") + } + } + } + } + + // Verify that blocks are unpersisted on all executors, and on all nodes if removeFromDriver + // is true. + def afterUnpersist(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) { + val expectedNumBlocks = if (removeFromDriver) 0 else 1 + val possiblyNot = if (removeFromDriver) "" else " not" + blockIds.foreach { blockId => + val statuses = bmm.getBlockStatus(blockId, askSlaves = true) + assert(statuses.size === expectedNumBlocks, + "Block should%s be unpersisted on the driver".format(possiblyNot)) + } + } + + testUnpersistBroadcast(distributed, numSlaves, torrentConf, getBlockIds, afterCreation, + afterUsingBroadcast, afterUnpersist, removeFromDriver) + } + + /** + * This test runs in 4 steps: + * + * 1) Create broadcast variable, and verify that all state is persisted on the driver. + * 2) Use the broadcast variable on all executors, and verify that all state is persisted + * on both the driver and the executors. + * 3) Unpersist the broadcast, and verify that all state is removed where they should be. + * 4) [Optional] If removeFromDriver is false, we verify that the broadcast is re-usable. + */ + private def testUnpersistBroadcast( + distributed: Boolean, + numSlaves: Int, // used only when distributed = true + broadcastConf: SparkConf, + getBlockIds: Long => Seq[BroadcastBlockId], + afterCreation: (Seq[BroadcastBlockId], BlockManagerMaster) => Unit, + afterUsingBroadcast: (Seq[BroadcastBlockId], BlockManagerMaster) => Unit, + afterUnpersist: (Seq[BroadcastBlockId], BlockManagerMaster) => Unit, + removeFromDriver: Boolean) { + + sc = if (distributed) { + new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test", broadcastConf) + } else { + new SparkContext("local", "test", broadcastConf) + } + val blockManagerMaster = sc.env.blockManager.master + val list = List[Int](1, 2, 3, 4) + + // Create broadcast variable + val broadcast = sc.broadcast(list) + val blocks = getBlockIds(broadcast.id) + afterCreation(blocks, blockManagerMaster) + + // Use broadcast variable on all executors + val partitions = 10 + assert(partitions > numSlaves) + val results = sc.parallelize(1 to partitions, partitions).map(x => (x, broadcast.value.sum)) + assert(results.collect().toSet === (1 to partitions).map(x => (x, list.sum)).toSet) + afterUsingBroadcast(blocks, blockManagerMaster) + + // Unpersist broadcast + if (removeFromDriver) { + broadcast.destroy(blocking = true) + } else { + broadcast.unpersist(blocking = true) + } + afterUnpersist(blocks, blockManagerMaster) + + // If the broadcast is removed from driver, all subsequent uses of the broadcast variable + // should throw SparkExceptions. Otherwise, the result should be the same as before. + if (removeFromDriver) { + // Using this variable on the executors crashes them, which hangs the test. + // Instead, crash the driver by directly accessing the broadcast value. + intercept[SparkException] { broadcast.value } + intercept[SparkException] { broadcast.unpersist() } + intercept[SparkException] { broadcast.destroy(blocking = true) } + } else { + val results = sc.parallelize(1 to partitions, partitions).map(x => (x, broadcast.value.sum)) + assert(results.collect().toSet === (1 to partitions).map(x => (x, list.sum)).toSet) + } } + /** Helper method to create a SparkConf that uses the given broadcast factory. */ + private def broadcastConf(factoryName: String): SparkConf = { + val conf = new SparkConf + conf.set("spark.broadcast.factory", "org.apache.spark.broadcast.%s".format(factoryName)) + conf + } } diff --git a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala new file mode 100644 index 0000000000000..e50981cf6fb20 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala @@ -0,0 +1,415 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import java.lang.ref.WeakReference + +import scala.collection.mutable.{HashSet, SynchronizedSet} +import scala.util.Random + +import org.scalatest.{BeforeAndAfter, FunSuite} +import org.scalatest.concurrent.Eventually +import org.scalatest.concurrent.Eventually._ +import org.scalatest.time.SpanSugar._ + +import org.apache.spark.SparkContext._ +import org.apache.spark.rdd.RDD +import org.apache.spark.storage.{BlockId, BroadcastBlockId, RDDBlockId, ShuffleBlockId} + +class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkContext { + + implicit val defaultTimeout = timeout(10000 millis) + val conf = new SparkConf() + .setMaster("local[2]") + .setAppName("ContextCleanerSuite") + .set("spark.cleaner.referenceTracking.blocking", "true") + + before { + sc = new SparkContext(conf) + } + + after { + if (sc != null) { + sc.stop() + sc = null + } + } + + + test("cleanup RDD") { + val rdd = newRDD.persist() + val collected = rdd.collect().toList + val tester = new CleanerTester(sc, rddIds = Seq(rdd.id)) + + // Explicit cleanup + cleaner.doCleanupRDD(rdd.id, blocking = true) + tester.assertCleanup() + + // Verify that RDDs can be re-executed after cleaning up + assert(rdd.collect().toList === collected) + } + + test("cleanup shuffle") { + val (rdd, shuffleDeps) = newRDDWithShuffleDependencies + val collected = rdd.collect().toList + val tester = new CleanerTester(sc, shuffleIds = shuffleDeps.map(_.shuffleId)) + + // Explicit cleanup + shuffleDeps.foreach(s => cleaner.doCleanupShuffle(s.shuffleId, blocking = true)) + tester.assertCleanup() + + // Verify that shuffles can be re-executed after cleaning up + assert(rdd.collect().toList === collected) + } + + test("cleanup broadcast") { + val broadcast = newBroadcast + val tester = new CleanerTester(sc, broadcastIds = Seq(broadcast.id)) + + // Explicit cleanup + cleaner.doCleanupBroadcast(broadcast.id, blocking = true) + tester.assertCleanup() + } + + test("automatically cleanup RDD") { + var rdd = newRDD.persist() + rdd.count() + + // Test that GC does not cause RDD cleanup due to a strong reference + val preGCTester = new CleanerTester(sc, rddIds = Seq(rdd.id)) + runGC() + intercept[Exception] { + preGCTester.assertCleanup()(timeout(1000 millis)) + } + + // Test that GC causes RDD cleanup after dereferencing the RDD + val postGCTester = new CleanerTester(sc, rddIds = Seq(rdd.id)) + rdd = null // Make RDD out of scope + runGC() + postGCTester.assertCleanup() + } + + test("automatically cleanup shuffle") { + var rdd = newShuffleRDD + rdd.count() + + // Test that GC does not cause shuffle cleanup due to a strong reference + val preGCTester = new CleanerTester(sc, shuffleIds = Seq(0)) + runGC() + intercept[Exception] { + preGCTester.assertCleanup()(timeout(1000 millis)) + } + + // Test that GC causes shuffle cleanup after dereferencing the RDD + val postGCTester = new CleanerTester(sc, shuffleIds = Seq(0)) + rdd = null // Make RDD out of scope, so that corresponding shuffle goes out of scope + runGC() + postGCTester.assertCleanup() + } + + test("automatically cleanup broadcast") { + var broadcast = newBroadcast + + // Test that GC does not cause broadcast cleanup due to a strong reference + val preGCTester = new CleanerTester(sc, broadcastIds = Seq(broadcast.id)) + runGC() + intercept[Exception] { + preGCTester.assertCleanup()(timeout(1000 millis)) + } + + // Test that GC causes broadcast cleanup after dereferencing the broadcast variable + val postGCTester = new CleanerTester(sc, broadcastIds = Seq(broadcast.id)) + broadcast = null // Make broadcast variable out of scope + runGC() + postGCTester.assertCleanup() + } + + test("automatically cleanup RDD + shuffle + broadcast") { + val numRdds = 100 + val numBroadcasts = 4 // Broadcasts are more costly + val rddBuffer = (1 to numRdds).map(i => randomRdd).toBuffer + val broadcastBuffer = (1 to numBroadcasts).map(i => randomBroadcast).toBuffer + val rddIds = sc.persistentRdds.keys.toSeq + val shuffleIds = 0 until sc.newShuffleId + val broadcastIds = 0L until numBroadcasts + + val preGCTester = new CleanerTester(sc, rddIds, shuffleIds, broadcastIds) + runGC() + intercept[Exception] { + preGCTester.assertCleanup()(timeout(1000 millis)) + } + + // Test that GC triggers the cleanup of all variables after the dereferencing them + val postGCTester = new CleanerTester(sc, rddIds, shuffleIds, broadcastIds) + broadcastBuffer.clear() + rddBuffer.clear() + runGC() + postGCTester.assertCleanup() + } + + test("automatically cleanup RDD + shuffle + broadcast in distributed mode") { + sc.stop() + + val conf2 = new SparkConf() + .setMaster("local-cluster[2, 1, 512]") + .setAppName("ContextCleanerSuite") + .set("spark.cleaner.referenceTracking.blocking", "true") + sc = new SparkContext(conf2) + + val numRdds = 10 + val numBroadcasts = 4 // Broadcasts are more costly + val rddBuffer = (1 to numRdds).map(i => randomRdd).toBuffer + val broadcastBuffer = (1 to numBroadcasts).map(i => randomBroadcast).toBuffer + val rddIds = sc.persistentRdds.keys.toSeq + val shuffleIds = 0 until sc.newShuffleId + val broadcastIds = 0L until numBroadcasts + + val preGCTester = new CleanerTester(sc, rddIds, shuffleIds, broadcastIds) + runGC() + intercept[Exception] { + preGCTester.assertCleanup()(timeout(1000 millis)) + } + + // Test that GC triggers the cleanup of all variables after the dereferencing them + val postGCTester = new CleanerTester(sc, rddIds, shuffleIds, broadcastIds) + broadcastBuffer.clear() + rddBuffer.clear() + runGC() + postGCTester.assertCleanup() + } + + //------ Helper functions ------ + + def newRDD = sc.makeRDD(1 to 10) + def newPairRDD = newRDD.map(_ -> 1) + def newShuffleRDD = newPairRDD.reduceByKey(_ + _) + def newBroadcast = sc.broadcast(1 to 100) + def newRDDWithShuffleDependencies: (RDD[_], Seq[ShuffleDependency[_, _]]) = { + def getAllDependencies(rdd: RDD[_]): Seq[Dependency[_]] = { + rdd.dependencies ++ rdd.dependencies.flatMap { dep => + getAllDependencies(dep.rdd) + } + } + val rdd = newShuffleRDD + + // Get all the shuffle dependencies + val shuffleDeps = getAllDependencies(rdd) + .filter(_.isInstanceOf[ShuffleDependency[_, _]]) + .map(_.asInstanceOf[ShuffleDependency[_, _]]) + (rdd, shuffleDeps) + } + + def randomRdd = { + val rdd: RDD[_] = Random.nextInt(3) match { + case 0 => newRDD + case 1 => newShuffleRDD + case 2 => newPairRDD.join(newPairRDD) + } + if (Random.nextBoolean()) rdd.persist() + rdd.count() + rdd + } + + def randomBroadcast = { + sc.broadcast(Random.nextInt(Int.MaxValue)) + } + + /** Run GC and make sure it actually has run */ + def runGC() { + val weakRef = new WeakReference(new Object()) + val startTime = System.currentTimeMillis + System.gc() // Make a best effort to run the garbage collection. It *usually* runs GC. + // Wait until a weak reference object has been GCed + while(System.currentTimeMillis - startTime < 10000 && weakRef.get != null) { + System.gc() + Thread.sleep(200) + } + } + + def cleaner = sc.cleaner.get +} + + +/** Class to test whether RDDs, shuffles, etc. have been successfully cleaned. */ +class CleanerTester( + sc: SparkContext, + rddIds: Seq[Int] = Seq.empty, + shuffleIds: Seq[Int] = Seq.empty, + broadcastIds: Seq[Long] = Seq.empty) + extends Logging { + + val toBeCleanedRDDIds = new HashSet[Int] with SynchronizedSet[Int] ++= rddIds + val toBeCleanedShuffleIds = new HashSet[Int] with SynchronizedSet[Int] ++= shuffleIds + val toBeCleanedBroadcstIds = new HashSet[Long] with SynchronizedSet[Long] ++= broadcastIds + val isDistributed = !sc.isLocal + + val cleanerListener = new CleanerListener { + def rddCleaned(rddId: Int): Unit = { + toBeCleanedRDDIds -= rddId + logInfo("RDD "+ rddId + " cleaned") + } + + def shuffleCleaned(shuffleId: Int): Unit = { + toBeCleanedShuffleIds -= shuffleId + logInfo("Shuffle " + shuffleId + " cleaned") + } + + def broadcastCleaned(broadcastId: Long): Unit = { + toBeCleanedBroadcstIds -= broadcastId + logInfo("Broadcast" + broadcastId + " cleaned") + } + } + + val MAX_VALIDATION_ATTEMPTS = 10 + val VALIDATION_ATTEMPT_INTERVAL = 100 + + logInfo("Attempting to validate before cleanup:\n" + uncleanedResourcesToString) + preCleanupValidate() + sc.cleaner.get.attachListener(cleanerListener) + + /** Assert that all the stuff has been cleaned up */ + def assertCleanup()(implicit waitTimeout: Eventually.Timeout) { + try { + eventually(waitTimeout, interval(100 millis)) { + assert(isAllCleanedUp) + } + postCleanupValidate() + } finally { + logInfo("Resources left from cleaning up:\n" + uncleanedResourcesToString) + } + } + + /** Verify that RDDs, shuffles, etc. occupy resources */ + private def preCleanupValidate() { + assert(rddIds.nonEmpty || shuffleIds.nonEmpty || broadcastIds.nonEmpty, "Nothing to cleanup") + + // Verify the RDDs have been persisted and blocks are present + rddIds.foreach { rddId => + assert( + sc.persistentRdds.contains(rddId), + "RDD " + rddId + " have not been persisted, cannot start cleaner test" + ) + + assert( + !getRDDBlocks(rddId).isEmpty, + "Blocks of RDD " + rddId + " cannot be found in block manager, " + + "cannot start cleaner test" + ) + } + + // Verify the shuffle ids are registered and blocks are present + shuffleIds.foreach { shuffleId => + assert( + mapOutputTrackerMaster.containsShuffle(shuffleId), + "Shuffle " + shuffleId + " have not been registered, cannot start cleaner test" + ) + + assert( + !getShuffleBlocks(shuffleId).isEmpty, + "Blocks of shuffle " + shuffleId + " cannot be found in block manager, " + + "cannot start cleaner test" + ) + } + + // Verify that the broadcast blocks are present + broadcastIds.foreach { broadcastId => + assert( + !getBroadcastBlocks(broadcastId).isEmpty, + "Blocks of broadcast " + broadcastId + "cannot be found in block manager, " + + "cannot start cleaner test" + ) + } + } + + /** + * Verify that RDDs, shuffles, etc. do not occupy resources. Tests multiple times as there is + * as there is not guarantee on how long it will take clean up the resources. + */ + private def postCleanupValidate() { + // Verify the RDDs have been persisted and blocks are present + rddIds.foreach { rddId => + assert( + !sc.persistentRdds.contains(rddId), + "RDD " + rddId + " was not cleared from sc.persistentRdds" + ) + + assert( + getRDDBlocks(rddId).isEmpty, + "Blocks of RDD " + rddId + " were not cleared from block manager" + ) + } + + // Verify the shuffle ids are registered and blocks are present + shuffleIds.foreach { shuffleId => + assert( + !mapOutputTrackerMaster.containsShuffle(shuffleId), + "Shuffle " + shuffleId + " was not deregistered from map output tracker" + ) + + assert( + getShuffleBlocks(shuffleId).isEmpty, + "Blocks of shuffle " + shuffleId + " were not cleared from block manager" + ) + } + + // Verify that the broadcast blocks are present + broadcastIds.foreach { broadcastId => + assert( + getBroadcastBlocks(broadcastId).isEmpty, + "Blocks of broadcast " + broadcastId + " were not cleared from block manager" + ) + } + } + + private def uncleanedResourcesToString = { + s""" + |\tRDDs = ${toBeCleanedRDDIds.toSeq.sorted.mkString("[", ", ", "]")} + |\tShuffles = ${toBeCleanedShuffleIds.toSeq.sorted.mkString("[", ", ", "]")} + |\tBroadcasts = ${toBeCleanedBroadcstIds.toSeq.sorted.mkString("[", ", ", "]")} + """.stripMargin + } + + private def isAllCleanedUp = + toBeCleanedRDDIds.isEmpty && + toBeCleanedShuffleIds.isEmpty && + toBeCleanedBroadcstIds.isEmpty + + private def getRDDBlocks(rddId: Int): Seq[BlockId] = { + blockManager.master.getMatchingBlockIds( _ match { + case RDDBlockId(`rddId`, _) => true + case _ => false + }, askSlaves = true) + } + + private def getShuffleBlocks(shuffleId: Int): Seq[BlockId] = { + blockManager.master.getMatchingBlockIds( _ match { + case ShuffleBlockId(`shuffleId`, _, _) => true + case _ => false + }, askSlaves = true) + } + + private def getBroadcastBlocks(broadcastId: Long): Seq[BlockId] = { + blockManager.master.getMatchingBlockIds( _ match { + case BroadcastBlockId(`broadcastId`, _) => true + case _ => false + }, askSlaves = true) + } + + private def blockManager = sc.env.blockManager + private def mapOutputTrackerMaster = sc.env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster] +} diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index a5bd72eb0a122..6b2571cd9295e 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -57,12 +57,13 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext { tracker.stop() } - test("master register and fetch") { + test("master register shuffle and fetch") { val actorSystem = ActorSystem("test") val tracker = new MapOutputTrackerMaster(conf) tracker.trackerActor = actorSystem.actorOf(Props(new MapOutputTrackerMasterActor(tracker, conf))) tracker.registerShuffle(10, 2) + assert(tracker.containsShuffle(10)) val compressedSize1000 = MapOutputTracker.compressSize(1000L) val compressedSize10000 = MapOutputTracker.compressSize(10000L) val size1000 = MapOutputTracker.decompressSize(compressedSize1000) @@ -77,7 +78,25 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext { tracker.stop() } - test("master register and unregister and fetch") { + test("master register and unregister shuffle") { + val actorSystem = ActorSystem("test") + val tracker = new MapOutputTrackerMaster(conf) + tracker.trackerActor = actorSystem.actorOf(Props(new MapOutputTrackerMasterActor(tracker, conf))) + tracker.registerShuffle(10, 2) + val compressedSize1000 = MapOutputTracker.compressSize(1000L) + val compressedSize10000 = MapOutputTracker.compressSize(10000L) + tracker.registerMapOutput(10, 0, new MapStatus(BlockManagerId("a", "hostA", 1000, 0), + Array(compressedSize1000, compressedSize10000))) + tracker.registerMapOutput(10, 1, new MapStatus(BlockManagerId("b", "hostB", 1000, 0), + Array(compressedSize10000, compressedSize1000))) + assert(tracker.containsShuffle(10)) + assert(tracker.getServerStatuses(10, 0).nonEmpty) + tracker.unregisterShuffle(10) + assert(!tracker.containsShuffle(10)) + assert(tracker.getServerStatuses(10, 0).isEmpty) + } + + test("master register shuffle and unregister map output and fetch") { val actorSystem = ActorSystem("test") val tracker = new MapOutputTrackerMaster(conf) tracker.trackerActor = @@ -114,7 +133,7 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext { val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0, conf = conf, securityManager = new SecurityManager(conf)) - val slaveTracker = new MapOutputTracker(conf) + val slaveTracker = new MapOutputTrackerWorker(conf) val selection = slaveSystem.actorSelection( s"akka.tcp://spark@localhost:$boundPort/user/MapOutputTracker") val timeout = AkkaUtils.lookupTimeout(conf) diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index b6dd0526105a0..e10ec7d2624a0 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -28,7 +28,7 @@ import org.scalatest.concurrent.Timeouts._ import org.scalatest.matchers.ShouldMatchers._ import org.scalatest.time.SpanSugar._ -import org.apache.spark.{SecurityManager, SparkConf} +import org.apache.spark.{MapOutputTrackerMaster, SecurityManager, SparkConf} import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} import org.apache.spark.util.{AkkaUtils, ByteBufferInputStream, SizeEstimator, Utils} @@ -42,6 +42,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT var oldArch: String = null conf.set("spark.authenticate", "false") val securityMgr = new SecurityManager(conf) + val mapOutputTracker = new MapOutputTrackerMaster(conf) // Reuse a serializer across tests to avoid creating a new thread-local buffer on each test conf.set("spark.kryoserializer.buffer.mb", "1") @@ -130,7 +131,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("master + 1 manager interaction") { - store = new BlockManager("", actorSystem, master, serializer, 2000, conf, securityMgr) + store = new BlockManager("", actorSystem, master, serializer, 2000, conf, + securityMgr, mapOutputTracker) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -160,9 +162,10 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("master + 2 managers interaction") { - store = new BlockManager("exec1", actorSystem, master, serializer, 2000, conf, securityMgr) + store = new BlockManager("exec1", actorSystem, master, serializer, 2000, conf, + securityMgr, mapOutputTracker) store2 = new BlockManager("exec2", actorSystem, master, new KryoSerializer(conf), 2000, conf, - securityMgr) + securityMgr, mapOutputTracker) val peers = master.getPeers(store.blockManagerId, 1) assert(peers.size === 1, "master did not return the other manager as a peer") @@ -177,7 +180,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("removing block") { - store = new BlockManager("", actorSystem, master, serializer, 2000, conf, securityMgr) + store = new BlockManager("", actorSystem, master, serializer, 2000, conf, + securityMgr, mapOutputTracker) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -225,7 +229,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("removing rdd") { - store = new BlockManager("", actorSystem, master, serializer, 2000, conf, securityMgr) + store = new BlockManager("", actorSystem, master, serializer, 2000, conf, + securityMgr, mapOutputTracker) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -257,9 +262,82 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT master.getLocations(rdd(0, 1)) should have size 0 } + test("removing broadcast") { + store = new BlockManager("", actorSystem, master, serializer, 2000, conf, + securityMgr, mapOutputTracker) + val driverStore = store + val executorStore = new BlockManager("executor", actorSystem, master, serializer, 2000, conf, + securityMgr, mapOutputTracker) + val a1 = new Array[Byte](400) + val a2 = new Array[Byte](400) + val a3 = new Array[Byte](400) + val a4 = new Array[Byte](400) + + val broadcast0BlockId = BroadcastBlockId(0) + val broadcast1BlockId = BroadcastBlockId(1) + val broadcast2BlockId = BroadcastBlockId(2) + val broadcast2BlockId2 = BroadcastBlockId(2, "_") + + // insert broadcast blocks in both the stores + Seq(driverStore, executorStore).foreach { case s => + s.putSingle(broadcast0BlockId, a1, StorageLevel.DISK_ONLY) + s.putSingle(broadcast1BlockId, a2, StorageLevel.DISK_ONLY) + s.putSingle(broadcast2BlockId, a3, StorageLevel.DISK_ONLY) + s.putSingle(broadcast2BlockId2, a4, StorageLevel.DISK_ONLY) + } + + // verify whether the blocks exist in both the stores + Seq(driverStore, executorStore).foreach { case s => + s.getLocal(broadcast0BlockId) should not be (None) + s.getLocal(broadcast1BlockId) should not be (None) + s.getLocal(broadcast2BlockId) should not be (None) + s.getLocal(broadcast2BlockId2) should not be (None) + } + + // remove broadcast 0 block only from executors + master.removeBroadcast(0, removeFromMaster = false, blocking = true) + + // only broadcast 0 block should be removed from the executor store + executorStore.getLocal(broadcast0BlockId) should be (None) + executorStore.getLocal(broadcast1BlockId) should not be (None) + executorStore.getLocal(broadcast2BlockId) should not be (None) + + // nothing should be removed from the driver store + driverStore.getLocal(broadcast0BlockId) should not be (None) + driverStore.getLocal(broadcast1BlockId) should not be (None) + driverStore.getLocal(broadcast2BlockId) should not be (None) + + // remove broadcast 0 block from the driver as well + master.removeBroadcast(0, removeFromMaster = true, blocking = true) + driverStore.getLocal(broadcast0BlockId) should be (None) + driverStore.getLocal(broadcast1BlockId) should not be (None) + + // remove broadcast 1 block from both the stores asynchronously + // and verify all broadcast 1 blocks have been removed + master.removeBroadcast(1, removeFromMaster = true, blocking = false) + eventually(timeout(1000 milliseconds), interval(10 milliseconds)) { + driverStore.getLocal(broadcast1BlockId) should be (None) + executorStore.getLocal(broadcast1BlockId) should be (None) + } + + // remove broadcast 2 from both the stores asynchronously + // and verify all broadcast 2 blocks have been removed + master.removeBroadcast(2, removeFromMaster = true, blocking = false) + eventually(timeout(1000 milliseconds), interval(10 milliseconds)) { + driverStore.getLocal(broadcast2BlockId) should be (None) + driverStore.getLocal(broadcast2BlockId2) should be (None) + executorStore.getLocal(broadcast2BlockId) should be (None) + executorStore.getLocal(broadcast2BlockId2) should be (None) + } + executorStore.stop() + driverStore.stop() + store = null + } + test("reregistration on heart beat") { val heartBeat = PrivateMethod[Unit]('heartBeat) - store = new BlockManager("", actorSystem, master, serializer, 2000, conf, securityMgr) + store = new BlockManager("", actorSystem, master, serializer, 2000, conf, + securityMgr, mapOutputTracker) val a1 = new Array[Byte](400) store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY) @@ -275,7 +353,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("reregistration on block update") { - store = new BlockManager("", actorSystem, master, serializer, 2000, conf, securityMgr) + store = new BlockManager("", actorSystem, master, serializer, 2000, conf, + securityMgr, mapOutputTracker) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) @@ -294,7 +373,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT test("reregistration doesn't dead lock") { val heartBeat = PrivateMethod[Unit]('heartBeat) - store = new BlockManager("", actorSystem, master, serializer, 2000, conf, securityMgr) + store = new BlockManager("", actorSystem, master, serializer, 2000, conf, + securityMgr, mapOutputTracker) val a1 = new Array[Byte](400) val a2 = List(new Array[Byte](400)) @@ -331,7 +411,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("in-memory LRU storage") { - store = new BlockManager("", actorSystem, master, serializer, 1200, conf, securityMgr) + store = new BlockManager("", actorSystem, master, serializer, 1200, conf, + securityMgr, mapOutputTracker) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -350,7 +431,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("in-memory LRU storage with serialization") { - store = new BlockManager("", actorSystem, master, serializer, 1200, conf, securityMgr) + store = new BlockManager("", actorSystem, master, serializer, 1200, conf, + securityMgr, mapOutputTracker) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -369,7 +451,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("in-memory LRU for partitions of same RDD") { - store = new BlockManager("", actorSystem, master, serializer, 1200, conf, securityMgr) + store = new BlockManager("", actorSystem, master, serializer, 1200, conf, + securityMgr, mapOutputTracker) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -388,7 +471,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("in-memory LRU for partitions of multiple RDDs") { - store = new BlockManager("", actorSystem, master, serializer, 1200, conf, securityMgr) + store = new BlockManager("", actorSystem, master, serializer, 1200, conf, + securityMgr, mapOutputTracker) store.putSingle(rdd(0, 1), new Array[Byte](400), StorageLevel.MEMORY_ONLY) store.putSingle(rdd(0, 2), new Array[Byte](400), StorageLevel.MEMORY_ONLY) store.putSingle(rdd(1, 1), new Array[Byte](400), StorageLevel.MEMORY_ONLY) @@ -414,7 +498,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT // TODO Make the spark.test.tachyon.enable true after using tachyon 0.5.0 testing jar. val tachyonUnitTestEnabled = conf.getBoolean("spark.test.tachyon.enable", false) if (tachyonUnitTestEnabled) { - store = new BlockManager("", actorSystem, master, serializer, 1200, conf, securityMgr) + store = new BlockManager("", actorSystem, master, serializer, 1200, conf, + securityMgr, mapOutputTracker) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -430,7 +515,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("on-disk storage") { - store = new BlockManager("", actorSystem, master, serializer, 1200, conf, securityMgr) + store = new BlockManager("", actorSystem, master, serializer, 1200, conf, + securityMgr, mapOutputTracker) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -443,7 +529,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("disk and memory storage") { - store = new BlockManager("", actorSystem, master, serializer, 1200, conf, securityMgr) + store = new BlockManager("", actorSystem, master, serializer, 1200, conf, + securityMgr, mapOutputTracker) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -458,7 +545,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("disk and memory storage with getLocalBytes") { - store = new BlockManager("", actorSystem, master, serializer, 1200, conf, securityMgr) + store = new BlockManager("", actorSystem, master, serializer, 1200, conf, + securityMgr, mapOutputTracker) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -473,7 +561,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("disk and memory storage with serialization") { - store = new BlockManager("", actorSystem, master, serializer, 1200, conf, securityMgr) + store = new BlockManager("", actorSystem, master, serializer, 1200, conf, + securityMgr, mapOutputTracker) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -488,7 +577,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("disk and memory storage with serialization and getLocalBytes") { - store = new BlockManager("", actorSystem, master, serializer, 1200, conf, securityMgr) + store = new BlockManager("", actorSystem, master, serializer, 1200, conf, + securityMgr, mapOutputTracker) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -503,7 +593,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("LRU with mixed storage levels") { - store = new BlockManager("", actorSystem, master, serializer, 1200, conf, securityMgr) + store = new BlockManager("", actorSystem, master, serializer, 1200, conf, + securityMgr, mapOutputTracker) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -525,7 +616,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("in-memory LRU with streams") { - store = new BlockManager("", actorSystem, master, serializer, 1200, conf, securityMgr) + store = new BlockManager("", actorSystem, master, serializer, 1200, conf, + securityMgr, mapOutputTracker) val list1 = List(new Array[Byte](200), new Array[Byte](200)) val list2 = List(new Array[Byte](200), new Array[Byte](200)) val list3 = List(new Array[Byte](200), new Array[Byte](200)) @@ -549,7 +641,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("LRU with mixed storage levels and streams") { - store = new BlockManager("", actorSystem, master, serializer, 1200, conf, securityMgr) + store = new BlockManager("", actorSystem, master, serializer, 1200, conf, + securityMgr, mapOutputTracker) val list1 = List(new Array[Byte](200), new Array[Byte](200)) val list2 = List(new Array[Byte](200), new Array[Byte](200)) val list3 = List(new Array[Byte](200), new Array[Byte](200)) @@ -595,7 +688,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("overly large block") { - store = new BlockManager("", actorSystem, master, serializer, 500, conf, securityMgr) + store = new BlockManager("", actorSystem, master, serializer, 500, conf, + securityMgr, mapOutputTracker) store.putSingle("a1", new Array[Byte](1000), StorageLevel.MEMORY_ONLY) assert(store.getSingle("a1") === None, "a1 was in store") store.putSingle("a2", new Array[Byte](1000), StorageLevel.MEMORY_AND_DISK) @@ -606,7 +700,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT test("block compression") { try { conf.set("spark.shuffle.compress", "true") - store = new BlockManager("exec1", actorSystem, master, serializer, 2000, conf, securityMgr) + store = new BlockManager("exec1", actorSystem, master, serializer, 2000, conf, + securityMgr, mapOutputTracker) store.putSingle(ShuffleBlockId(0, 0, 0), new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER) assert(store.memoryStore.getSize(ShuffleBlockId(0, 0, 0)) <= 100, "shuffle_0_0_0 was not compressed") @@ -614,7 +709,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT store = null conf.set("spark.shuffle.compress", "false") - store = new BlockManager("exec2", actorSystem, master, serializer, 2000, conf, securityMgr) + store = new BlockManager("exec2", actorSystem, master, serializer, 2000, conf, + securityMgr, mapOutputTracker) store.putSingle(ShuffleBlockId(0, 0, 0), new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER) assert(store.memoryStore.getSize(ShuffleBlockId(0, 0, 0)) >= 1000, "shuffle_0_0_0 was compressed") @@ -622,7 +718,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT store = null conf.set("spark.broadcast.compress", "true") - store = new BlockManager("exec3", actorSystem, master, serializer, 2000, conf, securityMgr) + store = new BlockManager("exec3", actorSystem, master, serializer, 2000, conf, + securityMgr, mapOutputTracker) store.putSingle(BroadcastBlockId(0), new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER) assert(store.memoryStore.getSize(BroadcastBlockId(0)) <= 100, "broadcast_0 was not compressed") @@ -630,28 +727,32 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT store = null conf.set("spark.broadcast.compress", "false") - store = new BlockManager("exec4", actorSystem, master, serializer, 2000, conf, securityMgr) + store = new BlockManager("exec4", actorSystem, master, serializer, 2000, conf, + securityMgr, mapOutputTracker) store.putSingle(BroadcastBlockId(0), new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER) assert(store.memoryStore.getSize(BroadcastBlockId(0)) >= 1000, "broadcast_0 was compressed") store.stop() store = null conf.set("spark.rdd.compress", "true") - store = new BlockManager("exec5", actorSystem, master, serializer, 2000, conf, securityMgr) + store = new BlockManager("exec5", actorSystem, master, serializer, 2000, conf, + securityMgr, mapOutputTracker) store.putSingle(rdd(0, 0), new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER) assert(store.memoryStore.getSize(rdd(0, 0)) <= 100, "rdd_0_0 was not compressed") store.stop() store = null conf.set("spark.rdd.compress", "false") - store = new BlockManager("exec6", actorSystem, master, serializer, 2000, conf, securityMgr) + store = new BlockManager("exec6", actorSystem, master, serializer, 2000, conf, + securityMgr, mapOutputTracker) store.putSingle(rdd(0, 0), new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER) assert(store.memoryStore.getSize(rdd(0, 0)) >= 1000, "rdd_0_0 was compressed") store.stop() store = null // Check that any other block types are also kept uncompressed - store = new BlockManager("exec7", actorSystem, master, serializer, 2000, conf, securityMgr) + store = new BlockManager("exec7", actorSystem, master, serializer, 2000, conf, + securityMgr, mapOutputTracker) store.putSingle("other_block", new Array[Byte](1000), StorageLevel.MEMORY_ONLY) assert(store.memoryStore.getSize("other_block") >= 1000, "other_block was compressed") store.stop() @@ -666,7 +767,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT test("block store put failure") { // Use Java serializer so we can create an unserializable error. store = new BlockManager("", actorSystem, master, new JavaSerializer(conf), 1200, conf, - securityMgr) + securityMgr, mapOutputTracker) // The put should fail since a1 is not serializable. class UnserializableClass @@ -682,7 +783,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("updated block statuses") { - store = new BlockManager("", actorSystem, master, serializer, 1200, conf, securityMgr) + store = new BlockManager("", actorSystem, master, serializer, 1200, conf, + securityMgr, mapOutputTracker) val list = List.fill(2)(new Array[Byte](200)) val bigList = List.fill(8)(new Array[Byte](200)) @@ -735,8 +837,83 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT assert(!store.get("list5").isDefined, "list5 was in store") } + test("query block statuses") { + store = new BlockManager("", actorSystem, master, serializer, 1200, conf, + securityMgr, mapOutputTracker) + val list = List.fill(2)(new Array[Byte](200)) + + // Tell master. By LRU, only list2 and list3 remains. + store.put("list1", list.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) + store.put("list2", list.iterator, StorageLevel.MEMORY_AND_DISK, tellMaster = true) + store.put("list3", list.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) + + // getLocations and getBlockStatus should yield the same locations + assert(store.master.getLocations("list1").size === 0) + assert(store.master.getLocations("list2").size === 1) + assert(store.master.getLocations("list3").size === 1) + assert(store.master.getBlockStatus("list1", askSlaves = false).size === 0) + assert(store.master.getBlockStatus("list2", askSlaves = false).size === 1) + assert(store.master.getBlockStatus("list3", askSlaves = false).size === 1) + assert(store.master.getBlockStatus("list1", askSlaves = true).size === 0) + assert(store.master.getBlockStatus("list2", askSlaves = true).size === 1) + assert(store.master.getBlockStatus("list3", askSlaves = true).size === 1) + + // This time don't tell master and see what happens. By LRU, only list5 and list6 remains. + store.put("list4", list.iterator, StorageLevel.MEMORY_ONLY, tellMaster = false) + store.put("list5", list.iterator, StorageLevel.MEMORY_AND_DISK, tellMaster = false) + store.put("list6", list.iterator, StorageLevel.MEMORY_ONLY, tellMaster = false) + + // getLocations should return nothing because the master is not informed + // getBlockStatus without asking slaves should have the same result + // getBlockStatus with asking slaves, however, should return the actual block statuses + assert(store.master.getLocations("list4").size === 0) + assert(store.master.getLocations("list5").size === 0) + assert(store.master.getLocations("list6").size === 0) + assert(store.master.getBlockStatus("list4", askSlaves = false).size === 0) + assert(store.master.getBlockStatus("list5", askSlaves = false).size === 0) + assert(store.master.getBlockStatus("list6", askSlaves = false).size === 0) + assert(store.master.getBlockStatus("list4", askSlaves = true).size === 0) + assert(store.master.getBlockStatus("list5", askSlaves = true).size === 1) + assert(store.master.getBlockStatus("list6", askSlaves = true).size === 1) + } + + test("get matching blocks") { + store = new BlockManager("", actorSystem, master, serializer, 1200, conf, + securityMgr, mapOutputTracker) + val list = List.fill(2)(new Array[Byte](10)) + + // insert some blocks + store.put("list1", list.iterator, StorageLevel.MEMORY_AND_DISK, tellMaster = true) + store.put("list2", list.iterator, StorageLevel.MEMORY_AND_DISK, tellMaster = true) + store.put("list3", list.iterator, StorageLevel.MEMORY_AND_DISK, tellMaster = true) + + // getLocations and getBlockStatus should yield the same locations + assert(store.master.getMatchingBlockIds(_.toString.contains("list"), askSlaves = false).size === 3) + assert(store.master.getMatchingBlockIds(_.toString.contains("list1"), askSlaves = false).size === 1) + + // insert some more blocks + store.put("newlist1", list.iterator, StorageLevel.MEMORY_AND_DISK, tellMaster = true) + store.put("newlist2", list.iterator, StorageLevel.MEMORY_AND_DISK, tellMaster = false) + store.put("newlist3", list.iterator, StorageLevel.MEMORY_AND_DISK, tellMaster = false) + + // getLocations and getBlockStatus should yield the same locations + assert(store.master.getMatchingBlockIds(_.toString.contains("newlist"), askSlaves = false).size === 1) + assert(store.master.getMatchingBlockIds(_.toString.contains("newlist"), askSlaves = true).size === 3) + + val blockIds = Seq(RDDBlockId(1, 0), RDDBlockId(1, 1), RDDBlockId(2, 0)) + blockIds.foreach { blockId => + store.put(blockId, list.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) + } + val matchedBlockIds = store.master.getMatchingBlockIds(_ match { + case RDDBlockId(1, _) => true + case _ => false + }, askSlaves = true) + assert(matchedBlockIds.toSet === Set(RDDBlockId(1, 0), RDDBlockId(1, 1))) + } + test("SPARK-1194 regression: fix the same-RDD rule for cache replacement") { - store = new BlockManager("", actorSystem, master, serializer, 1200, conf, securityMgr) + store = new BlockManager("", actorSystem, master, serializer, 1200, conf, + securityMgr, mapOutputTracker) store.putSingle(rdd(0, 0), new Array[Byte](400), StorageLevel.MEMORY_ONLY) store.putSingle(rdd(1, 0), new Array[Byte](400), StorageLevel.MEMORY_ONLY) // Access rdd_1_0 to ensure it's not least recently used. diff --git a/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala index 62f9b3cc7b2c1..808ddfdcf45d8 100644 --- a/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala @@ -59,8 +59,16 @@ class DiskBlockManagerSuite extends FunSuite with BeforeAndAfterEach { val newFile = diskBlockManager.getFile(blockId) writeToFile(newFile, 10) assertSegmentEquals(blockId, blockId.name, 0, 10) - + assert(diskBlockManager.containsBlock(blockId)) newFile.delete() + assert(!diskBlockManager.containsBlock(blockId)) + } + + test("enumerating blocks") { + val ids = (1 to 100).map(i => TestBlockId("test_" + i)) + val files = ids.map(id => diskBlockManager.getFile(id)) + files.foreach(file => writeToFile(file, 10)) + assert(diskBlockManager.getAllBlocks.toSet === ids.toSet) } test("block appending") { diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index 054eb01a64c11..7bab7da8fed68 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -108,8 +108,7 @@ class JsonProtocolSuite extends FunSuite { // BlockId testBlockId(RDDBlockId(1, 2)) testBlockId(ShuffleBlockId(1, 2, 3)) - testBlockId(BroadcastBlockId(1L)) - testBlockId(BroadcastHelperBlockId(BroadcastBlockId(2L), "Spark")) + testBlockId(BroadcastBlockId(1L, "insert_words_of_wisdom_here")) testBlockId(TaskResultBlockId(1L)) testBlockId(StreamBlockId(1, 2L)) } @@ -555,4 +554,4 @@ class JsonProtocolSuite extends FunSuite { {"Event":"SparkListenerUnpersistRDD","RDD ID":12345} """ - } +} diff --git a/core/src/test/scala/org/apache/spark/util/TimeStampedHashMapSuite.scala b/core/src/test/scala/org/apache/spark/util/TimeStampedHashMapSuite.scala new file mode 100644 index 0000000000000..6a5653ed2fb54 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/util/TimeStampedHashMapSuite.scala @@ -0,0 +1,264 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import java.lang.ref.WeakReference + +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer +import scala.util.Random + +import org.scalatest.FunSuite + +class TimeStampedHashMapSuite extends FunSuite { + + // Test the testMap function - a Scala HashMap should obviously pass + testMap(new mutable.HashMap[String, String]()) + + // Test TimeStampedHashMap basic functionality + testMap(new TimeStampedHashMap[String, String]()) + testMapThreadSafety(new TimeStampedHashMap[String, String]()) + + // Test TimeStampedWeakValueHashMap basic functionality + testMap(new TimeStampedWeakValueHashMap[String, String]()) + testMapThreadSafety(new TimeStampedWeakValueHashMap[String, String]()) + + test("TimeStampedHashMap - clearing by timestamp") { + // clearing by insertion time + val map = new TimeStampedHashMap[String, String](updateTimeStampOnGet = false) + map("k1") = "v1" + assert(map("k1") === "v1") + Thread.sleep(10) + val threshTime = System.currentTimeMillis + assert(map.getTimestamp("k1").isDefined) + assert(map.getTimestamp("k1").get < threshTime) + map.clearOldValues(threshTime) + assert(map.get("k1") === None) + + // clearing by modification time + val map1 = new TimeStampedHashMap[String, String](updateTimeStampOnGet = true) + map1("k1") = "v1" + map1("k2") = "v2" + assert(map1("k1") === "v1") + Thread.sleep(10) + val threshTime1 = System.currentTimeMillis + Thread.sleep(10) + assert(map1("k2") === "v2") // access k2 to update its access time to > threshTime + assert(map1.getTimestamp("k1").isDefined) + assert(map1.getTimestamp("k1").get < threshTime1) + assert(map1.getTimestamp("k2").isDefined) + assert(map1.getTimestamp("k2").get >= threshTime1) + map1.clearOldValues(threshTime1) //should only clear k1 + assert(map1.get("k1") === None) + assert(map1.get("k2").isDefined) + } + + test("TimeStampedWeakValueHashMap - clearing by timestamp") { + // clearing by insertion time + val map = new TimeStampedWeakValueHashMap[String, String](updateTimeStampOnGet = false) + map("k1") = "v1" + assert(map("k1") === "v1") + Thread.sleep(10) + val threshTime = System.currentTimeMillis + assert(map.getTimestamp("k1").isDefined) + assert(map.getTimestamp("k1").get < threshTime) + map.clearOldValues(threshTime) + assert(map.get("k1") === None) + + // clearing by modification time + val map1 = new TimeStampedWeakValueHashMap[String, String](updateTimeStampOnGet = true) + map1("k1") = "v1" + map1("k2") = "v2" + assert(map1("k1") === "v1") + Thread.sleep(10) + val threshTime1 = System.currentTimeMillis + Thread.sleep(10) + assert(map1("k2") === "v2") // access k2 to update its access time to > threshTime + assert(map1.getTimestamp("k1").isDefined) + assert(map1.getTimestamp("k1").get < threshTime1) + assert(map1.getTimestamp("k2").isDefined) + assert(map1.getTimestamp("k2").get >= threshTime1) + map1.clearOldValues(threshTime1) //should only clear k1 + assert(map1.get("k1") === None) + assert(map1.get("k2").isDefined) + } + + test("TimeStampedWeakValueHashMap - clearing weak references") { + var strongRef = new Object + val weakRef = new WeakReference(strongRef) + val map = new TimeStampedWeakValueHashMap[String, Object] + map("k1") = strongRef + map("k2") = "v2" + map("k3") = "v3" + assert(map("k1") === strongRef) + + // clear strong reference to "k1" + strongRef = null + val startTime = System.currentTimeMillis + System.gc() // Make a best effort to run the garbage collection. It *usually* runs GC. + System.runFinalization() // Make a best effort to call finalizer on all cleaned objects. + while(System.currentTimeMillis - startTime < 10000 && weakRef.get != null) { + System.gc() + System.runFinalization() + Thread.sleep(100) + } + assert(map.getReference("k1").isDefined) + val ref = map.getReference("k1").get + assert(ref.get === null) + assert(map.get("k1") === None) + + // operations should only display non-null entries + assert(map.iterator.forall { case (k, v) => k != "k1" }) + assert(map.filter { case (k, v) => k != "k2" }.size === 1) + assert(map.filter { case (k, v) => k != "k2" }.head._1 === "k3") + assert(map.toMap.size === 2) + assert(map.toMap.forall { case (k, v) => k != "k1" }) + val buffer = new ArrayBuffer[String] + map.foreach { case (k, v) => buffer += v.toString } + assert(buffer.size === 2) + assert(buffer.forall(_ != "k1")) + val plusMap = map + (("k4", "v4")) + assert(plusMap.size === 3) + assert(plusMap.forall { case (k, v) => k != "k1" }) + val minusMap = map - "k2" + assert(minusMap.size === 1) + assert(minusMap.head._1 == "k3") + + // clear null values - should only clear k1 + map.clearNullValues() + assert(map.getReference("k1") === None) + assert(map.get("k1") === None) + assert(map.get("k2").isDefined) + assert(map.get("k2").get === "v2") + } + + /** Test basic operations of a Scala mutable Map. */ + def testMap(hashMapConstructor: => mutable.Map[String, String]) { + def newMap() = hashMapConstructor + val testMap1 = newMap() + val testMap2 = newMap() + val name = testMap1.getClass.getSimpleName + + test(name + " - basic test") { + // put, get, and apply + testMap1 += (("k1", "v1")) + assert(testMap1.get("k1").isDefined) + assert(testMap1.get("k1").get === "v1") + testMap1("k2") = "v2" + assert(testMap1.get("k2").isDefined) + assert(testMap1.get("k2").get === "v2") + assert(testMap1("k2") === "v2") + testMap1.update("k3", "v3") + assert(testMap1.get("k3").isDefined) + assert(testMap1.get("k3").get === "v3") + + // remove + testMap1.remove("k1") + assert(testMap1.get("k1").isEmpty) + testMap1.remove("k2") + intercept[NoSuchElementException] { + testMap1("k2") // Map.apply() causes exception + } + testMap1 -= "k3" + assert(testMap1.get("k3").isEmpty) + + // multi put + val keys = (1 to 100).map(_.toString) + val pairs = keys.map(x => (x, x * 2)) + assert((testMap2 ++ pairs).iterator.toSet === pairs.toSet) + testMap2 ++= pairs + + // iterator + assert(testMap2.iterator.toSet === pairs.toSet) + + // filter + val filtered = testMap2.filter { case (_, v) => v.toInt % 2 == 0 } + val evenPairs = pairs.filter { case (_, v) => v.toInt % 2 == 0 } + assert(filtered.iterator.toSet === evenPairs.toSet) + + // foreach + val buffer = new ArrayBuffer[(String, String)] + testMap2.foreach(x => buffer += x) + assert(testMap2.toSet === buffer.toSet) + + // multi remove + testMap2("k1") = "v1" + testMap2 --= keys + assert(testMap2.size === 1) + assert(testMap2.iterator.toSeq.head === ("k1", "v1")) + + // + + val testMap3 = testMap2 + (("k0", "v0")) + assert(testMap3.size === 2) + assert(testMap3.get("k1").isDefined) + assert(testMap3.get("k1").get === "v1") + assert(testMap3.get("k0").isDefined) + assert(testMap3.get("k0").get === "v0") + + // - + val testMap4 = testMap3 - "k0" + assert(testMap4.size === 1) + assert(testMap4.get("k1").isDefined) + assert(testMap4.get("k1").get === "v1") + } + } + + /** Test thread safety of a Scala mutable map. */ + def testMapThreadSafety(hashMapConstructor: => mutable.Map[String, String]) { + def newMap() = hashMapConstructor + val name = newMap().getClass.getSimpleName + val testMap = newMap() + @volatile var error = false + + def getRandomKey(m: mutable.Map[String, String]): Option[String] = { + val keys = testMap.keysIterator.toSeq + if (keys.nonEmpty) { + Some(keys(Random.nextInt(keys.size))) + } else { + None + } + } + + val threads = (1 to 25).map(i => new Thread() { + override def run() { + try { + for (j <- 1 to 1000) { + Random.nextInt(3) match { + case 0 => + testMap(Random.nextString(10)) = Random.nextDouble().toString // put + case 1 => + getRandomKey(testMap).map(testMap.get) // get + case 2 => + getRandomKey(testMap).map(testMap.remove) // remove + } + } + } catch { + case t: Throwable => + error = true + throw t + } + } + }) + + test(name + " - threading safety test") { + threads.map(_.start) + threads.map(_.join) + assert(!error) + } + } +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala index d48b51aa69565..d043200f71a0b 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala @@ -341,9 +341,11 @@ abstract class DStream[T: ClassTag] ( */ private[streaming] def clearMetadata(time: Time) { val oldRDDs = generatedRDDs.filter(_._1 <= (time - rememberDuration)) + logDebug("Clearing references to old RDDs: [" + + oldRDDs.map(x => s"${x._1} -> ${x._2.id}").mkString(", ") + "]") generatedRDDs --= oldRDDs.keys if (ssc.conf.getBoolean("spark.streaming.unpersist", false)) { - logDebug("Unpersisting old RDDs: " + oldRDDs.keys.mkString(", ")) + logDebug("Unpersisting old RDDs: " + oldRDDs.values.map(_.id).mkString(", ")) oldRDDs.values.foreach(_.unpersist(false)) } logDebug("Cleared " + oldRDDs.size + " RDDs that were older than " +