From 1e752f1a5c5f3887df2ca20d63a9d30f1d32f9d1 Mon Sep 17 00:00:00 2001 From: Roman Pastukhov Date: Wed, 5 Feb 2014 20:11:56 +0400 Subject: [PATCH 01/14] Added unpersist method to Broadcast. --- .../scala/org/apache/spark/SparkContext.scala | 7 ++- .../apache/spark/broadcast/Broadcast.scala | 13 ++++- .../spark/broadcast/BroadcastFactory.scala | 2 +- .../spark/broadcast/HttpBroadcast.scala | 45 ++++++++++++----- .../spark/broadcast/TorrentBroadcast.scala | 43 +++++++++++++---- .../apache/spark/storage/BlockManager.scala | 12 +++++ .../apache/spark/storage/MemoryStore.scala | 31 +++++++----- .../org/apache/spark/BroadcastSuite.scala | 48 +++++++++++++++++++ 8 files changed, 163 insertions(+), 38 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 566472e597958..f42589c3900d0 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -613,8 +613,13 @@ class SparkContext( * Broadcast a read-only variable to the cluster, returning a * [[org.apache.spark.broadcast.Broadcast]] object for reading it in distributed functions. * The variable will be sent to each cluster only once. + * + * If `registerBlocks` is true, workers will notify driver about blocks they create + * and these blocks will be dropped when `unpersist` method of the broadcast variable is called. */ - def broadcast[T](value: T) = env.broadcastManager.newBroadcast[T](value, isLocal) + def broadcast[T](value: T, registerBlocks: Boolean = false) = { + env.broadcastManager.newBroadcast[T](value, isLocal, registerBlocks) + } /** * Add a file to be downloaded with this Spark job on every node. diff --git a/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala index d113d4040594d..076d98f8de991 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala @@ -53,6 +53,15 @@ import org.apache.spark._ abstract class Broadcast[T](val id: Long) extends Serializable { def value: T + /** + * Removes all blocks of this broadcast from memory (and disk if removeSource is true). + * + * @param removeSource Whether to remove data from disk as well. + * Will cause errors if broadcast is accessed on workers afterwards + * (e.g. in case of RDD re-computation due to executor failure). + */ + def unpersist(removeSource: Boolean = false) + // We cannot have an abstract readObject here due to some weird issues with // readObject having to be 'private' in sub-classes. @@ -91,8 +100,8 @@ class BroadcastManager(val _isDriver: Boolean, conf: SparkConf) extends Logging private val nextBroadcastId = new AtomicLong(0) - def newBroadcast[T](value_ : T, isLocal: Boolean) = - broadcastFactory.newBroadcast[T](value_, isLocal, nextBroadcastId.getAndIncrement()) + def newBroadcast[T](value_ : T, isLocal: Boolean, registerBlocks: Boolean) = + broadcastFactory.newBroadcast[T](value_, isLocal, nextBroadcastId.getAndIncrement(), registerBlocks) def isDriver = _isDriver } diff --git a/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala b/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala index 940e5ab805100..e38283f244ea1 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala @@ -27,6 +27,6 @@ import org.apache.spark.SparkConf */ trait BroadcastFactory { def initialize(isDriver: Boolean, conf: SparkConf): Unit - def newBroadcast[T](value: T, isLocal: Boolean, id: Long): Broadcast[T] + def newBroadcast[T](value: T, isLocal: Boolean, id: Long, registerBlocks: Boolean): Broadcast[T] def stop(): Unit } diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala index 39ee0dbb92841..53fcc2748b4e0 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala @@ -29,11 +29,20 @@ import org.apache.spark.io.CompressionCodec import org.apache.spark.storage.{BroadcastBlockId, StorageLevel} import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashSet, Utils} -private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long) +private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long, registerBlocks: Boolean) extends Broadcast[T](id) with Logging with Serializable { def value = value_ + def unpersist(removeSource: Boolean) { + SparkEnv.get.blockManager.master.removeBlock(blockId) + SparkEnv.get.blockManager.removeBlock(blockId) + + if (removeSource) { + HttpBroadcast.cleanupById(id) + } + } + def blockId = BroadcastBlockId(id) HttpBroadcast.synchronized { @@ -54,7 +63,7 @@ private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolea logInfo("Started reading broadcast variable " + id) val start = System.nanoTime value_ = HttpBroadcast.read[T](id) - SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, false) + SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, registerBlocks) val time = (System.nanoTime - start) / 1e9 logInfo("Reading broadcast variable " + id + " took " + time + " s") } @@ -69,8 +78,8 @@ private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolea class HttpBroadcastFactory extends BroadcastFactory { def initialize(isDriver: Boolean, conf: SparkConf) { HttpBroadcast.initialize(isDriver, conf) } - def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) = - new HttpBroadcast[T](value_, isLocal, id) + def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long, registerBlocks: Boolean) = + new HttpBroadcast[T](value_, isLocal, id, registerBlocks) def stop() { HttpBroadcast.stop() } } @@ -132,8 +141,10 @@ private object HttpBroadcast extends Logging { logInfo("Broadcast server started at " + serverUri) } + def getFile(id: Long) = new File(broadcastDir, BroadcastBlockId(id).name) + def write(id: Long, value: Any) { - val file = new File(broadcastDir, BroadcastBlockId(id).name) + val file = getFile(id) val out: OutputStream = { if (compress) { compressionCodec.compressedOutputStream(new FileOutputStream(file)) @@ -167,20 +178,30 @@ private object HttpBroadcast extends Logging { obj } + def deleteFile(fileName: String) { + try { + new File(fileName).delete() + logInfo("Deleted broadcast file '" + fileName + "'") + } catch { + case e: Exception => logWarning("Could not delete broadcast file '" + fileName + "'", e) + } + } + def cleanup(cleanupTime: Long) { val iterator = files.internalMap.entrySet().iterator() while(iterator.hasNext) { val entry = iterator.next() val (file, time) = (entry.getKey, entry.getValue) if (time < cleanupTime) { - try { - iterator.remove() - new File(file.toString).delete() - logInfo("Deleted broadcast file '" + file + "'") - } catch { - case e: Exception => logWarning("Could not delete broadcast file '" + file + "'", e) - } + iterator.remove() + deleteFile(file) } } } + + def cleanupById(id: Long) { + val file = getFile(id).getAbsolutePath + files.internalMap.remove(file) + deleteFile(file) + } } diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala index d351dfc1f56a2..11e74675491c6 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala @@ -23,16 +23,36 @@ import scala.math import scala.util.Random import org.apache.spark._ -import org.apache.spark.storage.{BroadcastBlockId, BroadcastHelperBlockId, StorageLevel} +import org.apache.spark.storage.{BlockId, BroadcastBlockId, BroadcastHelperBlockId, StorageLevel} import org.apache.spark.util.Utils -private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long) +private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long, registerBlocks: Boolean) extends Broadcast[T](id) with Logging with Serializable { def value = value_ + def unpersist(removeSource: Boolean) { + SparkEnv.get.blockManager.master.removeBlock(broadcastId) + SparkEnv.get.blockManager.removeBlock(broadcastId) + + if (removeSource) { + for (pid <- pieceIds) { + SparkEnv.get.blockManager.removeBlock(pieceBlockId(pid)) + } + SparkEnv.get.blockManager.removeBlock(metaId) + } else { + for (pid <- pieceIds) { + SparkEnv.get.blockManager.dropFromMemory(pieceBlockId(pid)) + } + SparkEnv.get.blockManager.dropFromMemory(metaId) + } + } + def broadcastId = BroadcastBlockId(id) + private def metaId = BroadcastHelperBlockId(broadcastId, "meta") + private def pieceBlockId(pid: Int) = BroadcastHelperBlockId(broadcastId, "piece" + pid) + private def pieceIds = Array.iterate(0, totalBlocks)(_ + 1).toList TorrentBroadcast.synchronized { SparkEnv.get.blockManager.putSingle(broadcastId, value_, StorageLevel.MEMORY_AND_DISK, false) @@ -55,7 +75,6 @@ extends Broadcast[T](id) with Logging with Serializable { hasBlocks = tInfo.totalBlocks // Store meta-info - val metaId = BroadcastHelperBlockId(broadcastId, "meta") val metaInfo = TorrentInfo(null, totalBlocks, totalBytes) TorrentBroadcast.synchronized { SparkEnv.get.blockManager.putSingle( @@ -64,7 +83,7 @@ extends Broadcast[T](id) with Logging with Serializable { // Store individual pieces for (i <- 0 until totalBlocks) { - val pieceId = BroadcastHelperBlockId(broadcastId, "piece" + i) + val pieceId = pieceBlockId(i) TorrentBroadcast.synchronized { SparkEnv.get.blockManager.putSingle( pieceId, tInfo.arrayOfBlocks(i), StorageLevel.MEMORY_AND_DISK, true) @@ -94,7 +113,7 @@ extends Broadcast[T](id) with Logging with Serializable { // This creates a tradeoff between memory usage and latency. // Storing copy doubles the memory footprint; not storing doubles deserialization cost. SparkEnv.get.blockManager.putSingle( - broadcastId, value_, StorageLevel.MEMORY_AND_DISK, false) + broadcastId, value_, StorageLevel.MEMORY_AND_DISK, registerBlocks) // Remove arrayOfBlocks from memory once value_ is on local cache resetWorkerVariables() @@ -109,6 +128,11 @@ extends Broadcast[T](id) with Logging with Serializable { } private def resetWorkerVariables() { + if (arrayOfBlocks != null) { + for (pid <- pieceIds) { + SparkEnv.get.blockManager.removeBlock(pieceBlockId(pid)) + } + } arrayOfBlocks = null totalBytes = -1 totalBlocks = -1 @@ -117,7 +141,6 @@ extends Broadcast[T](id) with Logging with Serializable { def receiveBroadcast(variableID: Long): Boolean = { // Receive meta-info - val metaId = BroadcastHelperBlockId(broadcastId, "meta") var attemptId = 10 while (attemptId > 0 && totalBlocks == -1) { TorrentBroadcast.synchronized { @@ -140,9 +163,9 @@ extends Broadcast[T](id) with Logging with Serializable { } // Receive actual blocks - val recvOrder = new Random().shuffle(Array.iterate(0, totalBlocks)(_ + 1).toList) + val recvOrder = new Random().shuffle(pieceIds) for (pid <- recvOrder) { - val pieceId = BroadcastHelperBlockId(broadcastId, "piece" + pid) + val pieceId = pieceBlockId(pid) TorrentBroadcast.synchronized { SparkEnv.get.blockManager.getSingle(pieceId) match { case Some(x) => @@ -243,8 +266,8 @@ class TorrentBroadcastFactory extends BroadcastFactory { def initialize(isDriver: Boolean, conf: SparkConf) { TorrentBroadcast.initialize(isDriver, conf) } - def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) = - new TorrentBroadcast[T](value_, isLocal, id) + def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long, registerBlocks: Boolean) = + new TorrentBroadcast[T](value_, isLocal, id, registerBlocks) def stop() { TorrentBroadcast.stop() } } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index ed53558566edf..f8c121615567f 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -196,6 +196,11 @@ private[spark] class BlockManager( } } + /** + * For testing. Returns number of blocks BlockManager knows about that are in memory. + */ + def numberOfBlocksInMemory() = blockInfo.keys.count(memoryStore.contains(_)) + /** * Get storage level of local block. If no info exists for the block, then returns null. */ @@ -720,6 +725,13 @@ private[spark] class BlockManager( } /** + * Drop a block from memory, possibly putting it on disk if applicable. + */ + def dropFromMemory(blockId: BlockId) { + memoryStore.asInstanceOf[MemoryStore].dropFromMemory(blockId) + } + + /** * Remove all blocks belonging to the given RDD. * @return The number of blocks removed. */ diff --git a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala index eb5a18521683e..4e47a06c1fed2 100644 --- a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala @@ -182,6 +182,24 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long) } } + /** + * Drop a block from memory, possibly putting it on disk if applicable. + */ + def dropFromMemory(blockId: BlockId) { + val entry = entries.synchronized { entries.get(blockId) } + // This should never be null as only one thread should be dropping + // blocks and removing entries. However the check is still here for + // future safety. + if (entry != null) { + val data = if (entry.deserialized) { + Left(entry.value.asInstanceOf[ArrayBuffer[Any]]) + } else { + Right(entry.value.asInstanceOf[ByteBuffer].duplicate()) + } + blockManager.dropFromMemory(blockId, data) + } + } + /** * Tries to free up a given amount of space to store a particular block, but can fail and return * false if either the block is bigger than our memory or it would require replacing another @@ -227,18 +245,7 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long) if (maxMemory - (currentMemory - selectedMemory) >= space) { logInfo(selectedBlocks.size + " blocks selected for dropping") for (blockId <- selectedBlocks) { - val entry = entries.synchronized { entries.get(blockId) } - // This should never be null as only one thread should be dropping - // blocks and removing entries. However the check is still here for - // future safety. - if (entry != null) { - val data = if (entry.deserialized) { - Left(entry.value.asInstanceOf[ArrayBuffer[Any]]) - } else { - Right(entry.value.asInstanceOf[ByteBuffer].duplicate()) - } - blockManager.dropFromMemory(blockId, data) - } + dropFromMemory(blockId) } return true } else { diff --git a/core/src/test/scala/org/apache/spark/BroadcastSuite.scala b/core/src/test/scala/org/apache/spark/BroadcastSuite.scala index e022accee6d08..a657753144b24 100644 --- a/core/src/test/scala/org/apache/spark/BroadcastSuite.scala +++ b/core/src/test/scala/org/apache/spark/BroadcastSuite.scala @@ -18,6 +18,11 @@ package org.apache.spark import org.scalatest.FunSuite +import org.scalatest.concurrent.Timeouts._ +import org.scalatest.time.{Millis, Span} +import org.scalatest.concurrent.Eventually._ +import org.scalatest.time.SpanSugar._ +import org.scalatest.matchers.ShouldMatchers._ class BroadcastSuite extends FunSuite with LocalSparkContext { @@ -82,4 +87,47 @@ class BroadcastSuite extends FunSuite with LocalSparkContext { assert(results.collect.toSet === (1 to numSlaves).map(x => (x, 10)).toSet) } + def blocksExist(sc: SparkContext, numSlaves: Int) = { + val rdd = sc.parallelize(1 to numSlaves, numSlaves) + val workerBlocks = rdd.mapPartitions(_ => { + val blocks = SparkEnv.get.blockManager.numberOfBlocksInMemory() + Seq(blocks).iterator + }) + val totalKnown = workerBlocks.reduce(_ + _) + sc.env.blockManager.numberOfBlocksInMemory() + + totalKnown > 0 + } + + def testUnpersist(bcFactory: String, removeSource: Boolean) { + test("Broadcast unpersist(" + removeSource + ") with " + bcFactory) { + val numSlaves = 2 + System.setProperty("spark.broadcast.factory", bcFactory) + sc = new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test") + val list = List(1, 2, 3, 4) + + assert(!blocksExist(sc, numSlaves)) + + val listBroadcast = sc.broadcast(list, true) + val results = sc.parallelize(1 to numSlaves).map(x => (x, listBroadcast.value.sum)) + assert(results.collect.toSet === (1 to numSlaves).map(x => (x, 10)).toSet) + + assert(blocksExist(sc, numSlaves)) + + listBroadcast.unpersist(removeSource) + + eventually(timeout(1000 milliseconds), interval(10 milliseconds)) { + blocksExist(sc, numSlaves) should be (false) + } + + if (!removeSource) { + val results = sc.parallelize(1 to numSlaves).map(x => (x, listBroadcast.value.sum)) + assert(results.collect.toSet === (1 to numSlaves).map(x => (x, 10)).toSet) + } + } + } + + for (removeSource <- Seq(true, false)) { + testUnpersist("org.apache.spark.broadcast.HttpBroadcastFactory", removeSource) + testUnpersist("org.apache.spark.broadcast.TorrentBroadcastFactory", removeSource) + } } From 80dd9778d2e7338bc93bc7de95ecc6776b0d9e8b Mon Sep 17 00:00:00 2001 From: Roman Pastukhov Date: Fri, 7 Feb 2014 02:53:29 +0400 Subject: [PATCH 02/14] Fix for Broadcast unpersist patch. Updated comment in MemoryStore.dropFromMemory Keep TorrentBroadcast piece blocks until unpersist is called --- .../spark/broadcast/HttpBroadcast.scala | 10 +++- .../spark/broadcast/TorrentBroadcast.scala | 57 ++++++++++++++----- .../apache/spark/storage/MemoryStore.scala | 6 +- 3 files changed, 54 insertions(+), 19 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala index 53fcc2748b4e0..7f056b8feae27 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala @@ -35,11 +35,15 @@ private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolea def value = value_ def unpersist(removeSource: Boolean) { - SparkEnv.get.blockManager.master.removeBlock(blockId) - SparkEnv.get.blockManager.removeBlock(blockId) + HttpBroadcast.synchronized { + SparkEnv.get.blockManager.master.removeBlock(blockId) + SparkEnv.get.blockManager.removeBlock(blockId) + } if (removeSource) { - HttpBroadcast.cleanupById(id) + HttpBroadcast.synchronized { + HttpBroadcast.cleanupById(id) + } } } diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala index 11e74675491c6..e6a8ae199e723 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala @@ -33,19 +33,55 @@ extends Broadcast[T](id) with Logging with Serializable { def value = value_ def unpersist(removeSource: Boolean) { - SparkEnv.get.blockManager.master.removeBlock(broadcastId) - SparkEnv.get.blockManager.removeBlock(broadcastId) + TorrentBroadcast.synchronized { + SparkEnv.get.blockManager.master.removeBlock(broadcastId) + SparkEnv.get.blockManager.removeBlock(broadcastId) + } + + if (!removeSource) { + //We can't tell BlockManager master to remove blocks from all nodes except driver, + //so we need to save them here in order to store them on disk later. + //This may be inefficient if blocks were already dropped to disk, + //but since unpersist is supposed to be called right after working with + //a broadcast this should not happen (and getting them from memory is cheap). + arrayOfBlocks = new Array[TorrentBlock](totalBlocks) + + for (pid <- 0 until totalBlocks) { + val pieceId = pieceBlockId(pid) + TorrentBroadcast.synchronized { + SparkEnv.get.blockManager.getSingle(pieceId) match { + case Some(x) => + arrayOfBlocks(pid) = x.asInstanceOf[TorrentBlock] + case None => + throw new SparkException("Failed to get " + pieceId + " of " + broadcastId) + } + } + } + } + + for (pid <- 0 until totalBlocks) { + TorrentBroadcast.synchronized { + SparkEnv.get.blockManager.master.removeBlock(pieceBlockId(pid)) + } + } if (removeSource) { - for (pid <- pieceIds) { - SparkEnv.get.blockManager.removeBlock(pieceBlockId(pid)) + TorrentBroadcast.synchronized { + SparkEnv.get.blockManager.removeBlock(metaId) } - SparkEnv.get.blockManager.removeBlock(metaId) } else { - for (pid <- pieceIds) { - SparkEnv.get.blockManager.dropFromMemory(pieceBlockId(pid)) + TorrentBroadcast.synchronized { + SparkEnv.get.blockManager.dropFromMemory(metaId) } - SparkEnv.get.blockManager.dropFromMemory(metaId) + + for (i <- 0 until totalBlocks) { + val pieceId = pieceBlockId(i) + TorrentBroadcast.synchronized { + SparkEnv.get.blockManager.putSingle( + pieceId, arrayOfBlocks(i), StorageLevel.DISK_ONLY, true) + } + } + arrayOfBlocks = null } } @@ -128,11 +164,6 @@ extends Broadcast[T](id) with Logging with Serializable { } private def resetWorkerVariables() { - if (arrayOfBlocks != null) { - for (pid <- pieceIds) { - SparkEnv.get.blockManager.removeBlock(pieceBlockId(pid)) - } - } arrayOfBlocks = null totalBytes = -1 totalBlocks = -1 diff --git a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala index 4e47a06c1fed2..5dff0e95b31ba 100644 --- a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala @@ -187,9 +187,9 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long) */ def dropFromMemory(blockId: BlockId) { val entry = entries.synchronized { entries.get(blockId) } - // This should never be null as only one thread should be dropping - // blocks and removing entries. However the check is still here for - // future safety. + // This should never be null if called from ensureFreeSpace as only one + // thread should be dropping blocks and removing entries. + // However the check is required in other cases. if (entry != null) { val data = if (entry.deserialized) { Left(entry.value.asInstanceOf[ArrayBuffer[Any]]) From ba52e00303896e46ce9cb5122e78e12d7cae7864 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Wed, 26 Mar 2014 14:43:52 -0700 Subject: [PATCH 03/14] Refactor broadcast classes --- .../scala/org/apache/spark/SparkContext.scala | 7 +- .../apache/spark/broadcast/Broadcast.scala | 51 ----------- .../spark/broadcast/BroadcastFactory.scala | 2 +- .../spark/broadcast/BroadcastManager.scala | 63 ++++++++++++++ .../spark/broadcast/HttpBroadcast.scala | 59 +++---------- .../broadcast/HttpBroadcastFactory.scala | 34 ++++++++ .../spark/broadcast/TorrentBroadcast.scala | 86 ++----------------- .../broadcast/TorrentBroadcastFactory.scala | 36 ++++++++ .../apache/spark/storage/BlockManager.scala | 12 --- .../apache/spark/storage/MemoryStore.scala | 38 ++++---- .../org/apache/spark/BroadcastSuite.scala | 49 ----------- 11 files changed, 169 insertions(+), 268 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala create mode 100644 core/src/main/scala/org/apache/spark/broadcast/HttpBroadcastFactory.scala create mode 100644 core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 3e4b40a7f7b4d..5cd2caed10297 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -641,13 +641,8 @@ class SparkContext( * Broadcast a read-only variable to the cluster, returning a * [[org.apache.spark.broadcast.Broadcast]] object for reading it in distributed functions. * The variable will be sent to each cluster only once. - * - * If `registerBlocks` is true, workers will notify driver about blocks they create - * and these blocks will be dropped when `unpersist` method of the broadcast variable is called. */ - def broadcast[T](value: T, registerBlocks: Boolean = false) = { - env.broadcastManager.newBroadcast[T](value, isLocal, registerBlocks) - } + def broadcast[T](value: T) = env.broadcastManager.newBroadcast[T](value, isLocal) /** * Add a file to be downloaded with this Spark job on every node. diff --git a/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala index 516e6ba4005c8..e3e1e4f29b107 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala @@ -18,9 +18,6 @@ package org.apache.spark.broadcast import java.io.Serializable -import java.util.concurrent.atomic.AtomicLong - -import org.apache.spark._ /** * A broadcast variable. Broadcast variables allow the programmer to keep a read-only variable @@ -53,56 +50,8 @@ import org.apache.spark._ abstract class Broadcast[T](val id: Long) extends Serializable { def value: T - /** - * Removes all blocks of this broadcast from memory (and disk if removeSource is true). - * - * @param removeSource Whether to remove data from disk as well. - * Will cause errors if broadcast is accessed on workers afterwards - * (e.g. in case of RDD re-computation due to executor failure). - */ - def unpersist(removeSource: Boolean = false) - // We cannot have an abstract readObject here due to some weird issues with // readObject having to be 'private' in sub-classes. override def toString = "Broadcast(" + id + ")" } - -private[spark] -class BroadcastManager(val _isDriver: Boolean, conf: SparkConf, securityManager: SecurityManager) - extends Logging with Serializable { - - private var initialized = false - private var broadcastFactory: BroadcastFactory = null - - initialize() - - // Called by SparkContext or Executor before using Broadcast - private def initialize() { - synchronized { - if (!initialized) { - val broadcastFactoryClass = conf.get( - "spark.broadcast.factory", "org.apache.spark.broadcast.HttpBroadcastFactory") - - broadcastFactory = - Class.forName(broadcastFactoryClass).newInstance.asInstanceOf[BroadcastFactory] - - // Initialize appropriate BroadcastFactory and BroadcastObject - broadcastFactory.initialize(isDriver, conf, securityManager) - - initialized = true - } - } - } - - def stop() { - broadcastFactory.stop() - } - - private val nextBroadcastId = new AtomicLong(0) - - def newBroadcast[T](value_ : T, isLocal: Boolean, registerBlocks: Boolean) = - broadcastFactory.newBroadcast[T](value_, isLocal, nextBroadcastId.getAndIncrement(), registerBlocks) - - def isDriver = _isDriver -} diff --git a/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala b/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala index 7aff8d7bb670b..0a0bb6cca336c 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala @@ -28,6 +28,6 @@ import org.apache.spark.SparkConf */ trait BroadcastFactory { def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager): Unit - def newBroadcast[T](value: T, isLocal: Boolean, id: Long, registerBlocks: Boolean): Broadcast[T] + def newBroadcast[T](value: T, isLocal: Boolean, id: Long): Broadcast[T] def stop(): Unit } diff --git a/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala b/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala new file mode 100644 index 0000000000000..746e23e81931a --- /dev/null +++ b/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.broadcast + +import java.util.concurrent.atomic.AtomicLong + +import org.apache.spark._ + +private[spark] class BroadcastManager( + val isDriver: Boolean, + conf: SparkConf, + securityManager: SecurityManager) + extends Logging with Serializable { + + private var initialized = false + private var broadcastFactory: BroadcastFactory = null + + initialize() + + // Called by SparkContext or Executor before using Broadcast + private def initialize() { + synchronized { + if (!initialized) { + val broadcastFactoryClass = + conf.get("spark.broadcast.factory", "org.apache.spark.broadcast.HttpBroadcastFactory") + + broadcastFactory = + Class.forName(broadcastFactoryClass).newInstance.asInstanceOf[BroadcastFactory] + + // Initialize appropriate BroadcastFactory and BroadcastObject + broadcastFactory.initialize(isDriver, conf, securityManager) + + initialized = true + } + } + } + + def stop() { + broadcastFactory.stop() + } + + private val nextBroadcastId = new AtomicLong(0) + + def newBroadcast[T](value_ : T, isLocal: Boolean) = { + broadcastFactory.newBroadcast[T](value_, isLocal, nextBroadcastId.getAndIncrement()) + } + +} diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala index 6c2413cea526a..374180e472805 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala @@ -29,24 +29,11 @@ import org.apache.spark.io.CompressionCodec import org.apache.spark.storage.{BroadcastBlockId, StorageLevel} import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashSet, Utils} -private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long, registerBlocks: Boolean) +private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long) extends Broadcast[T](id) with Logging with Serializable { def value = value_ - def unpersist(removeSource: Boolean) { - HttpBroadcast.synchronized { - SparkEnv.get.blockManager.master.removeBlock(blockId) - SparkEnv.get.blockManager.removeBlock(blockId) - } - - if (removeSource) { - HttpBroadcast.synchronized { - HttpBroadcast.cleanupById(id) - } - } - } - def blockId = BroadcastBlockId(id) HttpBroadcast.synchronized { @@ -67,7 +54,7 @@ private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolea logInfo("Started reading broadcast variable " + id) val start = System.nanoTime value_ = HttpBroadcast.read[T](id) - SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, registerBlocks) + SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, false) val time = (System.nanoTime - start) / 1e9 logInfo("Reading broadcast variable " + id + " took " + time + " s") } @@ -76,20 +63,6 @@ private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolea } } -/** - * A [[BroadcastFactory]] implementation that uses a HTTP server as the broadcast medium. - */ -class HttpBroadcastFactory extends BroadcastFactory { - def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager) { - HttpBroadcast.initialize(isDriver, conf, securityMgr) - } - - def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long, registerBlocks: Boolean) = - new HttpBroadcast[T](value_, isLocal, id, registerBlocks) - - def stop() { HttpBroadcast.stop() } -} - private object HttpBroadcast extends Logging { private var initialized = false @@ -149,10 +122,8 @@ private object HttpBroadcast extends Logging { logInfo("Broadcast server started at " + serverUri) } - def getFile(id: Long) = new File(broadcastDir, BroadcastBlockId(id).name) - def write(id: Long, value: Any) { - val file = getFile(id) + val file = new File(broadcastDir, BroadcastBlockId(id).name) val out: OutputStream = { if (compress) { compressionCodec.compressedOutputStream(new FileOutputStream(file)) @@ -198,30 +169,20 @@ private object HttpBroadcast extends Logging { obj } - def deleteFile(fileName: String) { - try { - new File(fileName).delete() - logInfo("Deleted broadcast file '" + fileName + "'") - } catch { - case e: Exception => logWarning("Could not delete broadcast file '" + fileName + "'", e) - } - } - def cleanup(cleanupTime: Long) { val iterator = files.internalMap.entrySet().iterator() while(iterator.hasNext) { val entry = iterator.next() val (file, time) = (entry.getKey, entry.getValue) if (time < cleanupTime) { - iterator.remove() - deleteFile(file) + try { + iterator.remove() + new File(file.toString).delete() + logInfo("Deleted broadcast file '" + file + "'") + } catch { + case e: Exception => logWarning("Could not delete broadcast file '" + file + "'", e) + } } } } - - def cleanupById(id: Long) { - val file = getFile(id).getAbsolutePath - files.internalMap.remove(file) - deleteFile(file) - } } diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcastFactory.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcastFactory.scala new file mode 100644 index 0000000000000..c4f0f149534a5 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcastFactory.scala @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.broadcast + +import org.apache.spark.{SecurityManager, SparkConf} + +/** + * A [[BroadcastFactory]] implementation that uses a HTTP server as the broadcast medium. + */ +class HttpBroadcastFactory extends BroadcastFactory { + def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager) { + HttpBroadcast.initialize(isDriver, conf, securityMgr) + } + + def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) = + new HttpBroadcast[T](value_, isLocal, id) + + def stop() { HttpBroadcast.stop() } +} diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala index 206765679e9ed..0828035c5d217 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala @@ -26,68 +26,12 @@ import org.apache.spark._ import org.apache.spark.storage.{BroadcastBlockId, BroadcastHelperBlockId, StorageLevel} import org.apache.spark.util.Utils -private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long, registerBlocks: Boolean) +private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long) extends Broadcast[T](id) with Logging with Serializable { def value = value_ - def unpersist(removeSource: Boolean) { - TorrentBroadcast.synchronized { - SparkEnv.get.blockManager.master.removeBlock(broadcastId) - SparkEnv.get.blockManager.removeBlock(broadcastId) - } - - if (!removeSource) { - //We can't tell BlockManager master to remove blocks from all nodes except driver, - //so we need to save them here in order to store them on disk later. - //This may be inefficient if blocks were already dropped to disk, - //but since unpersist is supposed to be called right after working with - //a broadcast this should not happen (and getting them from memory is cheap). - arrayOfBlocks = new Array[TorrentBlock](totalBlocks) - - for (pid <- 0 until totalBlocks) { - val pieceId = pieceBlockId(pid) - TorrentBroadcast.synchronized { - SparkEnv.get.blockManager.getSingle(pieceId) match { - case Some(x) => - arrayOfBlocks(pid) = x.asInstanceOf[TorrentBlock] - case None => - throw new SparkException("Failed to get " + pieceId + " of " + broadcastId) - } - } - } - } - - for (pid <- 0 until totalBlocks) { - TorrentBroadcast.synchronized { - SparkEnv.get.blockManager.master.removeBlock(pieceBlockId(pid)) - } - } - - if (removeSource) { - TorrentBroadcast.synchronized { - SparkEnv.get.blockManager.removeBlock(metaId) - } - } else { - TorrentBroadcast.synchronized { - SparkEnv.get.blockManager.dropFromMemory(metaId) - } - - for (i <- 0 until totalBlocks) { - val pieceId = pieceBlockId(i) - TorrentBroadcast.synchronized { - SparkEnv.get.blockManager.putSingle( - pieceId, arrayOfBlocks(i), StorageLevel.DISK_ONLY, true) - } - } - arrayOfBlocks = null - } - } - def broadcastId = BroadcastBlockId(id) - private def metaId = BroadcastHelperBlockId(broadcastId, "meta") - private def pieceBlockId(pid: Int) = BroadcastHelperBlockId(broadcastId, "piece" + pid) - private def pieceIds = Array.iterate(0, totalBlocks)(_ + 1).toList TorrentBroadcast.synchronized { SparkEnv.get.blockManager.putSingle(broadcastId, value_, StorageLevel.MEMORY_AND_DISK, false) @@ -110,6 +54,7 @@ private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boo hasBlocks = tInfo.totalBlocks // Store meta-info + val metaId = BroadcastHelperBlockId(broadcastId, "meta") val metaInfo = TorrentInfo(null, totalBlocks, totalBytes) TorrentBroadcast.synchronized { SparkEnv.get.blockManager.putSingle( @@ -118,7 +63,7 @@ private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boo // Store individual pieces for (i <- 0 until totalBlocks) { - val pieceId = pieceBlockId(i) + val pieceId = BroadcastHelperBlockId(broadcastId, "piece" + i) TorrentBroadcast.synchronized { SparkEnv.get.blockManager.putSingle( pieceId, tInfo.arrayOfBlocks(i), StorageLevel.MEMORY_AND_DISK, true) @@ -148,7 +93,7 @@ private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boo // This creates a tradeoff between memory usage and latency. // Storing copy doubles the memory footprint; not storing doubles deserialization cost. SparkEnv.get.blockManager.putSingle( - broadcastId, value_, StorageLevel.MEMORY_AND_DISK, registerBlocks) + broadcastId, value_, StorageLevel.MEMORY_AND_DISK, false) // Remove arrayOfBlocks from memory once value_ is on local cache resetWorkerVariables() @@ -171,6 +116,7 @@ private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boo def receiveBroadcast(variableID: Long): Boolean = { // Receive meta-info + val metaId = BroadcastHelperBlockId(broadcastId, "meta") var attemptId = 10 while (attemptId > 0 && totalBlocks == -1) { TorrentBroadcast.synchronized { @@ -193,9 +139,9 @@ private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boo } // Receive actual blocks - val recvOrder = new Random().shuffle(pieceIds) + val recvOrder = new Random().shuffle(Array.iterate(0, totalBlocks)(_ + 1).toList) for (pid <- recvOrder) { - val pieceId = pieceBlockId(pid) + val pieceId = BroadcastHelperBlockId(broadcastId, "piece" + pid) TorrentBroadcast.synchronized { SparkEnv.get.blockManager.getSingle(pieceId) match { case Some(x) => @@ -215,8 +161,7 @@ private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boo } -private object TorrentBroadcast -extends Logging { +private object TorrentBroadcast extends Logging { private var initialized = false private var conf: SparkConf = null @@ -289,18 +234,3 @@ private[spark] case class TorrentInfo( @transient var hasBlocks = 0 } - -/** - * A [[BroadcastFactory]] that creates a torrent-based implementation of broadcast. - */ -class TorrentBroadcastFactory extends BroadcastFactory { - - def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager) { - TorrentBroadcast.initialize(isDriver, conf) - } - - def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long, registerBlocks: Boolean) = - new TorrentBroadcast[T](value_, isLocal, id, registerBlocks) - - def stop() { TorrentBroadcast.stop() } -} diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala new file mode 100644 index 0000000000000..a51c438c57717 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.broadcast + +import org.apache.spark.{SecurityManager, SparkConf} + +/** + * A [[BroadcastFactory]] that creates a torrent-based implementation of broadcast. + */ +class TorrentBroadcastFactory extends BroadcastFactory { + + def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager) { + TorrentBroadcast.initialize(isDriver, conf) + } + + def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) = + new TorrentBroadcast[T](value_, isLocal, id) + + def stop() { TorrentBroadcast.stop() } + +} diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 84c87949adae4..ca23513c4dc64 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -209,11 +209,6 @@ private[spark] class BlockManager( } } - /** - * For testing. Returns number of blocks BlockManager knows about that are in memory. - */ - def numberOfBlocksInMemory() = blockInfo.keys.count(memoryStore.contains(_)) - /** * Get storage level of local block. If no info exists for the block, then returns null. */ @@ -817,13 +812,6 @@ private[spark] class BlockManager( } /** - * Drop a block from memory, possibly putting it on disk if applicable. - */ - def dropFromMemory(blockId: BlockId) { - memoryStore.asInstanceOf[MemoryStore].dropFromMemory(blockId) - } - - /** * Remove all blocks belonging to the given RDD. * @return The number of blocks removed. */ diff --git a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala index 7d614aa4726b2..488f1ea9628f5 100644 --- a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala @@ -210,27 +210,9 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long) } /** - * Drop a block from memory, possibly putting it on disk if applicable. - */ - def dropFromMemory(blockId: BlockId) { - val entry = entries.synchronized { entries.get(blockId) } - // This should never be null if called from ensureFreeSpace as only one - // thread should be dropping blocks and removing entries. - // However the check is required in other cases. - if (entry != null) { - val data = if (entry.deserialized) { - Left(entry.value.asInstanceOf[ArrayBuffer[Any]]) - } else { - Right(entry.value.asInstanceOf[ByteBuffer].duplicate()) - } - blockManager.dropFromMemory(blockId, data) - } - } - - /** - * Tries to free up a given amount of space to store a particular block, but can fail and return - * false if either the block is bigger than our memory or it would require replacing another - * block from the same RDD (which leads to a wasteful cyclic replacement pattern for RDDs that + * Try to free up a given amount of space to store a particular block, but can fail if + * either the block is bigger than our memory or it would require replacing another block + * from the same RDD (which leads to a wasteful cyclic replacement pattern for RDDs that * don't fit into memory that we want to avoid). * * Assume that a lock is held by the caller to ensure only one thread is dropping blocks. @@ -272,7 +254,19 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long) if (maxMemory - (currentMemory - selectedMemory) >= space) { logInfo(selectedBlocks.size + " blocks selected for dropping") for (blockId <- selectedBlocks) { - dropFromMemory(blockId) + val entry = entries.synchronized { entries.get(blockId) } + // This should never be null as only one thread should be dropping + // blocks and removing entries. However the check is still here for + // future safety. + if (entry != null) { + val data = if (entry.deserialized) { + Left(entry.value.asInstanceOf[ArrayBuffer[Any]]) + } else { + Right(entry.value.asInstanceOf[ByteBuffer].duplicate()) + } + val droppedBlockStatus = blockManager.dropFromMemory(blockId, data) + droppedBlockStatus.foreach { status => droppedBlocks += ((blockId, status)) } + } } return ResultWithDroppedBlocks(success = true, droppedBlocks) } else { diff --git a/core/src/test/scala/org/apache/spark/BroadcastSuite.scala b/core/src/test/scala/org/apache/spark/BroadcastSuite.scala index dad330d6513da..e022accee6d08 100644 --- a/core/src/test/scala/org/apache/spark/BroadcastSuite.scala +++ b/core/src/test/scala/org/apache/spark/BroadcastSuite.scala @@ -18,15 +18,9 @@ package org.apache.spark import org.scalatest.FunSuite -import org.scalatest.concurrent.Timeouts._ -import org.scalatest.time.{Millis, Span} -import org.scalatest.concurrent.Eventually._ -import org.scalatest.time.SpanSugar._ -import org.scalatest.matchers.ShouldMatchers._ class BroadcastSuite extends FunSuite with LocalSparkContext { - override def afterEach() { super.afterEach() System.clearProperty("spark.broadcast.factory") @@ -88,47 +82,4 @@ class BroadcastSuite extends FunSuite with LocalSparkContext { assert(results.collect.toSet === (1 to numSlaves).map(x => (x, 10)).toSet) } - def blocksExist(sc: SparkContext, numSlaves: Int) = { - val rdd = sc.parallelize(1 to numSlaves, numSlaves) - val workerBlocks = rdd.mapPartitions(_ => { - val blocks = SparkEnv.get.blockManager.numberOfBlocksInMemory() - Seq(blocks).iterator - }) - val totalKnown = workerBlocks.reduce(_ + _) + sc.env.blockManager.numberOfBlocksInMemory() - - totalKnown > 0 - } - - def testUnpersist(bcFactory: String, removeSource: Boolean) { - test("Broadcast unpersist(" + removeSource + ") with " + bcFactory) { - val numSlaves = 2 - System.setProperty("spark.broadcast.factory", bcFactory) - sc = new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test") - val list = List(1, 2, 3, 4) - - assert(!blocksExist(sc, numSlaves)) - - val listBroadcast = sc.broadcast(list, true) - val results = sc.parallelize(1 to numSlaves).map(x => (x, listBroadcast.value.sum)) - assert(results.collect.toSet === (1 to numSlaves).map(x => (x, 10)).toSet) - - assert(blocksExist(sc, numSlaves)) - - listBroadcast.unpersist(removeSource) - - eventually(timeout(1000 milliseconds), interval(10 milliseconds)) { - blocksExist(sc, numSlaves) should be (false) - } - - if (!removeSource) { - val results = sc.parallelize(1 to numSlaves).map(x => (x, listBroadcast.value.sum)) - assert(results.collect.toSet === (1 to numSlaves).map(x => (x, 10)).toSet) - } - } - } - - for (removeSource <- Seq(true, false)) { - testUnpersist("org.apache.spark.broadcast.HttpBroadcastFactory", removeSource) - testUnpersist("org.apache.spark.broadcast.TorrentBroadcastFactory", removeSource) - } } From d0edef3dda333b5bf43a320acd214f276b8a5b3e Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Wed, 26 Mar 2014 14:57:07 -0700 Subject: [PATCH 04/14] Add framework for broadcast cleanup As of this commit, Spark does not clean up broadcast blocks. This will be done in the next commit. --- .../org/apache/spark/ContextCleaner.scala | 134 +++++++++++------- .../scala/org/apache/spark/SparkContext.scala | 6 +- .../apache/spark/broadcast/Broadcast.scala | 6 + .../spark/broadcast/BroadcastFactory.scala | 1 + .../spark/broadcast/BroadcastManager.scala | 4 + .../spark/broadcast/HttpBroadcast.scala | 81 ++++++++--- .../broadcast/HttpBroadcastFactory.scala | 8 ++ .../spark/broadcast/TorrentBroadcast.scala | 86 ++++++----- .../broadcast/TorrentBroadcastFactory.scala | 7 + .../spark/storage/BlockManagerMessages.scala | 2 +- .../storage/BlockManagerSlaveActor.scala | 5 +- .../apache/spark/ContextCleanerSuite.scala | 21 ++- 12 files changed, 249 insertions(+), 112 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ContextCleaner.scala b/core/src/main/scala/org/apache/spark/ContextCleaner.scala index deabf6f5c8c5f..f856a13f84dec 100644 --- a/core/src/main/scala/org/apache/spark/ContextCleaner.scala +++ b/core/src/main/scala/org/apache/spark/ContextCleaner.scala @@ -21,27 +21,41 @@ import java.lang.ref.{ReferenceQueue, WeakReference} import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer} +import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD -/** Listener class used for testing when any item has been cleaned by the Cleaner class */ -private[spark] trait CleanerListener { - def rddCleaned(rddId: Int) - def shuffleCleaned(shuffleId: Int) -} +/** + * Classes that represent cleaning tasks. + */ +private sealed trait CleanupTask +private case class CleanRDD(rddId: Int) extends CleanupTask +private case class CleanShuffle(shuffleId: Int) extends CleanupTask +private case class CleanBroadcast(broadcastId: Long) extends CleanupTask /** - * Cleans RDDs and shuffle data. + * A WeakReference associated with a CleanupTask. + * + * When the referent object becomes only weakly reachable, the corresponding + * CleanupTaskWeakReference is automatically added to the given reference queue. + */ +private class CleanupTaskWeakReference( + val task: CleanupTask, + referent: AnyRef, + referenceQueue: ReferenceQueue[AnyRef]) + extends WeakReference(referent, referenceQueue) + +/** + * An asynchronous cleaner for RDD, shuffle, and broadcast state. + * + * This maintains a weak reference for each RDD, ShuffleDependency, and Broadcast of interest, + * to be processed when the associated object goes out of scope of the application. Actual + * cleanup is performed in a separate daemon thread. */ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { - /** Classes to represent cleaning tasks */ - private sealed trait CleanupTask - private case class CleanRDD(rddId: Int) extends CleanupTask - private case class CleanShuffle(shuffleId: Int) extends CleanupTask - // TODO: add CleanBroadcast + private val referenceBuffer = new ArrayBuffer[CleanupTaskWeakReference] + with SynchronizedBuffer[CleanupTaskWeakReference] - private val referenceBuffer = new ArrayBuffer[WeakReferenceWithCleanupTask] - with SynchronizedBuffer[WeakReferenceWithCleanupTask] private val referenceQueue = new ReferenceQueue[AnyRef] private val listeners = new ArrayBuffer[CleanerListener] @@ -49,77 +63,64 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { private val cleaningThread = new Thread() { override def run() { keepCleaning() }} - private val REF_QUEUE_POLL_TIMEOUT = 100 - @volatile private var stopped = false - private class WeakReferenceWithCleanupTask(referent: AnyRef, val task: CleanupTask) - extends WeakReference(referent, referenceQueue) + /** Attach a listener object to get information of when objects are cleaned. */ + def attachListener(listener: CleanerListener) { + listeners += listener + } - /** Start the cleaner */ + /** Start the cleaner. */ def start() { cleaningThread.setDaemon(true) cleaningThread.setName("ContextCleaner") cleaningThread.start() } - /** Stop the cleaner */ + /** Stop the cleaner. */ def stop() { stopped = true cleaningThread.interrupt() } - /** - * Register a RDD for cleanup when it is garbage collected. - */ + /** Register a RDD for cleanup when it is garbage collected. */ def registerRDDForCleanup(rdd: RDD[_]) { registerForCleanup(rdd, CleanRDD(rdd.id)) } - /** - * Register a shuffle dependency for cleanup when it is garbage collected. - */ + /** Register a ShuffleDependency for cleanup when it is garbage collected. */ def registerShuffleForCleanup(shuffleDependency: ShuffleDependency[_, _]) { registerForCleanup(shuffleDependency, CleanShuffle(shuffleDependency.shuffleId)) } - /** Cleanup RDD. */ - def cleanupRDD(rdd: RDD[_]) { - doCleanupRDD(rdd.id) - } - - /** Cleanup shuffle. */ - def cleanupShuffle(shuffleDependency: ShuffleDependency[_, _]) { - doCleanupShuffle(shuffleDependency.shuffleId) - } - - /** Attach a listener object to get information of when objects are cleaned. */ - def attachListener(listener: CleanerListener) { - listeners += listener + /** Register a Broadcast for cleanup when it is garbage collected. */ + def registerBroadcastForCleanup[T](broadcast: Broadcast[T]) { + registerForCleanup(broadcast, CleanBroadcast(broadcast.id)) } /** Register an object for cleanup. */ private def registerForCleanup(objectForCleanup: AnyRef, task: CleanupTask) { - referenceBuffer += new WeakReferenceWithCleanupTask(objectForCleanup, task) + referenceBuffer += new CleanupTaskWeakReference(task, objectForCleanup, referenceQueue) } - /** Keep cleaning RDDs and shuffle data */ + /** Keep cleaning RDD, shuffle, and broadcast state. */ private def keepCleaning() { - while (!isStopped) { + while (!stopped) { try { - val reference = Option(referenceQueue.remove(REF_QUEUE_POLL_TIMEOUT)) - .map(_.asInstanceOf[WeakReferenceWithCleanupTask]) + val reference = Option(referenceQueue.remove(ContextCleaner.REF_QUEUE_POLL_TIMEOUT)) + .map(_.asInstanceOf[CleanupTaskWeakReference]) reference.map(_.task).foreach { task => logDebug("Got cleaning task " + task) referenceBuffer -= reference.get task match { case CleanRDD(rddId) => doCleanupRDD(rddId) case CleanShuffle(shuffleId) => doCleanupShuffle(shuffleId) + case CleanBroadcast(broadcastId) => doCleanupBroadcast(broadcastId) } } } catch { case ie: InterruptedException => - if (!isStopped) logWarning("Cleaning thread interrupted") + if (!stopped) logWarning("Cleaning thread interrupted") case t: Throwable => logError("Error in cleaning thread", t) } } @@ -129,7 +130,7 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { private def doCleanupRDD(rddId: Int) { try { logDebug("Cleaning RDD " + rddId) - sc.unpersistRDD(rddId, false) + sc.unpersistRDD(rddId, blocking = false) listeners.foreach(_.rddCleaned(rddId)) logInfo("Cleaned RDD " + rddId) } catch { @@ -150,10 +151,47 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { } } - private def mapOutputTrackerMaster = - sc.env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster] + /** Perform broadcast cleanup. */ + private def doCleanupBroadcast(broadcastId: Long) { + try { + logDebug("Cleaning broadcast " + broadcastId) + broadcastManager.unbroadcast(broadcastId, removeFromDriver = true) + listeners.foreach(_.broadcastCleaned(broadcastId)) + logInfo("Cleaned broadcast " + broadcastId) + } catch { + case t: Throwable => logError("Error cleaning broadcast " + broadcastId, t) + } + } private def blockManagerMaster = sc.env.blockManager.master + private def broadcastManager = sc.env.broadcastManager + private def mapOutputTrackerMaster = sc.env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster] + + // Used for testing + + private[spark] def cleanupRDD(rdd: RDD[_]) { + doCleanupRDD(rdd.id) + } + + private[spark] def cleanupShuffle(shuffleDependency: ShuffleDependency[_, _]) { + doCleanupShuffle(shuffleDependency.shuffleId) + } - private def isStopped = stopped + private[spark] def cleanupBroadcast[T](broadcast: Broadcast[T]) { + doCleanupBroadcast(broadcast.id) + } + +} + +private object ContextCleaner { + private val REF_QUEUE_POLL_TIMEOUT = 100 +} + +/** + * Listener class used for testing when any item has been cleaned by the Cleaner class. + */ +private[spark] trait CleanerListener { + def rddCleaned(rddId: Int) + def shuffleCleaned(shuffleId: Int) + def broadcastCleaned(broadcastId: Long) } diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 5cd2caed10297..689180fcd719b 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -642,7 +642,11 @@ class SparkContext( * [[org.apache.spark.broadcast.Broadcast]] object for reading it in distributed functions. * The variable will be sent to each cluster only once. */ - def broadcast[T](value: T) = env.broadcastManager.newBroadcast[T](value, isLocal) + def broadcast[T](value: T) = { + val bc = env.broadcastManager.newBroadcast[T](value, isLocal) + cleaner.registerBroadcastForCleanup(bc) + bc + } /** * Add a file to be downloaded with this Spark job on every node. diff --git a/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala index e3e1e4f29b107..d75b9acfb7aa0 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala @@ -50,6 +50,12 @@ import java.io.Serializable abstract class Broadcast[T](val id: Long) extends Serializable { def value: T + /** + * Remove all persisted state associated with this broadcast. + * @param removeFromDriver Whether to remove state from the driver. + */ + def unpersist(removeFromDriver: Boolean) + // We cannot have an abstract readObject here due to some weird issues with // readObject having to be 'private' in sub-classes. diff --git a/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala b/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala index 0a0bb6cca336c..850650951e603 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala @@ -29,5 +29,6 @@ import org.apache.spark.SparkConf trait BroadcastFactory { def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager): Unit def newBroadcast[T](value: T, isLocal: Boolean, id: Long): Broadcast[T] + def unbroadcast(id: Long, removeFromDriver: Boolean) def stop(): Unit } diff --git a/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala b/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala index 746e23e81931a..85d62aae03959 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala @@ -60,4 +60,8 @@ private[spark] class BroadcastManager( broadcastFactory.newBroadcast[T](value_, isLocal, nextBroadcastId.getAndIncrement()) } + def unbroadcast(id: Long, removeFromDriver: Boolean) { + broadcastFactory.unbroadcast(id, removeFromDriver) + } + } diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala index 374180e472805..89361efec44a4 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala @@ -21,10 +21,9 @@ import java.io.{File, FileOutputStream, ObjectInputStream, OutputStream} import java.net.{URL, URLConnection, URI} import java.util.concurrent.TimeUnit -import it.unimi.dsi.fastutil.io.FastBufferedInputStream -import it.unimi.dsi.fastutil.io.FastBufferedOutputStream +import it.unimi.dsi.fastutil.io.{FastBufferedInputStream, FastBufferedOutputStream} -import org.apache.spark.{SparkConf, HttpServer, Logging, SecurityManager, SparkEnv} +import org.apache.spark.{HttpServer, Logging, SecurityManager, SparkConf, SparkEnv} import org.apache.spark.io.CompressionCodec import org.apache.spark.storage.{BroadcastBlockId, StorageLevel} import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashSet, Utils} @@ -32,18 +31,27 @@ import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedH private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long) extends Broadcast[T](id) with Logging with Serializable { - def value = value_ + override def value = value_ - def blockId = BroadcastBlockId(id) + val blockId = BroadcastBlockId(id) HttpBroadcast.synchronized { - SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, false) + SparkEnv.get.blockManager.putSingle( + blockId, value_, StorageLevel.MEMORY_AND_DISK, tellMaster = false) } if (!isLocal) { HttpBroadcast.write(id, value_) } + /** + * Remove all persisted state associated with this HTTP broadcast. + * @param removeFromDriver Whether to remove state from the driver. + */ + override def unpersist(removeFromDriver: Boolean) { + HttpBroadcast.unpersist(id, removeFromDriver) + } + // Called by JVM when deserializing an object private def readObject(in: ObjectInputStream) { in.defaultReadObject() @@ -54,7 +62,8 @@ private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolea logInfo("Started reading broadcast variable " + id) val start = System.nanoTime value_ = HttpBroadcast.read[T](id) - SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, false) + SparkEnv.get.blockManager.putSingle( + blockId, value_, StorageLevel.MEMORY_AND_DISK, tellMaster = false) val time = (System.nanoTime - start) / 1e9 logInfo("Reading broadcast variable " + id + " took " + time + " s") } @@ -63,7 +72,7 @@ private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolea } } -private object HttpBroadcast extends Logging { +private[spark] object HttpBroadcast extends Logging { private var initialized = false private var broadcastDir: File = null @@ -74,7 +83,7 @@ private object HttpBroadcast extends Logging { private var securityManager: SecurityManager = null // TODO: This shouldn't be a global variable so that multiple SparkContexts can coexist - private val files = new TimeStampedHashSet[String] + val files = new TimeStampedHashSet[String] private var cleaner: MetadataCleaner = null private val httpReadTimeout = TimeUnit.MILLISECONDS.convert(5, TimeUnit.MINUTES).toInt @@ -122,8 +131,10 @@ private object HttpBroadcast extends Logging { logInfo("Broadcast server started at " + serverUri) } + def getFile(id: Long) = new File(broadcastDir, BroadcastBlockId(id).name) + def write(id: Long, value: Any) { - val file = new File(broadcastDir, BroadcastBlockId(id).name) + val file = getFile(id) val out: OutputStream = { if (compress) { compressionCodec.compressedOutputStream(new FileOutputStream(file)) @@ -146,7 +157,7 @@ private object HttpBroadcast extends Logging { if (securityManager.isAuthenticationEnabled()) { logDebug("broadcast security enabled") val newuri = Utils.constructURIForAuthentication(new URI(url), securityManager) - uc = newuri.toURL().openConnection() + uc = newuri.toURL.openConnection() uc.setAllowUserInteraction(false) } else { logDebug("broadcast not using security") @@ -155,7 +166,7 @@ private object HttpBroadcast extends Logging { val in = { uc.setReadTimeout(httpReadTimeout) - val inputStream = uc.getInputStream(); + val inputStream = uc.getInputStream if (compress) { compressionCodec.compressedInputStream(inputStream) } else { @@ -169,20 +180,50 @@ private object HttpBroadcast extends Logging { obj } - def cleanup(cleanupTime: Long) { + /** + * Remove all persisted blocks associated with this HTTP broadcast on the executors. + * If removeFromDriver is true, also remove these persisted blocks on the driver + * and delete the associated broadcast file. + */ + def unpersist(id: Long, removeFromDriver: Boolean) = synchronized { + //SparkEnv.get.blockManager.master.removeBroadcast(id, removeFromDriver) + if (removeFromDriver) { + val file = new File(broadcastDir, BroadcastBlockId(id).name) + files.remove(file.toString) + deleteBroadcastFile(file) + } + } + + /** + * Periodically clean up old broadcasts by removing the associated map entries and + * deleting the associated files. + */ + private def cleanup(cleanupTime: Long) { val iterator = files.internalMap.entrySet().iterator() while(iterator.hasNext) { val entry = iterator.next() val (file, time) = (entry.getKey, entry.getValue) if (time < cleanupTime) { - try { - iterator.remove() - new File(file.toString).delete() - logInfo("Deleted broadcast file '" + file + "'") - } catch { - case e: Exception => logWarning("Could not delete broadcast file '" + file + "'", e) - } + iterator.remove() + deleteBroadcastFile(new File(file.toString)) } } } + + /** Delete the given broadcast file. */ + private def deleteBroadcastFile(file: File) { + try { + if (!file.exists()) { + logWarning("Broadcast file to be deleted does not exist: %s".format(file)) + } else if (file.delete()) { + logInfo("Deleted broadcast file: %s".format(file)) + } else { + logWarning("Could not delete broadcast file: %s".format(file)) + } + } catch { + case e: Exception => + logWarning("Exception while deleting broadcast file: %s".format(file), e) + } + } + } diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcastFactory.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcastFactory.scala index c4f0f149534a5..4affa922156c9 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcastFactory.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcastFactory.scala @@ -31,4 +31,12 @@ class HttpBroadcastFactory extends BroadcastFactory { new HttpBroadcast[T](value_, isLocal, id) def stop() { HttpBroadcast.stop() } + + /** + * Remove all persisted state associated with the HTTP broadcast with the given ID. + * @param removeFromDriver Whether to remove state from the driver. + */ + def unbroadcast(id: Long, removeFromDriver: Boolean) { + HttpBroadcast.unpersist(id, removeFromDriver) + } } diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala index 0828035c5d217..07ef54bb120b9 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala @@ -29,12 +29,13 @@ import org.apache.spark.util.Utils private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long) extends Broadcast[T](id) with Logging with Serializable { - def value = value_ + override def value = value_ - def broadcastId = BroadcastBlockId(id) + val broadcastId = BroadcastBlockId(id) TorrentBroadcast.synchronized { - SparkEnv.get.blockManager.putSingle(broadcastId, value_, StorageLevel.MEMORY_AND_DISK, false) + SparkEnv.get.blockManager.putSingle( + broadcastId, value_, StorageLevel.MEMORY_AND_DISK, tellMaster = false) } @transient var arrayOfBlocks: Array[TorrentBlock] = null @@ -47,8 +48,7 @@ private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boo } def sendBroadcast() { - var tInfo = TorrentBroadcast.blockifyObject(value_) - + val tInfo = TorrentBroadcast.blockifyObject(value_) totalBlocks = tInfo.totalBlocks totalBytes = tInfo.totalBytes hasBlocks = tInfo.totalBlocks @@ -58,7 +58,7 @@ private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boo val metaInfo = TorrentInfo(null, totalBlocks, totalBytes) TorrentBroadcast.synchronized { SparkEnv.get.blockManager.putSingle( - metaId, metaInfo, StorageLevel.MEMORY_AND_DISK, true) + metaId, metaInfo, StorageLevel.MEMORY_AND_DISK, tellMaster = true) } // Store individual pieces @@ -66,11 +66,19 @@ private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boo val pieceId = BroadcastHelperBlockId(broadcastId, "piece" + i) TorrentBroadcast.synchronized { SparkEnv.get.blockManager.putSingle( - pieceId, tInfo.arrayOfBlocks(i), StorageLevel.MEMORY_AND_DISK, true) + pieceId, tInfo.arrayOfBlocks(i), StorageLevel.MEMORY_AND_DISK, tellMaster = true) } } } + /** + * Remove all persisted state associated with this HTTP broadcast. + * @param removeFromDriver Whether to remove state from the driver. + */ + override def unpersist(removeFromDriver: Boolean) { + TorrentBroadcast.unpersist(id, removeFromDriver) + } + // Called by JVM when deserializing an object private def readObject(in: ObjectInputStream) { in.defaultReadObject() @@ -86,18 +94,18 @@ private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boo // Initialize @transient variables that will receive garbage values from the master. resetWorkerVariables() - if (receiveBroadcast(id)) { + if (receiveBroadcast()) { value_ = TorrentBroadcast.unBlockifyObject[T](arrayOfBlocks, totalBytes, totalBlocks) - // Store the merged copy in cache so that the next worker doesn't need to rebuild it. - // This creates a tradeoff between memory usage and latency. - // Storing copy doubles the memory footprint; not storing doubles deserialization cost. + /* Store the merged copy in cache so that the next worker doesn't need to rebuild it. + * This creates a trade-off between memory usage and latency. Storing copy doubles + * the memory footprint; not storing doubles deserialization cost. */ SparkEnv.get.blockManager.putSingle( - broadcastId, value_, StorageLevel.MEMORY_AND_DISK, false) + broadcastId, value_, StorageLevel.MEMORY_AND_DISK, tellMaster = false) // Remove arrayOfBlocks from memory once value_ is on local cache resetWorkerVariables() - } else { + } else { logError("Reading broadcast variable " + id + " failed") } @@ -114,7 +122,7 @@ private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boo hasBlocks = 0 } - def receiveBroadcast(variableID: Long): Boolean = { + def receiveBroadcast(): Boolean = { // Receive meta-info val metaId = BroadcastHelperBlockId(broadcastId, "meta") var attemptId = 10 @@ -148,7 +156,7 @@ private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boo arrayOfBlocks(pid) = x.asInstanceOf[TorrentBlock] hasBlocks += 1 SparkEnv.get.blockManager.putSingle( - pieceId, arrayOfBlocks(pid), StorageLevel.MEMORY_AND_DISK, true) + pieceId, arrayOfBlocks(pid), StorageLevel.MEMORY_AND_DISK, tellMaster = true) case None => throw new SparkException("Failed to get " + pieceId + " of " + broadcastId) @@ -156,15 +164,17 @@ private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boo } } - (hasBlocks == totalBlocks) + hasBlocks == totalBlocks } } -private object TorrentBroadcast extends Logging { - +private[spark] object TorrentBroadcast extends Logging { private var initialized = false private var conf: SparkConf = null + + lazy val BLOCK_SIZE = conf.getInt("spark.broadcast.blockSize", 4096) * 1024 + def initialize(_isDriver: Boolean, conf: SparkConf) { TorrentBroadcast.conf = conf //TODO: we might have to fix it in tests synchronized { @@ -178,39 +188,37 @@ private object TorrentBroadcast extends Logging { initialized = false } - lazy val BLOCK_SIZE = conf.getInt("spark.broadcast.blockSize", 4096) * 1024 - def blockifyObject[T](obj: T): TorrentInfo = { val byteArray = Utils.serialize[T](obj) val bais = new ByteArrayInputStream(byteArray) - var blockNum = (byteArray.length / BLOCK_SIZE) + var blockNum = byteArray.length / BLOCK_SIZE if (byteArray.length % BLOCK_SIZE != 0) { blockNum += 1 } - var retVal = new Array[TorrentBlock](blockNum) - var blockID = 0 + val blocks = new Array[TorrentBlock](blockNum) + var blockId = 0 for (i <- 0 until (byteArray.length, BLOCK_SIZE)) { val thisBlockSize = math.min(BLOCK_SIZE, byteArray.length - i) - var tempByteArray = new Array[Byte](thisBlockSize) - val hasRead = bais.read(tempByteArray, 0, thisBlockSize) + val tempByteArray = new Array[Byte](thisBlockSize) + bais.read(tempByteArray, 0, thisBlockSize) - retVal(blockID) = new TorrentBlock(blockID, tempByteArray) - blockID += 1 + blocks(blockId) = new TorrentBlock(blockId, tempByteArray) + blockId += 1 } bais.close() - val tInfo = TorrentInfo(retVal, blockNum, byteArray.length) - tInfo.hasBlocks = blockNum - - tInfo + val info = TorrentInfo(blocks, blockNum, byteArray.length) + info.hasBlocks = blockNum + info } - def unBlockifyObject[T](arrayOfBlocks: Array[TorrentBlock], - totalBytes: Int, - totalBlocks: Int): T = { + def unBlockifyObject[T]( + arrayOfBlocks: Array[TorrentBlock], + totalBytes: Int, + totalBlocks: Int): T = { val retByteArray = new Array[Byte](totalBytes) for (i <- 0 until totalBlocks) { System.arraycopy(arrayOfBlocks(i).byteArray, 0, retByteArray, @@ -219,6 +227,14 @@ private object TorrentBroadcast extends Logging { Utils.deserialize[T](retByteArray, Thread.currentThread.getContextClassLoader) } + /** + * Remove all persisted blocks associated with this torrent broadcast on the executors. + * If removeFromDriver is true, also remove these persisted blocks on the driver. + */ + def unpersist(id: Long, removeFromDriver: Boolean) = synchronized { + //SparkEnv.get.blockManager.master.removeBroadcast(id, removeFromDriver) + } + } private[spark] case class TorrentBlock( @@ -227,7 +243,7 @@ private[spark] case class TorrentBlock( extends Serializable private[spark] case class TorrentInfo( - @transient arrayOfBlocks : Array[TorrentBlock], + @transient arrayOfBlocks: Array[TorrentBlock], totalBlocks: Int, totalBytes: Int) extends Serializable { diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala index a51c438c57717..eabe792b550bb 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala @@ -33,4 +33,11 @@ class TorrentBroadcastFactory extends BroadcastFactory { def stop() { TorrentBroadcast.stop() } + /** + * Remove all persisted state associated with the torrent broadcast with the given ID. + * @param removeFromDriver Whether to remove state from the driver. + */ + def unbroadcast(id: Long, removeFromDriver: Boolean) { + TorrentBroadcast.unpersist(id, removeFromDriver) + } } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala index 50ea4e31ce509..4c5b31d0abe44 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala @@ -35,7 +35,7 @@ private[storage] object BlockManagerMessages { case class RemoveRdd(rddId: Int) extends ToBlockManagerSlave // Remove all blocks belonging to a specific shuffle. - case class RemoveShuffle(shuffleId: Int) + case class RemoveShuffle(shuffleId: Int) extends ToBlockManagerSlave ////////////////////////////////////////////////////////////////////////////////// diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala index a6ff147c1d3e6..9a12481b7f6d5 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala @@ -29,8 +29,9 @@ import org.apache.spark.storage.BlockManagerMessages._ private[storage] class BlockManagerSlaveActor( blockManager: BlockManager, - mapOutputTracker: MapOutputTracker - ) extends Actor { + mapOutputTracker: MapOutputTracker) + extends Actor { + override def receive = { case RemoveBlock(blockId) => diff --git a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala index b07f8817b7974..11e22145ebb88 100644 --- a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala @@ -17,7 +17,10 @@ package org.apache.spark +import java.lang.ref.WeakReference + import scala.collection.mutable.{ArrayBuffer, HashSet, SynchronizedSet} +import scala.util.Random import org.scalatest.{BeforeAndAfter, FunSuite} import org.scalatest.concurrent.Eventually @@ -26,9 +29,7 @@ import org.scalatest.time.SpanSugar._ import org.apache.spark.SparkContext._ import org.apache.spark.storage.{RDDBlockId, ShuffleBlockId} -import org.apache.spark.rdd.{ShuffleCoGroupSplitDep, RDD} -import scala.util.Random -import java.lang.ref.WeakReference +import org.apache.spark.rdd.RDD class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkContext { @@ -67,7 +68,7 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo test("automatically cleanup RDD") { var rdd = newRDD.persist() rdd.count() - + // test that GC does not cause RDD cleanup due to a strong reference val preGCTester = new CleanerTester(sc, rddIds = Seq(rdd.id)) runGC() @@ -171,11 +172,16 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo /** Class to test whether RDDs, shuffles, etc. have been successfully cleaned. */ -class CleanerTester(sc: SparkContext, rddIds: Seq[Int] = Nil, shuffleIds: Seq[Int] = Nil) +class CleanerTester( + sc: SparkContext, + rddIds: Seq[Int] = Seq.empty, + shuffleIds: Seq[Int] = Seq.empty, + broadcastIds: Seq[Long] = Seq.empty) extends Logging { val toBeCleanedRDDIds = new HashSet[Int] with SynchronizedSet[Int] ++= rddIds val toBeCleanedShuffleIds = new HashSet[Int] with SynchronizedSet[Int] ++= shuffleIds + val toBeCleanedBroadcstIds = new HashSet[Long] with SynchronizedSet[Long] ++= broadcastIds val cleanerListener = new CleanerListener { def rddCleaned(rddId: Int): Unit = { @@ -187,6 +193,11 @@ class CleanerTester(sc: SparkContext, rddIds: Seq[Int] = Nil, shuffleIds: Seq[In toBeCleanedShuffleIds -= shuffleId logInfo("Shuffle " + shuffleId + " cleaned") } + + def broadcastCleaned(broadcastId: Long): Unit = { + toBeCleanedBroadcstIds -= broadcastId + logInfo("Broadcast" + broadcastId + " cleaned") + } } val MAX_VALIDATION_ATTEMPTS = 10 From 544ac866edf21230140fe56ee7a428fe0ab86329 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Wed, 26 Mar 2014 15:11:42 -0700 Subject: [PATCH 05/14] Clean up broadcast blocks through BlockManager* --- .../apache/spark/broadcast/HttpBroadcast.scala | 2 +- .../spark/broadcast/TorrentBroadcast.scala | 2 +- .../org/apache/spark/storage/BlockManager.scala | 14 +++++++++++++- .../spark/storage/BlockManagerMaster.scala | 7 +++++++ .../spark/storage/BlockManagerMasterActor.scala | 16 +++++++++++++--- .../spark/storage/BlockManagerMessages.scala | 13 ++++++++++--- .../spark/storage/BlockManagerSlaveActor.scala | 3 +++ .../main/scala/org/apache/spark/util/Utils.scala | 8 ++++---- .../org/apache/spark/ContextCleanerSuite.scala | 2 +- 9 files changed, 53 insertions(+), 14 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala index 89361efec44a4..4985d4202ed6b 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala @@ -186,7 +186,7 @@ private[spark] object HttpBroadcast extends Logging { * and delete the associated broadcast file. */ def unpersist(id: Long, removeFromDriver: Boolean) = synchronized { - //SparkEnv.get.blockManager.master.removeBroadcast(id, removeFromDriver) + SparkEnv.get.blockManager.master.removeBroadcast(id, removeFromDriver) if (removeFromDriver) { val file = new File(broadcastDir, BroadcastBlockId(id).name) files.remove(file.toString) diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala index 07ef54bb120b9..51f1592cef752 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala @@ -232,7 +232,7 @@ private[spark] object TorrentBroadcast extends Logging { * If removeFromDriver is true, also remove these persisted blocks on the driver. */ def unpersist(id: Long, removeFromDriver: Boolean) = synchronized { - //SparkEnv.get.blockManager.master.removeBroadcast(id, removeFromDriver) + SparkEnv.get.blockManager.master.removeBroadcast(id, removeFromDriver) } } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index ca23513c4dc64..3c0941e195724 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -820,10 +820,22 @@ private[spark] class BlockManager( // from RDD.id to blocks. logInfo("Removing RDD " + rddId) val blocksToRemove = blockInfo.keys.flatMap(_.asRDDId).filter(_.rddId == rddId) - blocksToRemove.foreach(blockId => removeBlock(blockId, tellMaster = false)) + blocksToRemove.foreach { blockId => removeBlock(blockId, tellMaster = false) } blocksToRemove.size } + /** + * Remove all blocks belonging to the given broadcast. + */ + def removeBroadcast(broadcastId: Long) { + logInfo("Removing broadcast " + broadcastId) + val blocksToRemove = blockInfo.keys.filter(_.isBroadcast).collect { + case bid: BroadcastBlockId if bid.broadcastId == broadcastId => bid + case bid: BroadcastHelperBlockId if bid.broadcastId.broadcastId == broadcastId => bid + } + blocksToRemove.foreach { blockId => removeBlock(blockId) } + } + /** * Remove a block from both memory and disk. */ diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala index ff3f22b3b092a..4579c0d959553 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala @@ -126,6 +126,13 @@ class BlockManagerMaster(var driverActor: ActorRef, conf: SparkConf) extends Log askDriverWithReply(RemoveShuffle(shuffleId)) } + /** + * Remove all blocks belonging to the given broadcast. + */ + def removeBroadcast(broadcastId: Long, removeFromMaster: Boolean) { + askDriverWithReply(RemoveBroadcast(broadcastId, removeFromMaster)) + } + /** * Return the memory status for each block manager, in the form of a map from * the block manager's id to two long values. The first value is the maximum diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala index 646ccb7fa74f6..4cc4227fd87e2 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala @@ -100,6 +100,10 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus removeShuffle(shuffleId) sender ! true + case RemoveBroadcast(broadcastId, removeFromDriver) => + removeBroadcast(broadcastId, removeFromDriver) + sender ! true + case RemoveBlock(blockId) => removeBlockFromWorkers(blockId) sender ! true @@ -151,9 +155,15 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus private def removeShuffle(shuffleId: Int) { // Nothing to do in the BlockManagerMasterActor data structures val removeMsg = RemoveShuffle(shuffleId) - blockManagerInfo.values.foreach { bm => - bm.slaveActor ! removeMsg - } + blockManagerInfo.values.foreach { bm => bm.slaveActor ! removeMsg } + } + + private def removeBroadcast(broadcastId: Long, removeFromDriver: Boolean) { + // TODO(aor): Consolidate usages of + val removeMsg = RemoveBroadcast(broadcastId) + blockManagerInfo.values + .filter { info => removeFromDriver || info.blockManagerId.executorId != "" } + .foreach { bm => bm.slaveActor ! removeMsg } } private def removeBlockManager(blockManagerId: BlockManagerId) { diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala index 4c5b31d0abe44..3ea710ebc786e 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala @@ -22,9 +22,11 @@ import java.io.{Externalizable, ObjectInput, ObjectOutput} import akka.actor.ActorRef private[storage] object BlockManagerMessages { + ////////////////////////////////////////////////////////////////////////////////// // Messages from the master to slaves. ////////////////////////////////////////////////////////////////////////////////// + sealed trait ToBlockManagerSlave // Remove a block from the slaves that have it. This can only be used to remove @@ -37,10 +39,15 @@ private[storage] object BlockManagerMessages { // Remove all blocks belonging to a specific shuffle. case class RemoveShuffle(shuffleId: Int) extends ToBlockManagerSlave + // Remove all blocks belonging to a specific broadcast. + case class RemoveBroadcast(broadcastId: Long, removeFromDriver: Boolean = true) + extends ToBlockManagerSlave + ////////////////////////////////////////////////////////////////////////////////// // Messages from slaves to the master. ////////////////////////////////////////////////////////////////////////////////// + sealed trait ToBlockManagerMaster case class RegisterBlockManager( @@ -57,8 +64,7 @@ private[storage] object BlockManagerMessages { var storageLevel: StorageLevel, var memSize: Long, var diskSize: Long) - extends ToBlockManagerMaster - with Externalizable { + extends ToBlockManagerMaster with Externalizable { def this() = this(null, null, null, 0, 0) // For deserialization only @@ -80,7 +86,8 @@ private[storage] object BlockManagerMessages { } object UpdateBlockInfo { - def apply(blockManagerId: BlockManagerId, + def apply( + blockManagerId: BlockManagerId, blockId: BlockId, storageLevel: StorageLevel, memSize: Long, diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala index 9a12481b7f6d5..8c2ccbe6a7e66 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala @@ -46,5 +46,8 @@ class BlockManagerSlaveActor( if (mapOutputTracker != null) { mapOutputTracker.unregisterShuffle(shuffleId) } + + case RemoveBroadcast(broadcastId, _) => + blockManager.removeBroadcast(broadcastId) } } diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index ad87fda140476..e541591ee7582 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -461,10 +461,10 @@ private[spark] object Utils extends Logging { private val hostPortParseResults = new ConcurrentHashMap[String, (String, Int)]() def parseHostPort(hostPort: String): (String, Int) = { - { - // Check cache first. - val cached = hostPortParseResults.get(hostPort) - if (cached != null) return cached + // Check cache first. + val cached = hostPortParseResults.get(hostPort) + if (cached != null) { + return cached } val indx: Int = hostPort.lastIndexOf(':') diff --git a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala index 11e22145ebb88..77d9825434706 100644 --- a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala @@ -28,8 +28,8 @@ import org.scalatest.concurrent.Eventually._ import org.scalatest.time.SpanSugar._ import org.apache.spark.SparkContext._ -import org.apache.spark.storage.{RDDBlockId, ShuffleBlockId} import org.apache.spark.rdd.RDD +import org.apache.spark.storage.{RDDBlockId, ShuffleBlockId} class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkContext { From e95479cd63b3259beddea278befd0bdee89bb17e Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Thu, 27 Mar 2014 14:37:51 -0700 Subject: [PATCH 06/14] Add tests for unpersisting broadcast There is not currently a way to query the blocks on the executors, an operation that is deceptively simple to accomplish. This commit adds this mechanism in order to verify that blocks are in fact persisted/unpersisted on the executors in the tests. --- .../apache/spark/broadcast/Broadcast.scala | 16 +- .../spark/broadcast/HttpBroadcast.scala | 13 +- .../spark/broadcast/TorrentBroadcast.scala | 13 +- .../apache/spark/storage/BlockManager.scala | 20 +- .../spark/storage/BlockManagerMaster.scala | 18 ++ .../storage/BlockManagerMasterActor.scala | 24 +- .../spark/storage/BlockManagerMessages.scala | 7 + .../storage/BlockManagerSlaveActor.scala | 7 +- .../org/apache/spark/BroadcastSuite.scala | 254 +++++++++++++++--- 9 files changed, 309 insertions(+), 63 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala index d75b9acfb7aa0..3a2fef05861e6 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala @@ -48,16 +48,26 @@ import java.io.Serializable * @tparam T Type of the data contained in the broadcast variable. */ abstract class Broadcast[T](val id: Long) extends Serializable { + + /** + * Whether this Broadcast is actually usable. This should be false once persisted state is + * removed from the driver. + */ + protected var isValid: Boolean = true + def value: T /** - * Remove all persisted state associated with this broadcast. + * Remove all persisted state associated with this broadcast. Overriding implementations + * should set isValid to false if persisted state is also removed from the driver. + * * @param removeFromDriver Whether to remove state from the driver. + * If true, the resulting broadcast should no longer be valid. */ def unpersist(removeFromDriver: Boolean) - // We cannot have an abstract readObject here due to some weird issues with - // readObject having to be 'private' in sub-classes. + // We cannot define abstract readObject and writeObject here due to some weird issues + // with these methods having to be 'private' in sub-classes. override def toString = "Broadcast(" + id + ")" } diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala index 4985d4202ed6b..d5e3d60a5b2b7 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala @@ -17,8 +17,8 @@ package org.apache.spark.broadcast -import java.io.{File, FileOutputStream, ObjectInputStream, OutputStream} -import java.net.{URL, URLConnection, URI} +import java.io.{File, FileOutputStream, ObjectInputStream, ObjectOutputStream, OutputStream} +import java.net.{URI, URL, URLConnection} import java.util.concurrent.TimeUnit import it.unimi.dsi.fastutil.io.{FastBufferedInputStream, FastBufferedOutputStream} @@ -49,10 +49,17 @@ private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolea * @param removeFromDriver Whether to remove state from the driver. */ override def unpersist(removeFromDriver: Boolean) { + isValid = !removeFromDriver HttpBroadcast.unpersist(id, removeFromDriver) } - // Called by JVM when deserializing an object + // Used by the JVM when serializing this object + private def writeObject(out: ObjectOutputStream) { + assert(isValid, "Attempted to serialize a broadcast variable that has been destroyed!") + out.defaultWriteObject() + } + + // Used by the JVM when deserializing this object private def readObject(in: ObjectInputStream) { in.defaultReadObject() HttpBroadcast.synchronized { diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala index 51f1592cef752..ace71575f5390 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala @@ -17,12 +17,12 @@ package org.apache.spark.broadcast -import java.io._ +import java.io.{ByteArrayInputStream, ObjectInputStream, ObjectOutputStream} import scala.math import scala.util.Random -import org.apache.spark._ +import org.apache.spark.{Logging, SparkConf, SparkEnv, SparkException} import org.apache.spark.storage.{BroadcastBlockId, BroadcastHelperBlockId, StorageLevel} import org.apache.spark.util.Utils @@ -76,10 +76,17 @@ private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boo * @param removeFromDriver Whether to remove state from the driver. */ override def unpersist(removeFromDriver: Boolean) { + isValid = !removeFromDriver TorrentBroadcast.unpersist(id, removeFromDriver) } - // Called by JVM when deserializing an object + // Used by the JVM when serializing this object + private def writeObject(out: ObjectOutputStream) { + assert(isValid, "Attempted to serialize a broadcast variable that has been destroyed!") + out.defaultWriteObject() + } + + // Used by the JVM when deserializing this object private def readObject(in: ObjectInputStream) { in.defaultReadObject() TorrentBroadcast.synchronized { diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 3c0941e195724..78dc32b4b1525 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -29,7 +29,7 @@ import akka.actor.{ActorSystem, Cancellable, Props} import it.unimi.dsi.fastutil.io.{FastBufferedOutputStream, FastByteArrayOutputStream} import sun.nio.ch.DirectBuffer -import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkEnv, SparkException, MapOutputTracker} +import org.apache.spark._ import org.apache.spark.io.CompressionCodec import org.apache.spark.network._ import org.apache.spark.serializer.Serializer @@ -58,7 +58,7 @@ private[spark] class BlockManager( private val blockInfo = new TimeStampedHashMap[BlockId, BlockInfo] - private[storage] val memoryStore: BlockStore = new MemoryStore(this, maxMemory) + private[storage] val memoryStore = new MemoryStore(this, maxMemory) private[storage] val diskStore = new DiskStore(this, diskBlockManager) // If we use Netty for shuffle, start a new Netty-based shuffle sender service. @@ -210,9 +210,9 @@ private[spark] class BlockManager( } /** - * Get storage level of local block. If no info exists for the block, then returns null. + * Get storage level of local block. If no info exists for the block, return None. */ - def getLevel(blockId: BlockId): StorageLevel = blockInfo.get(blockId).map(_.level).orNull + def getLevel(blockId: BlockId): Option[StorageLevel] = blockInfo.get(blockId).map(_.level) /** * Tell the master about the current storage status of a block. This will send a block update @@ -496,9 +496,8 @@ private[spark] class BlockManager( /** * A short circuited method to get a block writer that can write data directly to disk. - * The Block will be appended to the File specified by filename. - * This is currently used for writing shuffle files out. Callers should handle error - * cases. + * The Block will be appended to the File specified by filename. This is currently used for + * writing shuffle files out. Callers should handle error cases. */ def getDiskWriter( blockId: BlockId, @@ -816,8 +815,7 @@ private[spark] class BlockManager( * @return The number of blocks removed. */ def removeRdd(rddId: Int): Int = { - // TODO: Instead of doing a linear scan on the blockInfo map, create another map that maps - // from RDD.id to blocks. + // TODO: Avoid a linear scan by creating another mapping of RDD.id to blocks. logInfo("Removing RDD " + rddId) val blocksToRemove = blockInfo.keys.flatMap(_.asRDDId).filter(_.rddId == rddId) blocksToRemove.foreach { blockId => removeBlock(blockId, tellMaster = false) } @@ -827,13 +825,13 @@ private[spark] class BlockManager( /** * Remove all blocks belonging to the given broadcast. */ - def removeBroadcast(broadcastId: Long) { + def removeBroadcast(broadcastId: Long, removeFromDriver: Boolean) { logInfo("Removing broadcast " + broadcastId) val blocksToRemove = blockInfo.keys.filter(_.isBroadcast).collect { case bid: BroadcastBlockId if bid.broadcastId == broadcastId => bid case bid: BroadcastHelperBlockId if bid.broadcastId.broadcastId == broadcastId => bid } - blocksToRemove.foreach { blockId => removeBlock(blockId) } + blocksToRemove.foreach { blockId => removeBlock(blockId, removeFromDriver) } } /** diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala index 4579c0d959553..674322e3034c8 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala @@ -147,6 +147,24 @@ class BlockManagerMaster(var driverActor: ActorRef, conf: SparkConf) extends Log askDriverWithReply[Array[StorageStatus]](GetStorageStatus) } + /** + * Mainly for testing. Ask the driver to query all executors for their storage levels + * regarding this block. This provides an avenue for the driver to learn the storage + * levels of blocks it has not been informed of. + * + * WARNING: This could lead to deadlocks if there are any outstanding messages the + * executors are already expecting from the driver. In this case, while the driver is + * waiting for the executors to respond to its GetStorageLevel query, the executors + * are also waiting for a response from the driver to a prior message. + * + * The interim solution is to wait for a brief window of time to pass before asking. + * This should suffice, since this mechanism is largely introduced for testing only. + */ + def askForStorageLevels(blockId: BlockId, waitTimeMs: Long = 1000) = { + Thread.sleep(waitTimeMs) + askDriverWithReply[Map[BlockManagerId, StorageLevel]](AskForStorageLevels(blockId)) + } + /** Stop the driver actor, called only on the Spark driver node */ def stop() { if (driverActor != null) { diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala index 4cc4227fd87e2..f83c26dafe2e9 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala @@ -21,7 +21,7 @@ import java.util.{HashMap => JHashMap} import scala.collection.mutable import scala.collection.JavaConversions._ -import scala.concurrent.Future +import scala.concurrent.{Await, Future} import scala.concurrent.duration._ import akka.actor.{Actor, ActorRef, Cancellable} @@ -126,6 +126,9 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus case HeartBeat(blockManagerId) => sender ! heartBeat(blockManagerId) + case AskForStorageLevels(blockId) => + sender ! askForStorageLevels(blockId) + case other => logWarning("Got unknown message: " + other) } @@ -158,6 +161,11 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus blockManagerInfo.values.foreach { bm => bm.slaveActor ! removeMsg } } + /** + * Delegate RemoveBroadcast messages to each BlockManager because the master may not notified + * of all broadcast blocks. If removeFromDriver is false, broadcast blocks are only removed + * from the executors, but not from the driver. + */ private def removeBroadcast(broadcastId: Long, removeFromDriver: Boolean) { // TODO(aor): Consolidate usages of val removeMsg = RemoveBroadcast(broadcastId) @@ -246,6 +254,19 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus }.toArray } + // For testing. Ask all block managers for the given block's local storage level, if any. + private def askForStorageLevels(blockId: BlockId): Map[BlockManagerId, StorageLevel] = { + val getStorageLevel = GetStorageLevel(blockId) + blockManagerInfo.values.flatMap { info => + val future = info.slaveActor.ask(getStorageLevel)(akkaTimeout) + val result = Await.result(future, akkaTimeout) + if (result != null) { + // If the block does not exist on the slave, the slave replies None + result.asInstanceOf[Option[StorageLevel]].map { reply => (info.blockManagerId, reply) } + } else None + }.toMap + } + private def register(id: BlockManagerId, maxMemSize: Long, slaveActor: ActorRef) { if (!blockManagerInfo.contains(id)) { blockManagerIdByExecutor.get(id.executorId) match { @@ -329,6 +350,7 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus // Note that this logic will select the same node multiple times if there aren't enough peers Array.tabulate[BlockManagerId](size) { i => peers((selfIndex + i + 1) % peers.length) }.toSeq } + } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala index 3ea710ebc786e..1d3e94c4b6533 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala @@ -43,6 +43,9 @@ private[storage] object BlockManagerMessages { case class RemoveBroadcast(broadcastId: Long, removeFromDriver: Boolean = true) extends ToBlockManagerSlave + // For testing. Ask the slave for the block's storage level. + case class GetStorageLevel(blockId: BlockId) extends ToBlockManagerSlave + ////////////////////////////////////////////////////////////////////////////////// // Messages from slaves to the master. @@ -116,4 +119,8 @@ private[storage] object BlockManagerMessages { case object ExpireDeadHosts extends ToBlockManagerMaster case object GetStorageStatus extends ToBlockManagerMaster + + // For testing. Have the master ask all slaves for the given block's storage level. + case class AskForStorageLevels(blockId: BlockId) extends ToBlockManagerMaster + } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala index 8c2ccbe6a7e66..85b8ec40c0ea3 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala @@ -47,7 +47,10 @@ class BlockManagerSlaveActor( mapOutputTracker.unregisterShuffle(shuffleId) } - case RemoveBroadcast(broadcastId, _) => - blockManager.removeBroadcast(broadcastId) + case RemoveBroadcast(broadcastId, removeFromDriver) => + blockManager.removeBroadcast(broadcastId, removeFromDriver) + + case GetStorageLevel(blockId) => + sender ! blockManager.getLevel(blockId) } } diff --git a/core/src/test/scala/org/apache/spark/BroadcastSuite.scala b/core/src/test/scala/org/apache/spark/BroadcastSuite.scala index e022accee6d08..a462654197ea0 100644 --- a/core/src/test/scala/org/apache/spark/BroadcastSuite.scala +++ b/core/src/test/scala/org/apache/spark/BroadcastSuite.scala @@ -19,67 +19,241 @@ package org.apache.spark import org.scalatest.FunSuite +import org.apache.spark.storage._ +import org.apache.spark.broadcast.HttpBroadcast +import org.apache.spark.storage.{BroadcastBlockId, BroadcastHelperBlockId} + class BroadcastSuite extends FunSuite with LocalSparkContext { - override def afterEach() { - super.afterEach() - System.clearProperty("spark.broadcast.factory") - } + private val httpConf = broadcastConf("HttpBroadcastFactory") + private val torrentConf = broadcastConf("TorrentBroadcastFactory") test("Using HttpBroadcast locally") { - System.setProperty("spark.broadcast.factory", "org.apache.spark.broadcast.HttpBroadcastFactory") - sc = new SparkContext("local", "test") - val list = List(1, 2, 3, 4) - val listBroadcast = sc.broadcast(list) - val results = sc.parallelize(1 to 2).map(x => (x, listBroadcast.value.sum)) - assert(results.collect.toSet === Set((1, 10), (2, 10))) + sc = new SparkContext("local", "test", httpConf) + val list = List[Int](1, 2, 3, 4) + val broadcast = sc.broadcast(list) + val results = sc.parallelize(1 to 2).map(x => (x, broadcast.value.sum)) + assert(results.collect().toSet === Set((1, 10), (2, 10))) } test("Accessing HttpBroadcast variables from multiple threads") { - System.setProperty("spark.broadcast.factory", "org.apache.spark.broadcast.HttpBroadcastFactory") - sc = new SparkContext("local[10]", "test") - val list = List(1, 2, 3, 4) - val listBroadcast = sc.broadcast(list) - val results = sc.parallelize(1 to 10).map(x => (x, listBroadcast.value.sum)) - assert(results.collect.toSet === (1 to 10).map(x => (x, 10)).toSet) + sc = new SparkContext("local[10]", "test", httpConf) + val list = List[Int](1, 2, 3, 4) + val broadcast = sc.broadcast(list) + val results = sc.parallelize(1 to 10).map(x => (x, broadcast.value.sum)) + assert(results.collect().toSet === (1 to 10).map(x => (x, 10)).toSet) } test("Accessing HttpBroadcast variables in a local cluster") { - System.setProperty("spark.broadcast.factory", "org.apache.spark.broadcast.HttpBroadcastFactory") val numSlaves = 4 - sc = new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test") - val list = List(1, 2, 3, 4) - val listBroadcast = sc.broadcast(list) - val results = sc.parallelize(1 to numSlaves).map(x => (x, listBroadcast.value.sum)) - assert(results.collect.toSet === (1 to numSlaves).map(x => (x, 10)).toSet) + sc = new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test", httpConf) + val list = List[Int](1, 2, 3, 4) + val broadcast = sc.broadcast(list) + val results = sc.parallelize(1 to numSlaves).map(x => (x, broadcast.value.sum)) + assert(results.collect().toSet === (1 to numSlaves).map(x => (x, 10)).toSet) } test("Using TorrentBroadcast locally") { - System.setProperty("spark.broadcast.factory", "org.apache.spark.broadcast.TorrentBroadcastFactory") - sc = new SparkContext("local", "test") - val list = List(1, 2, 3, 4) - val listBroadcast = sc.broadcast(list) - val results = sc.parallelize(1 to 2).map(x => (x, listBroadcast.value.sum)) - assert(results.collect.toSet === Set((1, 10), (2, 10))) + sc = new SparkContext("local", "test", torrentConf) + val list = List[Int](1, 2, 3, 4) + val broadcast = sc.broadcast(list) + val results = sc.parallelize(1 to 2).map(x => (x, broadcast.value.sum)) + assert(results.collect().toSet === Set((1, 10), (2, 10))) } test("Accessing TorrentBroadcast variables from multiple threads") { - System.setProperty("spark.broadcast.factory", "org.apache.spark.broadcast.TorrentBroadcastFactory") - sc = new SparkContext("local[10]", "test") - val list = List(1, 2, 3, 4) - val listBroadcast = sc.broadcast(list) - val results = sc.parallelize(1 to 10).map(x => (x, listBroadcast.value.sum)) - assert(results.collect.toSet === (1 to 10).map(x => (x, 10)).toSet) + sc = new SparkContext("local[10]", "test", torrentConf) + val list = List[Int](1, 2, 3, 4) + val broadcast = sc.broadcast(list) + val results = sc.parallelize(1 to 10).map(x => (x, broadcast.value.sum)) + assert(results.collect().toSet === (1 to 10).map(x => (x, 10)).toSet) } test("Accessing TorrentBroadcast variables in a local cluster") { - System.setProperty("spark.broadcast.factory", "org.apache.spark.broadcast.TorrentBroadcastFactory") val numSlaves = 4 - sc = new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test") - val list = List(1, 2, 3, 4) - val listBroadcast = sc.broadcast(list) - val results = sc.parallelize(1 to numSlaves).map(x => (x, listBroadcast.value.sum)) - assert(results.collect.toSet === (1 to numSlaves).map(x => (x, 10)).toSet) + sc = new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test", torrentConf) + val list = List[Int](1, 2, 3, 4) + val broadcast = sc.broadcast(list) + val results = sc.parallelize(1 to numSlaves).map(x => (x, broadcast.value.sum)) + assert(results.collect().toSet === (1 to numSlaves).map(x => (x, 10)).toSet) + } + + test("Unpersisting HttpBroadcast on executors only") { + testUnpersistHttpBroadcast(2, removeFromDriver = false) + } + + test("Unpersisting HttpBroadcast on executors and driver") { + testUnpersistHttpBroadcast(2, removeFromDriver = true) + } + + test("Unpersisting TorrentBroadcast on executors only") { + testUnpersistTorrentBroadcast(2, removeFromDriver = false) + } + + test("Unpersisting TorrentBroadcast on executors and driver") { + testUnpersistTorrentBroadcast(2, removeFromDriver = true) + } + + /** + * Verify the persistence of state associated with an HttpBroadcast in a local-cluster. + * + * This test creates a broadcast variable, uses it on all executors, and then unpersists it. + * In between each step, this test verifies that the broadcast blocks and the broadcast file + * are present only on the expected nodes. + */ + private def testUnpersistHttpBroadcast(numSlaves: Int, removeFromDriver: Boolean) { + def getBlockIds(id: Long) = Seq[BlockId](BroadcastBlockId(id)) + + // Verify that the broadcast file is created, and blocks are persisted only on the driver + def afterCreation(blockIds: Seq[BlockId], bmm: BlockManagerMaster) { + assert(blockIds.size === 1) + val broadcastBlockId = blockIds.head.asInstanceOf[BroadcastBlockId] + val levels = bmm.askForStorageLevels(broadcastBlockId, waitTimeMs = 0) + assert(levels.size === 1) + levels.head match { case (bm, level) => + assert(bm.executorId === "") + assert(level === StorageLevel.MEMORY_AND_DISK) + } + assert(HttpBroadcast.getFile(broadcastBlockId.broadcastId).exists) + } + + // Verify that blocks are persisted in both the executors and the driver + def afterUsingBroadcast(blockIds: Seq[BlockId], bmm: BlockManagerMaster) { + assert(blockIds.size === 1) + val levels = bmm.askForStorageLevels(blockIds.head, waitTimeMs = 0) + assert(levels.size === numSlaves + 1) + levels.foreach { case (_, level) => + assert(level === StorageLevel.MEMORY_AND_DISK) + } + } + + // Verify that blocks are unpersisted on all executors, and on all nodes if removeFromDriver + // is true. In the latter case, also verify that the broadcast file is deleted on the driver. + def afterUnpersist(blockIds: Seq[BlockId], bmm: BlockManagerMaster) { + assert(blockIds.size === 1) + val broadcastBlockId = blockIds.head.asInstanceOf[BroadcastBlockId] + val levels = bmm.askForStorageLevels(broadcastBlockId, waitTimeMs = 0) + assert(levels.size === (if (removeFromDriver) 0 else 1)) + assert(removeFromDriver === !HttpBroadcast.getFile(broadcastBlockId.broadcastId).exists) + } + + testUnpersistBroadcast(numSlaves, httpConf, getBlockIds, afterCreation, + afterUsingBroadcast, afterUnpersist, removeFromDriver) + } + + /** + * Verify the persistence of state associated with an TorrentBroadcast in a local-cluster. + * + * This test creates a broadcast variable, uses it on all executors, and then unpersists it. + * In between each step, this test verifies that the broadcast blocks are present only on the + * expected nodes. + */ + private def testUnpersistTorrentBroadcast(numSlaves: Int, removeFromDriver: Boolean) { + def getBlockIds(id: Long) = { + val broadcastBlockId = BroadcastBlockId(id) + val metaBlockId = BroadcastHelperBlockId(broadcastBlockId, "meta") + // Assume broadcast value is small enough to fit into 1 piece + val pieceBlockId = BroadcastHelperBlockId(broadcastBlockId, "piece0") + Seq[BlockId](broadcastBlockId, metaBlockId, pieceBlockId) + } + + // Verify that blocks are persisted only on the driver + def afterCreation(blockIds: Seq[BlockId], bmm: BlockManagerMaster) { + blockIds.foreach { blockId => + val levels = bmm.askForStorageLevels(blockId, waitTimeMs = 0) + assert(levels.size === 1) + levels.head match { case (bm, level) => + assert(bm.executorId === "") + assert(level === StorageLevel.MEMORY_AND_DISK) + } + } + } + + // Verify that blocks are persisted in both the executors and the driver + def afterUsingBroadcast(blockIds: Seq[BlockId], bmm: BlockManagerMaster) { + blockIds.foreach { blockId => + val levels = bmm.askForStorageLevels(blockId, waitTimeMs = 0) + blockId match { + case BroadcastHelperBlockId(_, "meta") => + // Meta data is only on the driver + assert(levels.size === 1) + levels.head match { case (bm, _) => assert(bm.executorId === "") } + case _ => + // Other blocks are on both the executors and the driver + assert(levels.size === numSlaves + 1) + levels.foreach { case (_, level) => + assert(level === StorageLevel.MEMORY_AND_DISK) + } + } + } + } + + // Verify that blocks are unpersisted on all executors, and on all nodes if removeFromDriver + // is true. + def afterUnpersist(blockIds: Seq[BlockId], bmm: BlockManagerMaster) { + val expectedNumBlocks = if (removeFromDriver) 0 else 1 + var waitTimeMs = 1000L + blockIds.foreach { blockId => + // Allow a second for the messages triggered by unpersist to propagate to prevent deadlocks + val levels = bmm.askForStorageLevels(blockId, waitTimeMs) + assert(levels.size === expectedNumBlocks) + waitTimeMs = 0L + } + } + + testUnpersistBroadcast(numSlaves, torrentConf, getBlockIds, afterCreation, + afterUsingBroadcast, afterUnpersist, removeFromDriver) + } + + /** + * This test runs in 4 steps: + * + * 1) Create broadcast variable, and verify that all state is persisted on the driver. + * 2) Use the broadcast variable on all executors, and verify that all state is persisted + * on both the driver and the executors. + * 3) Unpersist the broadcast, and verify that all state is removed where they should be. + * 4) [Optional] If removeFromDriver is false, we verify that the broadcast is re-usable. + */ + private def testUnpersistBroadcast( + numSlaves: Int, + broadcastConf: SparkConf, + getBlockIds: Long => Seq[BlockId], + afterCreation: (Seq[BlockId], BlockManagerMaster) => Unit, + afterUsingBroadcast: (Seq[BlockId], BlockManagerMaster) => Unit, + afterUnpersist: (Seq[BlockId], BlockManagerMaster) => Unit, + removeFromDriver: Boolean) { + + sc = new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test", broadcastConf) + val blockManagerMaster = sc.env.blockManager.master + val list = List[Int](1, 2, 3, 4) + + // Create broadcast variable + val broadcast = sc.broadcast(list) + val blocks = getBlockIds(broadcast.id) + afterCreation(blocks, blockManagerMaster) + + // Use broadcast variable on all executors + val results = sc.parallelize(1 to numSlaves, numSlaves).map(x => (x, broadcast.value.sum)) + assert(results.collect().toSet === (1 to numSlaves).map(x => (x, 10)).toSet) + afterUsingBroadcast(blocks, blockManagerMaster) + + // Unpersist broadcast + broadcast.unpersist(removeFromDriver) + afterUnpersist(blocks, blockManagerMaster) + + if (!removeFromDriver) { + // The broadcast variable is not completely destroyed (i.e. state still exists on driver) + // Using the variable again should yield the same answer as before. + val results = sc.parallelize(1 to numSlaves, numSlaves).map(x => (x, broadcast.value.sum)) + assert(results.collect().toSet === (1 to numSlaves).map(x => (x, 10)).toSet) + } + } + + /** Helper method to create a SparkConf that uses the given broadcast factory. */ + private def broadcastConf(factoryName: String): SparkConf = { + val conf = new SparkConf + conf.set("spark.broadcast.factory", "org.apache.spark.broadcast.%s".format(factoryName)) + conf } } From f201a8d3c2f3c95da986760ac7ce4acb199f4e71 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Thu, 27 Mar 2014 15:39:51 -0700 Subject: [PATCH 07/14] Test broadcast cleanup in ContextCleanerSuite + remove BoundedHashMap --- .../apache/spark/util/BoundedHashMap.scala | 67 -------- .../apache/spark/ContextCleanerSuite.scala | 147 +++++++++++------- .../spark/util/WrappedJavaHashMapSuite.scala | 5 - 3 files changed, 94 insertions(+), 125 deletions(-) delete mode 100644 core/src/main/scala/org/apache/spark/util/BoundedHashMap.scala diff --git a/core/src/main/scala/org/apache/spark/util/BoundedHashMap.scala b/core/src/main/scala/org/apache/spark/util/BoundedHashMap.scala deleted file mode 100644 index 888a06b2408c9..0000000000000 --- a/core/src/main/scala/org/apache/spark/util/BoundedHashMap.scala +++ /dev/null @@ -1,67 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.util - -import scala.collection.mutable.{ArrayBuffer, SynchronizedMap} - -import java.util.{Collections, LinkedHashMap} -import java.util.Map.{Entry => JMapEntry} -import scala.reflect.ClassTag - -/** - * A map that upper bounds the number of key-value pairs present in it. It can be configured to - * drop the least recently user pair or the earliest inserted pair. It exposes a - * scala.collection.mutable.Map interface to allow it to be a drop-in replacement for Scala - * HashMaps. - * - * Internally, a Java LinkedHashMap is used to get insert-order or access-order behavior. - * Note that the LinkedHashMap is not thread-safe and hence, it is wrapped in a - * Collections.synchronizedMap. However, getting the Java HashMap's iterator and - * using it can still lead to ConcurrentModificationExceptions. Hence, the iterator() - * function is overridden to copy the all pairs into an ArrayBuffer and then return the - * iterator to the ArrayBuffer. Also, the class apply the trait SynchronizedMap which - * ensures that all calls to the Scala Map API are synchronized. This together ensures - * that ConcurrentModificationException is never thrown. - * - * @param bound max number of key-value pairs - * @param useLRU true = least recently used/accessed will be dropped when bound is reached, - * false = earliest inserted will be dropped - */ -private[spark] class BoundedHashMap[A, B](bound: Int, useLRU: Boolean) - extends WrappedJavaHashMap[A, B, A, B] with SynchronizedMap[A, B] { - - private[util] val internalJavaMap = Collections.synchronizedMap(new LinkedHashMap[A, B]( - bound / 8, (0.75).toFloat, useLRU) { - override protected def removeEldestEntry(eldest: JMapEntry[A, B]): Boolean = { - size() > bound - } - }) - - private[util] def newInstance[K1, V1](): WrappedJavaHashMap[K1, V1, _, _] = { - new BoundedHashMap[K1, V1](bound, useLRU) - } - - /** - * Overriding iterator to make sure that the internal Java HashMap's iterator - * is not concurrently modified. This can be a performance issue and this should be overridden - * if it is known that this map will not be used in a multi-threaded environment. - */ - override def iterator: Iterator[(A, B)] = { - (new ArrayBuffer[(A, B)] ++= super.iterator).iterator - } -} diff --git a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala index 77d9825434706..6a12cb6603700 100644 --- a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark import java.lang.ref.WeakReference -import scala.collection.mutable.{ArrayBuffer, HashSet, SynchronizedSet} +import scala.collection.mutable.{HashSet, SynchronizedSet} import scala.util.Random import org.scalatest.{BeforeAndAfter, FunSuite} @@ -29,7 +29,7 @@ import org.scalatest.time.SpanSugar._ import org.apache.spark.SparkContext._ import org.apache.spark.rdd.RDD -import org.apache.spark.storage.{RDDBlockId, ShuffleBlockId} +import org.apache.spark.storage.{BroadcastBlockId, RDDBlockId, ShuffleBlockId} class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkContext { @@ -46,9 +46,9 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo // Explicit cleanup cleaner.cleanupRDD(rdd) - tester.assertCleanup + tester.assertCleanup() - // verify that RDDs can be re-executed after cleaning up + // Verify that RDDs can be re-executed after cleaning up assert(rdd.collect().toList === collected) } @@ -59,87 +59,101 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo // Explicit cleanup shuffleDeps.foreach(s => cleaner.cleanupShuffle(s)) - tester.assertCleanup + tester.assertCleanup() // Verify that shuffles can be re-executed after cleaning up assert(rdd.collect().toList === collected) } + test("cleanup broadcast") { + val broadcast = newBroadcast + val tester = new CleanerTester(sc, broadcastIds = Seq(broadcast.id)) + + // Explicit cleanup + cleaner.cleanupBroadcast(broadcast) + tester.assertCleanup() + } + test("automatically cleanup RDD") { var rdd = newRDD.persist() rdd.count() - // test that GC does not cause RDD cleanup due to a strong reference + // Test that GC does not cause RDD cleanup due to a strong reference val preGCTester = new CleanerTester(sc, rddIds = Seq(rdd.id)) runGC() intercept[Exception] { - preGCTester.assertCleanup(timeout(1000 millis)) + preGCTester.assertCleanup()(timeout(1000 millis)) } - // test that GC causes RDD cleanup after dereferencing the RDD + // Test that GC causes RDD cleanup after dereferencing the RDD val postGCTester = new CleanerTester(sc, rddIds = Seq(rdd.id)) - rdd = null // make RDD out of scope + rdd = null // Make RDD out of scope runGC() - postGCTester.assertCleanup + postGCTester.assertCleanup() } test("automatically cleanup shuffle") { var rdd = newShuffleRDD rdd.count() - // test that GC does not cause shuffle cleanup due to a strong reference - val preGCTester = new CleanerTester(sc, shuffleIds = Seq(0)) + // Test that GC does not cause shuffle cleanup due to a strong reference + val preGCTester = new CleanerTester(sc, shuffleIds = Seq(0)) runGC() intercept[Exception] { - preGCTester.assertCleanup(timeout(1000 millis)) + preGCTester.assertCleanup()(timeout(1000 millis)) } - // test that GC causes shuffle cleanup after dereferencing the RDD + // Test that GC causes shuffle cleanup after dereferencing the RDD val postGCTester = new CleanerTester(sc, shuffleIds = Seq(0)) - rdd = null // make RDD out of scope, so that corresponding shuffle goes out of scope + rdd = null // Make RDD out of scope, so that corresponding shuffle goes out of scope runGC() - postGCTester.assertCleanup + postGCTester.assertCleanup() } - test("automatically cleanup RDD + shuffle") { + test("automatically cleanup broadcast") { + var broadcast = newBroadcast - def randomRDD: RDD[_] = { - val rdd: RDD[_] = Random.nextInt(3) match { - case 0 => newRDD - case 1 => newShuffleRDD - case 2 => newPairRDD.join(newPairRDD) - } - if (Random.nextBoolean()) rdd.persist() - rdd.count() - rdd + // Test that GC does not cause broadcast cleanup due to a strong reference + val preGCTester = new CleanerTester(sc, broadcastIds = Seq(broadcast.id)) + runGC() + intercept[Exception] { + preGCTester.assertCleanup()(timeout(1000 millis)) } - val buffer = new ArrayBuffer[RDD[_]] - for (i <- 1 to 500) { - buffer += randomRDD - } + // Test that GC causes broadcast cleanup after dereferencing the broadcast variable + val postGCTester = new CleanerTester(sc, broadcastIds = Seq(broadcast.id)) + broadcast = null // Make broadcast variable out of scope + runGC() + postGCTester.assertCleanup() + } + test("automatically cleanup RDD + shuffle + broadcast") { + val numRdds = 100 + val numBroadcasts = 4 // Broadcasts are more costly + val rddBuffer = (1 to numRdds).map(i => randomRdd).toBuffer + val broadcastBuffer = (1 to numBroadcasts).map(i => randomBroadcast).toBuffer val rddIds = sc.persistentRdds.keys.toSeq val shuffleIds = 0 until sc.newShuffleId + val broadcastIds = 0L until numBroadcasts - val preGCTester = new CleanerTester(sc, rddIds, shuffleIds) + val preGCTester = new CleanerTester(sc, rddIds, shuffleIds, broadcastIds) runGC() intercept[Exception] { - preGCTester.assertCleanup(timeout(1000 millis)) + preGCTester.assertCleanup()(timeout(1000 millis)) } - // test that GC causes shuffle cleanup after dereferencing the RDD - val postGCTester = new CleanerTester(sc, rddIds, shuffleIds) - buffer.clear() + + // Test that GC triggers the cleanup of all variables after the dereferencing them + val postGCTester = new CleanerTester(sc, rddIds, shuffleIds, broadcastIds) + broadcastBuffer.clear() + rddBuffer.clear() runGC() - postGCTester.assertCleanup + postGCTester.assertCleanup() } def newRDD = sc.makeRDD(1 to 10) - def newPairRDD = newRDD.map(_ -> 1) - def newShuffleRDD = newPairRDD.reduceByKey(_ + _) - + def newBroadcast = sc.broadcast(1 to 100) def newRDDWithShuffleDependencies: (RDD[_], Seq[ShuffleDependency[_, _]]) = { def getAllDependencies(rdd: RDD[_]): Seq[Dependency[_]] = { rdd.dependencies ++ rdd.dependencies.flatMap { dep => @@ -149,11 +163,27 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo val rdd = newShuffleRDD // Get all the shuffle dependencies - val shuffleDeps = getAllDependencies(rdd).filter(_.isInstanceOf[ShuffleDependency[_, _]]) + val shuffleDeps = getAllDependencies(rdd) + .filter(_.isInstanceOf[ShuffleDependency[_, _]]) .map(_.asInstanceOf[ShuffleDependency[_, _]]) (rdd, shuffleDeps) } + def randomRdd = { + val rdd: RDD[_] = Random.nextInt(3) match { + case 0 => newRDD + case 1 => newShuffleRDD + case 2 => newPairRDD.join(newPairRDD) + } + if (Random.nextBoolean()) rdd.persist() + rdd.count() + rdd + } + + def randomBroadcast = { + sc.broadcast(Random.nextInt(Int.MaxValue)) + } + /** Run GC and make sure it actually has run */ def runGC() { val weakRef = new WeakReference(new Object()) @@ -208,7 +238,7 @@ class CleanerTester( sc.cleaner.attachListener(cleanerListener) /** Assert that all the stuff has been cleaned up */ - def assertCleanup(implicit waitTimeout: Eventually.Timeout) { + def assertCleanup()(implicit waitTimeout: Eventually.Timeout) { try { eventually(waitTimeout, interval(10 millis)) { assert(isAllCleanedUp) @@ -222,7 +252,7 @@ class CleanerTester( /** Verify that RDDs, shuffles, etc. occupy resources */ private def preCleanupValidate() { - assert(rddIds.nonEmpty || shuffleIds.nonEmpty, "Nothing to cleanup") + assert(rddIds.nonEmpty || shuffleIds.nonEmpty || broadcastIds.nonEmpty, "Nothing to cleanup") // Verify the RDDs have been persisted and blocks are present assert(rddIds.forall(sc.persistentRdds.contains), @@ -233,8 +263,12 @@ class CleanerTester( // Verify the shuffle ids are registered and blocks are present assert(shuffleIds.forall(mapOutputTrackerMaster.containsShuffle), "One or more shuffles have not been registered cannot start cleaner test") - assert(shuffleIds.forall(shuffleId => diskBlockManager.containsBlock(shuffleBlockId(shuffleId))), + assert(shuffleIds.forall(sid => diskBlockManager.containsBlock(shuffleBlockId(sid))), "One or more shuffles' blocks cannot be found in disk manager, cannot start cleaner test") + + // Verify that the broadcast is in the driver's block manager + assert(broadcastIds.forall(bid => blockManager.getLevel(broadcastBlockId(bid)).isDefined), + "One ore more broadcasts have not been persisted in the driver's block manager") } /** @@ -247,14 +281,19 @@ class CleanerTester( attempts += 1 logInfo("Attempt: " + attempts) try { - // Verify all the RDDs have been unpersisted + // Verify all RDDs have been unpersisted assert(rddIds.forall(!sc.persistentRdds.contains(_))) assert(rddIds.forall(rddId => !blockManager.master.contains(rddBlockId(rddId)))) - // Verify all the shuffle have been deregistered and cleaned up + // Verify all shuffles have been deregistered and cleaned up assert(shuffleIds.forall(!mapOutputTrackerMaster.containsShuffle(_))) - assert(shuffleIds.forall(shuffleId => - !diskBlockManager.containsBlock(shuffleBlockId(shuffleId)))) + assert(shuffleIds.forall(sid => !diskBlockManager.containsBlock(shuffleBlockId(sid)))) + + // Verify all broadcasts have been unpersisted + assert(broadcastIds.forall { bid => + blockManager.master.askForStorageLevels(broadcastBlockId(bid)).isEmpty + }) + return } catch { case t: Throwable => @@ -271,18 +310,20 @@ class CleanerTester( s""" |\tRDDs = ${toBeCleanedRDDIds.mkString("[", ", ", "]")} |\tShuffles = ${toBeCleanedShuffleIds.mkString("[", ", ", "]")} + |\tBroadcasts = ${toBeCleanedBroadcstIds.mkString("[", ", ", "]")} """.stripMargin } - private def isAllCleanedUp = toBeCleanedRDDIds.isEmpty && toBeCleanedShuffleIds.isEmpty - - private def shuffleBlockId(shuffleId: Int) = ShuffleBlockId(shuffleId, 0, 0) + private def isAllCleanedUp = + toBeCleanedRDDIds.isEmpty && + toBeCleanedShuffleIds.isEmpty && + toBeCleanedBroadcstIds.isEmpty private def rddBlockId(rddId: Int) = RDDBlockId(rddId, 0) + private def shuffleBlockId(shuffleId: Int) = ShuffleBlockId(shuffleId, 0, 0) + private def broadcastBlockId(broadcastId: Long) = BroadcastBlockId(broadcastId) private def blockManager = sc.env.blockManager - private def diskBlockManager = blockManager.diskBlockManager - private def mapOutputTrackerMaster = sc.env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster] -} \ No newline at end of file +} diff --git a/core/src/test/scala/org/apache/spark/util/WrappedJavaHashMapSuite.scala b/core/src/test/scala/org/apache/spark/util/WrappedJavaHashMapSuite.scala index e446c7f75dc0b..0b9847174ac84 100644 --- a/core/src/test/scala/org/apache/spark/util/WrappedJavaHashMapSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/WrappedJavaHashMapSuite.scala @@ -33,11 +33,6 @@ class WrappedJavaHashMapSuite extends FunSuite { // Test a simple WrappedJavaHashMap testMap(new TestMap[String, String]()) - // Test BoundedHashMap - testMap(new BoundedHashMap[String, String](100, true)) - - testMapThreadSafety(new BoundedHashMap[String, String](100, true)) - // Test TimeStampedHashMap testMap(new TimeStampedHashMap[String, String]) From 0d170606469ad1d58f7743f9cd57247d45082fad Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Thu, 27 Mar 2014 19:07:55 -0700 Subject: [PATCH 08/14] Import, comments, and style fixes (minor) --- .../scala/org/apache/spark/MapOutputTracker.scala | 11 +++++------ .../scala/org/apache/spark/SparkContext.scala | 3 ++- .../main/scala/org/apache/spark/SparkEnv.scala | 1 + .../apache/spark/broadcast/BroadcastFactory.scala | 4 ++-- .../apache/spark/broadcast/HttpBroadcast.scala | 9 ++++----- .../apache/spark/broadcast/TorrentBroadcast.scala | 7 +++---- .../src/main/scala/org/apache/spark/rdd/RDD.scala | 3 ++- .../org/apache/spark/scheduler/DAGScheduler.scala | 1 - .../org/apache/spark/storage/BlockManager.scala | 4 ++-- .../apache/spark/storage/BlockManagerMaster.scala | 2 +- .../spark/storage/BlockManagerMasterActor.scala | 3 +-- .../spark/storage/BlockManagerMessages.scala | 4 ---- .../apache/spark/storage/DiskBlockManager.scala | 2 +- .../spark/storage/ShuffleBlockManager.scala | 4 ++-- .../org/apache/spark/storage/ThreadingTest.scala | 2 +- .../org/apache/spark/util/MetadataCleaner.scala | 15 ++++++++------- .../spark/util/TimeStampedWeakValueHashMap.scala | 8 ++++---- .../org/apache/spark/MapOutputTrackerSuite.scala | 2 +- .../apache/spark/storage/BlockManagerSuite.scala | 3 +-- .../spark/util/WrappedJavaHashMapSuite.scala | 2 +- 20 files changed, 42 insertions(+), 48 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index e1a273593cce5..c45c5c90048f3 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -112,8 +112,8 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging } /** - * Called from executors to get the server URIs and - * output sizes of the map outputs of a given shuffle + * Called from executors to get the server URIs and output sizes of the map outputs of + * a given shuffle. */ def getServerStatuses(shuffleId: Int, reduceId: Int): Array[(BlockManagerId, Long)] = { val statuses = mapStatuses.get(shuffleId).orNull @@ -218,10 +218,9 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf) private var cacheEpoch = epoch /** - * Timestamp based HashMap for storing mapStatuses and cached serialized statuses - * in the master, so that statuses are dropped only by explicit deregistering or - * by TTL-based cleaning (if set). Other than these two - * scenarios, nothing should be dropped from this HashMap. + * Timestamp based HashMap for storing mapStatuses and cached serialized statuses in the master, + * so that statuses are dropped only by explicit deregistering or by TTL-based cleaning (if set). + * Other than these two scenarios, nothing should be dropped from this HashMap. */ protected val mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]]() private val cachedSerializedStatuses = new TimeStampedHashMap[Int, Array[Byte]]() diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index fe84b812ba8d0..79574c271cfb6 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -230,6 +230,7 @@ class SparkContext( private[spark] val cleaner = new ContextCleaner(this) cleaner.start() + postEnvironmentUpdate() /** A default Hadoop Configuration for the Hadoop code (e.g. file systems) that we reuse. */ @@ -773,7 +774,7 @@ class SparkContext( * filesystems), an HTTP, HTTPS or FTP URI, or local:/path for a file on every worker node. */ def addJar(path: String) { - if (path == null) { + if (path == null) { logWarning("null specified as parameter to addJar") } else { var key = "" diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 521182021dd4b..62398dc930993 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -185,6 +185,7 @@ object SparkEnv extends Logging { } else { new MapOutputTrackerWorker(conf) } + // Have to assign trackerActor after initialization as MapOutputTrackerActor // requires the MapOutputTracker itself mapOutputTracker.trackerActor = registerOrLookup( diff --git a/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala b/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala index 850650951e603..9ff1675e76a5e 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala @@ -27,8 +27,8 @@ import org.apache.spark.SparkConf * entire Spark job. */ trait BroadcastFactory { - def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager): Unit + def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager) def newBroadcast[T](value: T, isLocal: Boolean, id: Long): Broadcast[T] def unbroadcast(id: Long, removeFromDriver: Boolean) - def stop(): Unit + def stop() } diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala index d5e3d60a5b2b7..d8981bb42e684 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala @@ -90,7 +90,7 @@ private[spark] object HttpBroadcast extends Logging { private var securityManager: SecurityManager = null // TODO: This shouldn't be a global variable so that multiple SparkContexts can coexist - val files = new TimeStampedHashSet[String] + private val files = new TimeStampedHashSet[String] private var cleaner: MetadataCleaner = null private val httpReadTimeout = TimeUnit.MILLISECONDS.convert(5, TimeUnit.MINUTES).toInt @@ -195,7 +195,7 @@ private[spark] object HttpBroadcast extends Logging { def unpersist(id: Long, removeFromDriver: Boolean) = synchronized { SparkEnv.get.blockManager.master.removeBroadcast(id, removeFromDriver) if (removeFromDriver) { - val file = new File(broadcastDir, BroadcastBlockId(id).name) + val file = getFile(id) files.remove(file.toString) deleteBroadcastFile(file) } @@ -217,10 +217,9 @@ private[spark] object HttpBroadcast extends Logging { } } - /** Delete the given broadcast file. */ private def deleteBroadcastFile(file: File) { try { - if (!file.exists()) { + if (!file.exists) { logWarning("Broadcast file to be deleted does not exist: %s".format(file)) } else if (file.delete()) { logInfo("Deleted broadcast file: %s".format(file)) @@ -229,7 +228,7 @@ private[spark] object HttpBroadcast extends Logging { } } catch { case e: Exception => - logWarning("Exception while deleting broadcast file: %s".format(file), e) + logError("Exception while deleting broadcast file: %s".format(file), e) } } diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala index ace71575f5390..ab280fad4e28f 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala @@ -72,7 +72,7 @@ private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boo } /** - * Remove all persisted state associated with this HTTP broadcast. + * Remove all persisted state associated with this Torrent broadcast. * @param removeFromDriver Whether to remove state from the driver. */ override def unpersist(removeFromDriver: Boolean) { @@ -177,13 +177,12 @@ private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boo } private[spark] object TorrentBroadcast extends Logging { + private lazy val BLOCK_SIZE = conf.getInt("spark.broadcast.blockSize", 4096) * 1024 private var initialized = false private var conf: SparkConf = null - lazy val BLOCK_SIZE = conf.getInt("spark.broadcast.blockSize", 4096) * 1024 - def initialize(_isDriver: Boolean, conf: SparkConf) { - TorrentBroadcast.conf = conf //TODO: we might have to fix it in tests + TorrentBroadcast.conf = conf // TODO: we might have to fix it in tests synchronized { if (!initialized) { initialized = true diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index e5638d0132e88..e8d36e6bfc810 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -158,7 +158,7 @@ abstract class RDD[T: ClassTag]( */ def unpersist(blocking: Boolean = true): RDD[T] = { logInfo("Removing RDD " + id + " from persistence list") - sc.unpersistRDD(this.id, blocking) + sc.unpersistRDD(id, blocking) storageLevel = StorageLevel.NONE this } @@ -1128,4 +1128,5 @@ abstract class RDD[T: ClassTag]( def toJavaRDD() : JavaRDD[T] = { new JavaRDD(this)(elementClassTag) } + } diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index edef40e7309f6..f31f0580c36fe 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -1090,7 +1090,6 @@ class DAGScheduler( eventProcessActor ! StopDAGScheduler } taskScheduler.stop() - listenerBus.stop() } } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 78dc32b4b1525..24ec8d3ab44bf 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -49,8 +49,8 @@ private[spark] class BlockManager( maxMemory: Long, val conf: SparkConf, securityManager: SecurityManager, - mapOutputTracker: MapOutputTracker - ) extends Logging { + mapOutputTracker: MapOutputTracker) + extends Logging { val shuffleBlockManager = new ShuffleBlockManager(this) val diskBlockManager = new DiskBlockManager(shuffleBlockManager, diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala index 674322e3034c8..5c9ea88d6b1a4 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala @@ -82,7 +82,7 @@ class BlockManagerMaster(var driverActor: ActorRef, conf: SparkConf) extends Log /** * Check if block manager master has a block. Note that this can be used to check for only - * those blocks that are expected to be reported to block manager master. + * those blocks that are reported to block manager master. */ def contains(blockId: BlockId) = { !getLocations(blockId).isEmpty diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala index f83c26dafe2e9..3271d4f1375ef 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala @@ -167,7 +167,7 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus * from the executors, but not from the driver. */ private def removeBroadcast(broadcastId: Long, removeFromDriver: Boolean) { - // TODO(aor): Consolidate usages of + // TODO: Consolidate usages of val removeMsg = RemoveBroadcast(broadcastId) blockManagerInfo.values .filter { info => removeFromDriver || info.blockManagerId.executorId != "" } @@ -350,7 +350,6 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus // Note that this logic will select the same node multiple times if there aren't enough peers Array.tabulate[BlockManagerId](size) { i => peers((selfIndex + i + 1) % peers.length) }.toSeq } - } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala index 1d3e94c4b6533..9a29c39a28ab1 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala @@ -22,11 +22,9 @@ import java.io.{Externalizable, ObjectInput, ObjectOutput} import akka.actor.ActorRef private[storage] object BlockManagerMessages { - ////////////////////////////////////////////////////////////////////////////////// // Messages from the master to slaves. ////////////////////////////////////////////////////////////////////////////////// - sealed trait ToBlockManagerSlave // Remove a block from the slaves that have it. This can only be used to remove @@ -50,7 +48,6 @@ private[storage] object BlockManagerMessages { ////////////////////////////////////////////////////////////////////////////////// // Messages from slaves to the master. ////////////////////////////////////////////////////////////////////////////////// - sealed trait ToBlockManagerMaster case class RegisterBlockManager( @@ -122,5 +119,4 @@ private[storage] object BlockManagerMessages { // For testing. Have the master ask all slaves for the given block's storage level. case class AskForStorageLevels(blockId: BlockId) extends ToBlockManagerMaster - } diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala index a57e6f710305a..fcad84669c79a 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala @@ -90,7 +90,7 @@ private[spark] class DiskBlockManager(shuffleManager: ShuffleBlockManager, rootD def getFile(blockId: BlockId): File = getFile(blockId.name) - /** Check if disk block manager has a block */ + /** Check if disk block manager has a block. */ def containsBlock(blockId: BlockId): Boolean = { getBlockLocation(blockId).file.exists() } diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala index cf83a60ffb9e8..06233153c56d4 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala @@ -169,13 +169,13 @@ class ShuffleBlockManager(blockManager: BlockManager) extends Logging { throw new IllegalStateException("Failed to find shuffle block: " + id) } - /** Remove all the blocks / files and metadata related to a particular shuffle */ + /** Remove all the blocks / files and metadata related to a particular shuffle. */ def removeShuffle(shuffleId: ShuffleId) { removeShuffleBlocks(shuffleId) shuffleStates.remove(shuffleId) } - /** Remove all the blocks / files related to a particular shuffle */ + /** Remove all the blocks / files related to a particular shuffle. */ private def removeShuffleBlocks(shuffleId: ShuffleId) { shuffleStates.get(shuffleId) match { case Some(state) => diff --git a/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala b/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala index 7b75215846a9a..a107c5182b3be 100644 --- a/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala +++ b/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala @@ -48,7 +48,7 @@ private[spark] object ThreadingTest { val block = (1 to blockSize).map(_ => Random.nextInt()) val level = randomLevel() val startTime = System.currentTimeMillis() - manager.put(blockId, block.iterator, level, true) + manager.put(blockId, block.iterator, level, tellMaster = true) println("Pushed block " + blockId + " in " + (System.currentTimeMillis - startTime) + " ms") queue.add((blockId, block)) } diff --git a/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala b/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala index 2ef853710a554..7ebed5105b9fd 100644 --- a/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala +++ b/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala @@ -78,15 +78,16 @@ private[spark] object MetadataCleaner { conf.getInt("spark.cleaner.ttl", -1) } - def getDelaySeconds(conf: SparkConf, cleanerType: MetadataCleanerType.MetadataCleanerType): Int = - { - conf.get(MetadataCleanerType.systemProperty(cleanerType), getDelaySeconds(conf).toString) - .toInt + def getDelaySeconds( + conf: SparkConf, + cleanerType: MetadataCleanerType.MetadataCleanerType): Int = { + conf.get(MetadataCleanerType.systemProperty(cleanerType), getDelaySeconds(conf).toString).toInt } - def setDelaySeconds(conf: SparkConf, cleanerType: MetadataCleanerType.MetadataCleanerType, - delay: Int) - { + def setDelaySeconds( + conf: SparkConf, + cleanerType: MetadataCleanerType.MetadataCleanerType, + delay: Int) { conf.set(MetadataCleanerType.systemProperty(cleanerType), delay.toString) } diff --git a/core/src/main/scala/org/apache/spark/util/TimeStampedWeakValueHashMap.scala b/core/src/main/scala/org/apache/spark/util/TimeStampedWeakValueHashMap.scala index 09a6faf33ec60..9f3247a27ba38 100644 --- a/core/src/main/scala/org/apache/spark/util/TimeStampedWeakValueHashMap.scala +++ b/core/src/main/scala/org/apache/spark/util/TimeStampedWeakValueHashMap.scala @@ -17,14 +17,14 @@ package org.apache.spark.util -import scala.collection.{JavaConversions, immutable} - -import java.util import java.lang.ref.WeakReference +import java.util import java.util.concurrent.ConcurrentHashMap +import java.util.concurrent.atomic.AtomicInteger + +import scala.collection.JavaConversions import org.apache.spark.Logging -import java.util.concurrent.atomic.AtomicInteger private[util] case class TimeStampedWeakValue[T](timestamp: Long, weakValue: WeakReference[T]) { def this(timestamp: Long, value: T) = this(timestamp, new WeakReference[T](value)) diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index b83033c35f6b7..6b2571cd9295e 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -96,7 +96,7 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext { assert(tracker.getServerStatuses(10, 0).isEmpty) } - test("master register shuffle and unregister mapoutput and fetch") { + test("master register shuffle and unregister map output and fetch") { val actorSystem = ActorSystem("test") val tracker = new MapOutputTrackerMaster(conf) tracker.trackerActor = diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 04e64ee7a45b3..1f5bcca64fc39 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -28,8 +28,7 @@ import org.scalatest.concurrent.Timeouts._ import org.scalatest.matchers.ShouldMatchers._ import org.scalatest.time.SpanSugar._ -import org.apache.spark.{MapOutputTrackerMaster, SecurityManager, SparkConf, SparkContext} -import org.apache.spark.{SecurityManager, SparkConf} +import org.apache.spark.{MapOutputTrackerMaster, SecurityManager, SparkConf} import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} import org.apache.spark.util.{AkkaUtils, ByteBufferInputStream, SizeEstimator, Utils} diff --git a/core/src/test/scala/org/apache/spark/util/WrappedJavaHashMapSuite.scala b/core/src/test/scala/org/apache/spark/util/WrappedJavaHashMapSuite.scala index 0b9847174ac84..f6e6a4c77c820 100644 --- a/core/src/test/scala/org/apache/spark/util/WrappedJavaHashMapSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/WrappedJavaHashMapSuite.scala @@ -17,8 +17,8 @@ package org.apache.spark.util -import java.util import java.lang.ref.WeakReference +import java.util import scala.collection.mutable.{ArrayBuffer, HashMap, Map} import scala.util.Random From 34f436f7d1799a6fd22b745d339734f220108dae Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Fri, 28 Mar 2014 13:17:20 -0700 Subject: [PATCH 09/14] Generalize BroadcastBlockId to remove BroadcastHelperBlockId Rather than having a special purpose BroadcastHelperBlockId just for TorrentBroadcast, we now have a single BroadcastBlockId that has a possibly empty field. This simplifies broadcast clean-up because now we only have to look for one type of block. This commit also simplifies BlockId JSON de/serialization in general by parsing the name through regex with apply. --- .../spark/broadcast/TorrentBroadcast.scala | 10 +-- .../org/apache/spark/storage/BlockId.scala | 29 ++++--- .../apache/spark/storage/BlockManager.scala | 3 +- .../org/apache/spark/util/JsonProtocol.scala | 77 +------------------ .../org/apache/spark/BroadcastSuite.scala | 61 +++++++-------- .../apache/spark/util/JsonProtocolSuite.scala | 14 ---- 6 files changed, 54 insertions(+), 140 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala index ab280fad4e28f..dbe65d88104fb 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala @@ -23,7 +23,7 @@ import scala.math import scala.util.Random import org.apache.spark.{Logging, SparkConf, SparkEnv, SparkException} -import org.apache.spark.storage.{BroadcastBlockId, BroadcastHelperBlockId, StorageLevel} +import org.apache.spark.storage.{BroadcastBlockId, StorageLevel} import org.apache.spark.util.Utils private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long) @@ -54,7 +54,7 @@ private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boo hasBlocks = tInfo.totalBlocks // Store meta-info - val metaId = BroadcastHelperBlockId(broadcastId, "meta") + val metaId = BroadcastBlockId(id, "meta") val metaInfo = TorrentInfo(null, totalBlocks, totalBytes) TorrentBroadcast.synchronized { SparkEnv.get.blockManager.putSingle( @@ -63,7 +63,7 @@ private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boo // Store individual pieces for (i <- 0 until totalBlocks) { - val pieceId = BroadcastHelperBlockId(broadcastId, "piece" + i) + val pieceId = BroadcastBlockId(id, "piece" + i) TorrentBroadcast.synchronized { SparkEnv.get.blockManager.putSingle( pieceId, tInfo.arrayOfBlocks(i), StorageLevel.MEMORY_AND_DISK, tellMaster = true) @@ -131,7 +131,7 @@ private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boo def receiveBroadcast(): Boolean = { // Receive meta-info - val metaId = BroadcastHelperBlockId(broadcastId, "meta") + val metaId = BroadcastBlockId(id, "meta") var attemptId = 10 while (attemptId > 0 && totalBlocks == -1) { TorrentBroadcast.synchronized { @@ -156,7 +156,7 @@ private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boo // Receive actual blocks val recvOrder = new Random().shuffle(Array.iterate(0, totalBlocks)(_ + 1).toList) for (pid <- recvOrder) { - val pieceId = BroadcastHelperBlockId(broadcastId, "piece" + pid) + val pieceId = BroadcastBlockId(id, "piece" + pid) TorrentBroadcast.synchronized { SparkEnv.get.blockManager.getSingle(pieceId) match { case Some(x) => diff --git a/core/src/main/scala/org/apache/spark/storage/BlockId.scala b/core/src/main/scala/org/apache/spark/storage/BlockId.scala index 301d784b350a3..27e271368ed06 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockId.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockId.scala @@ -34,7 +34,7 @@ private[spark] sealed abstract class BlockId { def asRDDId = if (isRDD) Some(asInstanceOf[RDDBlockId]) else None def isRDD = isInstanceOf[RDDBlockId] def isShuffle = isInstanceOf[ShuffleBlockId] - def isBroadcast = isInstanceOf[BroadcastBlockId] || isInstanceOf[BroadcastHelperBlockId] + def isBroadcast = isInstanceOf[BroadcastBlockId] override def toString = name override def hashCode = name.hashCode @@ -48,18 +48,15 @@ private[spark] case class RDDBlockId(rddId: Int, splitIndex: Int) extends BlockI def name = "rdd_" + rddId + "_" + splitIndex } -private[spark] -case class ShuffleBlockId(shuffleId: Int, mapId: Int, reduceId: Int) extends BlockId { +private[spark] case class ShuffleBlockId(shuffleId: Int, mapId: Int, reduceId: Int) + extends BlockId { def name = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId } +// Leave field as an instance variable to avoid matching on it private[spark] case class BroadcastBlockId(broadcastId: Long) extends BlockId { - def name = "broadcast_" + broadcastId -} - -private[spark] -case class BroadcastHelperBlockId(broadcastId: BroadcastBlockId, hType: String) extends BlockId { - def name = broadcastId.name + "_" + hType + var field = "" + def name = "broadcast_" + broadcastId + (if (field == "") "" else "_" + field) } private[spark] case class TaskResultBlockId(taskId: Long) extends BlockId { @@ -80,11 +77,19 @@ private[spark] case class TestBlockId(id: String) extends BlockId { def name = "test_" + id } +private[spark] object BroadcastBlockId { + def apply(broadcastId: Long, field: String) = { + val blockId = new BroadcastBlockId(broadcastId) + blockId.field = field + blockId + } +} + private[spark] object BlockId { val RDD = "rdd_([0-9]+)_([0-9]+)".r val SHUFFLE = "shuffle_([0-9]+)_([0-9]+)_([0-9]+)".r val BROADCAST = "broadcast_([0-9]+)".r - val BROADCAST_HELPER = "broadcast_([0-9]+)_([A-Za-z0-9]+)".r + val BROADCAST_FIELD = "broadcast_([0-9]+)_([A-Za-z0-9]+)".r val TASKRESULT = "taskresult_([0-9]+)".r val STREAM = "input-([0-9]+)-([0-9]+)".r val TEST = "test_(.*)".r @@ -97,8 +102,8 @@ private[spark] object BlockId { ShuffleBlockId(shuffleId.toInt, mapId.toInt, reduceId.toInt) case BROADCAST(broadcastId) => BroadcastBlockId(broadcastId.toLong) - case BROADCAST_HELPER(broadcastId, hType) => - BroadcastHelperBlockId(BroadcastBlockId(broadcastId.toLong), hType) + case BROADCAST_FIELD(broadcastId, field) => + BroadcastBlockId(broadcastId.toLong, field) case TASKRESULT(taskId) => TaskResultBlockId(taskId.toLong) case STREAM(streamId, uniqueId) => diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 24ec8d3ab44bf..a88eb1315a37b 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -827,9 +827,8 @@ private[spark] class BlockManager( */ def removeBroadcast(broadcastId: Long, removeFromDriver: Boolean) { logInfo("Removing broadcast " + broadcastId) - val blocksToRemove = blockInfo.keys.filter(_.isBroadcast).collect { + val blocksToRemove = blockInfo.keys.collect { case bid: BroadcastBlockId if bid.broadcastId == broadcastId => bid - case bid: BroadcastHelperBlockId if bid.broadcastId.broadcastId == broadcastId => bid } blocksToRemove.foreach { blockId => removeBlock(blockId, removeFromDriver) } } diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index 346f2b7856791..d9a6af61872d1 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -195,7 +195,7 @@ private[spark] object JsonProtocol { taskMetrics.shuffleWriteMetrics.map(shuffleWriteMetricsToJson).getOrElse(JNothing) val updatedBlocks = taskMetrics.updatedBlocks.map { blocks => JArray(blocks.toList.map { case (id, status) => - ("Block ID" -> blockIdToJson(id)) ~ + ("Block ID" -> id.toString) ~ ("Status" -> blockStatusToJson(status)) }) }.getOrElse(JNothing) @@ -284,35 +284,6 @@ private[spark] object JsonProtocol { ("Replication" -> storageLevel.replication) } - def blockIdToJson(blockId: BlockId): JValue = { - val blockType = Utils.getFormattedClassName(blockId) - val json: JObject = blockId match { - case rddBlockId: RDDBlockId => - ("RDD ID" -> rddBlockId.rddId) ~ - ("Split Index" -> rddBlockId.splitIndex) - case shuffleBlockId: ShuffleBlockId => - ("Shuffle ID" -> shuffleBlockId.shuffleId) ~ - ("Map ID" -> shuffleBlockId.mapId) ~ - ("Reduce ID" -> shuffleBlockId.reduceId) - case broadcastBlockId: BroadcastBlockId => - "Broadcast ID" -> broadcastBlockId.broadcastId - case broadcastHelperBlockId: BroadcastHelperBlockId => - ("Broadcast Block ID" -> blockIdToJson(broadcastHelperBlockId.broadcastId)) ~ - ("Helper Type" -> broadcastHelperBlockId.hType) - case taskResultBlockId: TaskResultBlockId => - "Task ID" -> taskResultBlockId.taskId - case streamBlockId: StreamBlockId => - ("Stream ID" -> streamBlockId.streamId) ~ - ("Unique ID" -> streamBlockId.uniqueId) - case tempBlockId: TempBlockId => - val uuid = UUIDToJson(tempBlockId.id) - "Temp ID" -> uuid - case testBlockId: TestBlockId => - "Test ID" -> testBlockId.id - } - ("Type" -> blockType) ~ json - } - def blockStatusToJson(blockStatus: BlockStatus): JValue = { val storageLevel = storageLevelToJson(blockStatus.storageLevel) ("Storage Level" -> storageLevel) ~ @@ -513,7 +484,7 @@ private[spark] object JsonProtocol { Utils.jsonOption(json \ "Shuffle Write Metrics").map(shuffleWriteMetricsFromJson) metrics.updatedBlocks = Utils.jsonOption(json \ "Updated Blocks").map { value => value.extract[List[JValue]].map { block => - val id = blockIdFromJson(block \ "Block ID") + val id = BlockId((block \ "Block ID").extract[String]) val status = blockStatusFromJson(block \ "Status") (id, status) } @@ -616,50 +587,6 @@ private[spark] object JsonProtocol { StorageLevel(useDisk, useMemory, deserialized, replication) } - def blockIdFromJson(json: JValue): BlockId = { - val rddBlockId = Utils.getFormattedClassName(RDDBlockId) - val shuffleBlockId = Utils.getFormattedClassName(ShuffleBlockId) - val broadcastBlockId = Utils.getFormattedClassName(BroadcastBlockId) - val broadcastHelperBlockId = Utils.getFormattedClassName(BroadcastHelperBlockId) - val taskResultBlockId = Utils.getFormattedClassName(TaskResultBlockId) - val streamBlockId = Utils.getFormattedClassName(StreamBlockId) - val tempBlockId = Utils.getFormattedClassName(TempBlockId) - val testBlockId = Utils.getFormattedClassName(TestBlockId) - - (json \ "Type").extract[String] match { - case `rddBlockId` => - val rddId = (json \ "RDD ID").extract[Int] - val splitIndex = (json \ "Split Index").extract[Int] - new RDDBlockId(rddId, splitIndex) - case `shuffleBlockId` => - val shuffleId = (json \ "Shuffle ID").extract[Int] - val mapId = (json \ "Map ID").extract[Int] - val reduceId = (json \ "Reduce ID").extract[Int] - new ShuffleBlockId(shuffleId, mapId, reduceId) - case `broadcastBlockId` => - val broadcastId = (json \ "Broadcast ID").extract[Long] - new BroadcastBlockId(broadcastId) - case `broadcastHelperBlockId` => - val broadcastBlockId = - blockIdFromJson(json \ "Broadcast Block ID").asInstanceOf[BroadcastBlockId] - val hType = (json \ "Helper Type").extract[String] - new BroadcastHelperBlockId(broadcastBlockId, hType) - case `taskResultBlockId` => - val taskId = (json \ "Task ID").extract[Long] - new TaskResultBlockId(taskId) - case `streamBlockId` => - val streamId = (json \ "Stream ID").extract[Int] - val uniqueId = (json \ "Unique ID").extract[Long] - new StreamBlockId(streamId, uniqueId) - case `tempBlockId` => - val tempId = UUIDFromJson(json \ "Temp ID") - new TempBlockId(tempId) - case `testBlockId` => - val testId = (json \ "Test ID").extract[String] - new TestBlockId(testId) - } - } - def blockStatusFromJson(json: JValue): BlockStatus = { val storageLevel = storageLevelFromJson(json \ "Storage Level") val memorySize = (json \ "Memory Size").extract[Long] diff --git a/core/src/test/scala/org/apache/spark/BroadcastSuite.scala b/core/src/test/scala/org/apache/spark/BroadcastSuite.scala index a462654197ea0..9e600f1e91aa2 100644 --- a/core/src/test/scala/org/apache/spark/BroadcastSuite.scala +++ b/core/src/test/scala/org/apache/spark/BroadcastSuite.scala @@ -21,7 +21,7 @@ import org.scalatest.FunSuite import org.apache.spark.storage._ import org.apache.spark.broadcast.HttpBroadcast -import org.apache.spark.storage.{BroadcastBlockId, BroadcastHelperBlockId} +import org.apache.spark.storage.BroadcastBlockId class BroadcastSuite extends FunSuite with LocalSparkContext { @@ -102,23 +102,22 @@ class BroadcastSuite extends FunSuite with LocalSparkContext { * are present only on the expected nodes. */ private def testUnpersistHttpBroadcast(numSlaves: Int, removeFromDriver: Boolean) { - def getBlockIds(id: Long) = Seq[BlockId](BroadcastBlockId(id)) + def getBlockIds(id: Long) = Seq[BroadcastBlockId](BroadcastBlockId(id)) // Verify that the broadcast file is created, and blocks are persisted only on the driver - def afterCreation(blockIds: Seq[BlockId], bmm: BlockManagerMaster) { + def afterCreation(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) { assert(blockIds.size === 1) - val broadcastBlockId = blockIds.head.asInstanceOf[BroadcastBlockId] - val levels = bmm.askForStorageLevels(broadcastBlockId, waitTimeMs = 0) + val levels = bmm.askForStorageLevels(blockIds.head, waitTimeMs = 0) assert(levels.size === 1) levels.head match { case (bm, level) => assert(bm.executorId === "") assert(level === StorageLevel.MEMORY_AND_DISK) } - assert(HttpBroadcast.getFile(broadcastBlockId.broadcastId).exists) + assert(HttpBroadcast.getFile(blockIds.head.broadcastId).exists) } // Verify that blocks are persisted in both the executors and the driver - def afterUsingBroadcast(blockIds: Seq[BlockId], bmm: BlockManagerMaster) { + def afterUsingBroadcast(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) { assert(blockIds.size === 1) val levels = bmm.askForStorageLevels(blockIds.head, waitTimeMs = 0) assert(levels.size === numSlaves + 1) @@ -129,12 +128,11 @@ class BroadcastSuite extends FunSuite with LocalSparkContext { // Verify that blocks are unpersisted on all executors, and on all nodes if removeFromDriver // is true. In the latter case, also verify that the broadcast file is deleted on the driver. - def afterUnpersist(blockIds: Seq[BlockId], bmm: BlockManagerMaster) { + def afterUnpersist(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) { assert(blockIds.size === 1) - val broadcastBlockId = blockIds.head.asInstanceOf[BroadcastBlockId] - val levels = bmm.askForStorageLevels(broadcastBlockId, waitTimeMs = 0) + val levels = bmm.askForStorageLevels(blockIds.head, waitTimeMs = 0) assert(levels.size === (if (removeFromDriver) 0 else 1)) - assert(removeFromDriver === !HttpBroadcast.getFile(broadcastBlockId.broadcastId).exists) + assert(removeFromDriver === !HttpBroadcast.getFile(blockIds.head.broadcastId).exists) } testUnpersistBroadcast(numSlaves, httpConf, getBlockIds, afterCreation, @@ -151,14 +149,14 @@ class BroadcastSuite extends FunSuite with LocalSparkContext { private def testUnpersistTorrentBroadcast(numSlaves: Int, removeFromDriver: Boolean) { def getBlockIds(id: Long) = { val broadcastBlockId = BroadcastBlockId(id) - val metaBlockId = BroadcastHelperBlockId(broadcastBlockId, "meta") + val metaBlockId = BroadcastBlockId(id, "meta") // Assume broadcast value is small enough to fit into 1 piece - val pieceBlockId = BroadcastHelperBlockId(broadcastBlockId, "piece0") - Seq[BlockId](broadcastBlockId, metaBlockId, pieceBlockId) + val pieceBlockId = BroadcastBlockId(id, "piece0") + Seq[BroadcastBlockId](broadcastBlockId, metaBlockId, pieceBlockId) } // Verify that blocks are persisted only on the driver - def afterCreation(blockIds: Seq[BlockId], bmm: BlockManagerMaster) { + def afterCreation(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) { blockIds.foreach { blockId => val levels = bmm.askForStorageLevels(blockId, waitTimeMs = 0) assert(levels.size === 1) @@ -170,27 +168,26 @@ class BroadcastSuite extends FunSuite with LocalSparkContext { } // Verify that blocks are persisted in both the executors and the driver - def afterUsingBroadcast(blockIds: Seq[BlockId], bmm: BlockManagerMaster) { + def afterUsingBroadcast(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) { blockIds.foreach { blockId => val levels = bmm.askForStorageLevels(blockId, waitTimeMs = 0) - blockId match { - case BroadcastHelperBlockId(_, "meta") => - // Meta data is only on the driver - assert(levels.size === 1) - levels.head match { case (bm, _) => assert(bm.executorId === "") } - case _ => - // Other blocks are on both the executors and the driver - assert(levels.size === numSlaves + 1) - levels.foreach { case (_, level) => - assert(level === StorageLevel.MEMORY_AND_DISK) - } + if (blockId.field == "meta") { + // Meta data is only on the driver + assert(levels.size === 1) + levels.head match { case (bm, _) => assert(bm.executorId === "") } + } else { + // Other blocks are on both the executors and the driver + assert(levels.size === numSlaves + 1) + levels.foreach { case (_, level) => + assert(level === StorageLevel.MEMORY_AND_DISK) + } } } } // Verify that blocks are unpersisted on all executors, and on all nodes if removeFromDriver // is true. - def afterUnpersist(blockIds: Seq[BlockId], bmm: BlockManagerMaster) { + def afterUnpersist(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) { val expectedNumBlocks = if (removeFromDriver) 0 else 1 var waitTimeMs = 1000L blockIds.foreach { blockId => @@ -217,10 +214,10 @@ class BroadcastSuite extends FunSuite with LocalSparkContext { private def testUnpersistBroadcast( numSlaves: Int, broadcastConf: SparkConf, - getBlockIds: Long => Seq[BlockId], - afterCreation: (Seq[BlockId], BlockManagerMaster) => Unit, - afterUsingBroadcast: (Seq[BlockId], BlockManagerMaster) => Unit, - afterUnpersist: (Seq[BlockId], BlockManagerMaster) => Unit, + getBlockIds: Long => Seq[BroadcastBlockId], + afterCreation: (Seq[BroadcastBlockId], BlockManagerMaster) => Unit, + afterUsingBroadcast: (Seq[BroadcastBlockId], BlockManagerMaster) => Unit, + afterUnpersist: (Seq[BroadcastBlockId], BlockManagerMaster) => Unit, removeFromDriver: Boolean) { sc = new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test", broadcastConf) diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index 67c0a434c9b52..580ac34f5f0b4 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -104,15 +104,6 @@ class JsonProtocolSuite extends FunSuite { testTaskEndReason(TaskKilled) testTaskEndReason(ExecutorLostFailure) testTaskEndReason(UnknownReason) - - // BlockId - testBlockId(RDDBlockId(1, 2)) - testBlockId(ShuffleBlockId(1, 2, 3)) - testBlockId(BroadcastBlockId(1L)) - testBlockId(BroadcastHelperBlockId(BroadcastBlockId(2L), "Spark")) - testBlockId(TaskResultBlockId(1L)) - testBlockId(StreamBlockId(1, 2L)) - testBlockId(TempBlockId(UUID.randomUUID())) } @@ -167,11 +158,6 @@ class JsonProtocolSuite extends FunSuite { assertEquals(reason, newReason) } - private def testBlockId(blockId: BlockId) { - val newBlockId = JsonProtocol.blockIdFromJson(JsonProtocol.blockIdToJson(blockId)) - blockId == newBlockId - } - /** -------------------------------- * | Util methods for comparing events | From fbfeec80cfb7a1bd86847fa22f641d9b9ad7480f Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Fri, 28 Mar 2014 18:33:11 -0700 Subject: [PATCH 10/14] Add functionality to query executors for their local BlockStatuses Not all blocks are reported to the master. In HttpBroadcast and TorrentBroadcast, for instance, most blocks are not reported to master. The lack of a mechanism to get local block statuses on each executor makes it difficult to test the correctness of un/persisting a broadcast. This new functionality, though only used for testing at the moment, is general enough to be used for other things in the future. --- .../spark/network/ConnectionManager.scala | 1 - .../org/apache/spark/storage/BlockInfo.scala | 2 + .../apache/spark/storage/BlockManager.scala | 15 ++-- .../spark/storage/BlockManagerMaster.scala | 33 +++++---- .../storage/BlockManagerMasterActor.scala | 47 ++++++++----- .../spark/storage/BlockManagerMessages.scala | 11 ++- .../storage/BlockManagerSlaveActor.scala | 4 +- .../org/apache/spark/BroadcastSuite.scala | 69 +++++++++++-------- .../apache/spark/ContextCleanerSuite.scala | 4 +- .../spark/storage/BlockManagerSuite.scala | 40 +++++++++++ 10 files changed, 150 insertions(+), 76 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala index a75130cba2a2e..bb3abf1d032d1 100644 --- a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala +++ b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala @@ -17,7 +17,6 @@ package org.apache.spark.network -import java.net._ import java.nio._ import java.nio.channels._ import java.nio.channels.spi._ diff --git a/core/src/main/scala/org/apache/spark/storage/BlockInfo.scala b/core/src/main/scala/org/apache/spark/storage/BlockInfo.scala index c8f397609a0b4..ef924123a3b11 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockInfo.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockInfo.scala @@ -79,3 +79,5 @@ private object BlockInfo { private val BLOCK_PENDING: Long = -1L private val BLOCK_FAILED: Long = -2L } + +private[spark] case class BlockStatus(storageLevel: StorageLevel, memSize: Long, diskSize: Long) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index a88eb1315a37b..dd2dbd1c8a397 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -209,10 +209,14 @@ private[spark] class BlockManager( } } - /** - * Get storage level of local block. If no info exists for the block, return None. - */ - def getLevel(blockId: BlockId): Option[StorageLevel] = blockInfo.get(blockId).map(_.level) + /** Return the status of the block identified by the given ID, if it exists. */ + def getStatus(blockId: BlockId): Option[BlockStatus] = { + blockInfo.get(blockId).map { info => + val memSize = if (memoryStore.contains(blockId)) memoryStore.getSize(blockId) else 0L + val diskSize = if (diskStore.contains(blockId)) diskStore.getSize(blockId) else 0L + BlockStatus(info.level, memSize, diskSize) + } + } /** * Tell the master about the current storage status of a block. This will send a block update @@ -631,10 +635,9 @@ private[spark] class BlockManager( diskStore.putValues(blockId, iterator, level, askForBytes) case ArrayBufferValues(array) => diskStore.putValues(blockId, array, level, askForBytes) - case ByteBufferValues(bytes) => { + case ByteBufferValues(bytes) => bytes.rewind() diskStore.putBytes(blockId, bytes, level) - } } size = res.size res.data match { diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala index 5c9ea88d6b1a4..f61aa1d6bc0fc 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala @@ -148,21 +148,30 @@ class BlockManagerMaster(var driverActor: ActorRef, conf: SparkConf) extends Log } /** - * Mainly for testing. Ask the driver to query all executors for their storage levels - * regarding this block. This provides an avenue for the driver to learn the storage - * levels of blocks it has not been informed of. + * Return the block's local status on all block managers, if any. * - * WARNING: This could lead to deadlocks if there are any outstanding messages the - * executors are already expecting from the driver. In this case, while the driver is - * waiting for the executors to respond to its GetStorageLevel query, the executors - * are also waiting for a response from the driver to a prior message. + * If askSlaves is true, this invokes the master to query each block manager for the most + * updated block statuses. This is useful when the master is not informed of the given block + * by all block managers. * - * The interim solution is to wait for a brief window of time to pass before asking. - * This should suffice, since this mechanism is largely introduced for testing only. + * To avoid potential deadlocks, the use of Futures is necessary, because the master actor + * should not block on waiting for a block manager, which can in turn be waiting for the + * master actor for a response to a prior message. */ - def askForStorageLevels(blockId: BlockId, waitTimeMs: Long = 1000) = { - Thread.sleep(waitTimeMs) - askDriverWithReply[Map[BlockManagerId, StorageLevel]](AskForStorageLevels(blockId)) + def getBlockStatus( + blockId: BlockId, + askSlaves: Boolean = true): Map[BlockManagerId, BlockStatus] = { + val msg = GetBlockStatus(blockId, askSlaves) + val response = askDriverWithReply[Map[BlockManagerId, Future[Option[BlockStatus]]]](msg) + val (blockManagerIds, futures) = response.unzip + val result = Await.result(Future.sequence(futures), timeout) + if (result == null) { + throw new SparkException("BlockManager returned null for BlockStatus query: " + blockId) + } + val blockStatus = result.asInstanceOf[Iterable[Option[BlockStatus]]] + blockManagerIds.zip(blockStatus).flatMap { case (blockManagerId, status) => + status.map { s => (blockManagerId, s) } + }.toMap } /** Stop the driver actor, called only on the Spark driver node */ diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala index 3271d4f1375ef..2d9445425b879 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala @@ -21,7 +21,7 @@ import java.util.{HashMap => JHashMap} import scala.collection.mutable import scala.collection.JavaConversions._ -import scala.concurrent.{Await, Future} +import scala.concurrent.Future import scala.concurrent.duration._ import akka.actor.{Actor, ActorRef, Cancellable} @@ -93,6 +93,9 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus case GetStorageStatus => sender ! storageStatus + case GetBlockStatus(blockId, askSlaves) => + sender ! blockStatus(blockId, askSlaves) + case RemoveRdd(rddId) => sender ! removeRdd(rddId) @@ -126,9 +129,6 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus case HeartBeat(blockManagerId) => sender ! heartBeat(blockManagerId) - case AskForStorageLevels(blockId) => - sender ! askForStorageLevels(blockId) - case other => logWarning("Got unknown message: " + other) } @@ -254,16 +254,30 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus }.toArray } - // For testing. Ask all block managers for the given block's local storage level, if any. - private def askForStorageLevels(blockId: BlockId): Map[BlockManagerId, StorageLevel] = { - val getStorageLevel = GetStorageLevel(blockId) - blockManagerInfo.values.flatMap { info => - val future = info.slaveActor.ask(getStorageLevel)(akkaTimeout) - val result = Await.result(future, akkaTimeout) - if (result != null) { - // If the block does not exist on the slave, the slave replies None - result.asInstanceOf[Option[StorageLevel]].map { reply => (info.blockManagerId, reply) } - } else None + /** + * Return the block's local status for all block managers, if any. + * + * If askSlaves is true, the master queries each block manager for the most updated block + * statuses. This is useful when the master is not informed of the given block by all block + * managers. + * + * Rather than blocking on the block status query, master actor should simply return a + * Future to avoid potential deadlocks. This can arise if there exists a block manager + * that is also waiting for this master actor's response to a previous message. + */ + private def blockStatus( + blockId: BlockId, + askSlaves: Boolean): Map[BlockManagerId, Future[Option[BlockStatus]]] = { + import context.dispatcher + val getBlockStatus = GetBlockStatus(blockId) + blockManagerInfo.values.map { info => + val blockStatusFuture = + if (askSlaves) { + info.slaveActor.ask(getBlockStatus)(akkaTimeout).mapTo[Option[BlockStatus]] + } else { + Future { info.getStatus(blockId) } + } + (info.blockManagerId, blockStatusFuture) }.toMap } @@ -352,9 +366,6 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus } } - -private[spark] case class BlockStatus(storageLevel: StorageLevel, memSize: Long, diskSize: Long) - private[spark] class BlockManagerInfo( val blockManagerId: BlockManagerId, timeMs: Long, @@ -371,6 +382,8 @@ private[spark] class BlockManagerInfo( logInfo("Registering block manager %s with %s RAM".format( blockManagerId.hostPort, Utils.bytesToString(maxMem))) + def getStatus(blockId: BlockId) = Option(_blocks.get(blockId)) + def updateLastSeenMs() { _lastSeenMs = System.currentTimeMillis() } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala index 9a29c39a28ab1..afb2c6a12ce67 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala @@ -41,9 +41,6 @@ private[storage] object BlockManagerMessages { case class RemoveBroadcast(broadcastId: Long, removeFromDriver: Boolean = true) extends ToBlockManagerSlave - // For testing. Ask the slave for the block's storage level. - case class GetStorageLevel(blockId: BlockId) extends ToBlockManagerSlave - ////////////////////////////////////////////////////////////////////////////////// // Messages from slaves to the master. @@ -113,10 +110,10 @@ private[storage] object BlockManagerMessages { case object GetMemoryStatus extends ToBlockManagerMaster - case object ExpireDeadHosts extends ToBlockManagerMaster - case object GetStorageStatus extends ToBlockManagerMaster - // For testing. Have the master ask all slaves for the given block's storage level. - case class AskForStorageLevels(blockId: BlockId) extends ToBlockManagerMaster + case class GetBlockStatus(blockId: BlockId, askSlaves: Boolean = true) + extends ToBlockManagerMaster + + case object ExpireDeadHosts extends ToBlockManagerMaster } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala index 85b8ec40c0ea3..016ade428c68f 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala @@ -50,7 +50,7 @@ class BlockManagerSlaveActor( case RemoveBroadcast(broadcastId, removeFromDriver) => blockManager.removeBroadcast(broadcastId, removeFromDriver) - case GetStorageLevel(blockId) => - sender ! blockManager.getLevel(blockId) + case GetBlockStatus(blockId, _) => + sender ! blockManager.getStatus(blockId) } } diff --git a/core/src/test/scala/org/apache/spark/BroadcastSuite.scala b/core/src/test/scala/org/apache/spark/BroadcastSuite.scala index 9e600f1e91aa2..d28496e316a34 100644 --- a/core/src/test/scala/org/apache/spark/BroadcastSuite.scala +++ b/core/src/test/scala/org/apache/spark/BroadcastSuite.scala @@ -107,22 +107,26 @@ class BroadcastSuite extends FunSuite with LocalSparkContext { // Verify that the broadcast file is created, and blocks are persisted only on the driver def afterCreation(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) { assert(blockIds.size === 1) - val levels = bmm.askForStorageLevels(blockIds.head, waitTimeMs = 0) - assert(levels.size === 1) - levels.head match { case (bm, level) => - assert(bm.executorId === "") - assert(level === StorageLevel.MEMORY_AND_DISK) + val statuses = bmm.getBlockStatus(blockIds.head) + assert(statuses.size === 1) + statuses.head match { case (bm, status) => + assert(bm.executorId === "", "Block should only be on the driver") + assert(status.storageLevel === StorageLevel.MEMORY_AND_DISK) + assert(status.memSize > 0, "Block should be in memory store on the driver") + assert(status.diskSize === 0, "Block should not be in disk store on the driver") } - assert(HttpBroadcast.getFile(blockIds.head.broadcastId).exists) + assert(HttpBroadcast.getFile(blockIds.head.broadcastId).exists, "Broadcast file not found!") } // Verify that blocks are persisted in both the executors and the driver def afterUsingBroadcast(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) { assert(blockIds.size === 1) - val levels = bmm.askForStorageLevels(blockIds.head, waitTimeMs = 0) - assert(levels.size === numSlaves + 1) - levels.foreach { case (_, level) => - assert(level === StorageLevel.MEMORY_AND_DISK) + val statuses = bmm.getBlockStatus(blockIds.head) + assert(statuses.size === numSlaves + 1) + statuses.foreach { case (_, status) => + assert(status.storageLevel === StorageLevel.MEMORY_AND_DISK) + assert(status.memSize > 0, "Block should be in memory store") + assert(status.diskSize === 0, "Block should not be in disk store") } } @@ -130,9 +134,13 @@ class BroadcastSuite extends FunSuite with LocalSparkContext { // is true. In the latter case, also verify that the broadcast file is deleted on the driver. def afterUnpersist(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) { assert(blockIds.size === 1) - val levels = bmm.askForStorageLevels(blockIds.head, waitTimeMs = 0) - assert(levels.size === (if (removeFromDriver) 0 else 1)) - assert(removeFromDriver === !HttpBroadcast.getFile(blockIds.head.broadcastId).exists) + val statuses = bmm.getBlockStatus(blockIds.head) + val expectedNumBlocks = if (removeFromDriver) 0 else 1 + val possiblyNot = if (removeFromDriver) "" else " not" + assert(statuses.size === expectedNumBlocks, + "Block should%s be unpersisted on the driver".format(possiblyNot)) + assert(removeFromDriver === !HttpBroadcast.getFile(blockIds.head.broadcastId).exists, + "Broadcast file should%s be deleted".format(possiblyNot)) } testUnpersistBroadcast(numSlaves, httpConf, getBlockIds, afterCreation, @@ -158,11 +166,13 @@ class BroadcastSuite extends FunSuite with LocalSparkContext { // Verify that blocks are persisted only on the driver def afterCreation(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) { blockIds.foreach { blockId => - val levels = bmm.askForStorageLevels(blockId, waitTimeMs = 0) - assert(levels.size === 1) - levels.head match { case (bm, level) => - assert(bm.executorId === "") - assert(level === StorageLevel.MEMORY_AND_DISK) + val statuses = bmm.getBlockStatus(blockIds.head) + assert(statuses.size === 1) + statuses.head match { case (bm, status) => + assert(bm.executorId === "", "Block should only be on the driver") + assert(status.storageLevel === StorageLevel.MEMORY_AND_DISK) + assert(status.memSize > 0, "Block should be in memory store on the driver") + assert(status.diskSize === 0, "Block should not be in disk store on the driver") } } } @@ -170,16 +180,18 @@ class BroadcastSuite extends FunSuite with LocalSparkContext { // Verify that blocks are persisted in both the executors and the driver def afterUsingBroadcast(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) { blockIds.foreach { blockId => - val levels = bmm.askForStorageLevels(blockId, waitTimeMs = 0) + val statuses = bmm.getBlockStatus(blockId) if (blockId.field == "meta") { // Meta data is only on the driver - assert(levels.size === 1) - levels.head match { case (bm, _) => assert(bm.executorId === "") } + assert(statuses.size === 1) + statuses.head match { case (bm, _) => assert(bm.executorId === "") } } else { // Other blocks are on both the executors and the driver - assert(levels.size === numSlaves + 1) - levels.foreach { case (_, level) => - assert(level === StorageLevel.MEMORY_AND_DISK) + assert(statuses.size === numSlaves + 1) + statuses.foreach { case (_, status) => + assert(status.storageLevel === StorageLevel.MEMORY_AND_DISK) + assert(status.memSize > 0, "Block should be in memory store") + assert(status.diskSize === 0, "Block should not be in disk store") } } } @@ -189,12 +201,11 @@ class BroadcastSuite extends FunSuite with LocalSparkContext { // is true. def afterUnpersist(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) { val expectedNumBlocks = if (removeFromDriver) 0 else 1 - var waitTimeMs = 1000L + val possiblyNot = if (removeFromDriver) "" else " not" blockIds.foreach { blockId => - // Allow a second for the messages triggered by unpersist to propagate to prevent deadlocks - val levels = bmm.askForStorageLevels(blockId, waitTimeMs) - assert(levels.size === expectedNumBlocks) - waitTimeMs = 0L + val statuses = bmm.getBlockStatus(blockId) + assert(statuses.size === expectedNumBlocks, + "Block should%s be unpersisted on the driver".format(possiblyNot)) } } diff --git a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala index 6a12cb6603700..3d95547b20fc1 100644 --- a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala @@ -267,7 +267,7 @@ class CleanerTester( "One or more shuffles' blocks cannot be found in disk manager, cannot start cleaner test") // Verify that the broadcast is in the driver's block manager - assert(broadcastIds.forall(bid => blockManager.getLevel(broadcastBlockId(bid)).isDefined), + assert(broadcastIds.forall(bid => blockManager.getStatus(broadcastBlockId(bid)).isDefined), "One ore more broadcasts have not been persisted in the driver's block manager") } @@ -291,7 +291,7 @@ class CleanerTester( // Verify all broadcasts have been unpersisted assert(broadcastIds.forall { bid => - blockManager.master.askForStorageLevels(broadcastBlockId(bid)).isEmpty + blockManager.master.getBlockStatus(broadcastBlockId(bid)).isEmpty }) return diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 1f5bcca64fc39..bddbd381c2665 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -745,6 +745,46 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT assert(!store.get("list5").isDefined, "list5 was in store") } + test("query block statuses") { + store = new BlockManager("", actorSystem, master, serializer, 1200, conf, + securityMgr, mapOutputTracker) + val list = List.fill(2)(new Array[Byte](200)) + + // Tell master. By LRU, only list2 and list3 remains. + store.put("list1", list.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) + store.put("list2", list.iterator, StorageLevel.MEMORY_AND_DISK, tellMaster = true) + store.put("list3", list.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) + + // getLocations and getBlockStatus should yield the same locations + assert(store.master.getLocations("list1").size === 0) + assert(store.master.getLocations("list2").size === 1) + assert(store.master.getLocations("list3").size === 1) + assert(store.master.getBlockStatus("list1", askSlaves = false).size === 0) + assert(store.master.getBlockStatus("list2", askSlaves = false).size === 1) + assert(store.master.getBlockStatus("list3", askSlaves = false).size === 1) + assert(store.master.getBlockStatus("list1", askSlaves = true).size === 0) + assert(store.master.getBlockStatus("list2", askSlaves = true).size === 1) + assert(store.master.getBlockStatus("list3", askSlaves = true).size === 1) + + // This time don't tell master and see what happens. By LRU, only list5 and list6 remains. + store.put("list4", list.iterator, StorageLevel.MEMORY_ONLY, tellMaster = false) + store.put("list5", list.iterator, StorageLevel.MEMORY_AND_DISK, tellMaster = false) + store.put("list6", list.iterator, StorageLevel.MEMORY_ONLY, tellMaster = false) + + // getLocations should return nothing because the master is not informed + // getBlockStatus without asking slaves should have the same result + // getBlockStatus with asking slaves, however, should present the actual block statuses + assert(store.master.getLocations("list4").size === 0) + assert(store.master.getLocations("list5").size === 0) + assert(store.master.getLocations("list6").size === 0) + assert(store.master.getBlockStatus("list4", askSlaves = false).size === 0) + assert(store.master.getBlockStatus("list5", askSlaves = false).size === 0) + assert(store.master.getBlockStatus("list6", askSlaves = false).size === 0) + assert(store.master.getBlockStatus("list4", askSlaves = true).size === 0) + assert(store.master.getBlockStatus("list5", askSlaves = true).size === 1) + assert(store.master.getBlockStatus("list6", askSlaves = true).size === 1) + } + test("SPARK-1194 regression: fix the same-RDD rule for cache replacement") { store = new BlockManager("", actorSystem, master, serializer, 1200, conf, securityMgr, mapOutputTracker) From 88904a3659fe4a81bdfb2a6b615894d926af3fe1 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Fri, 28 Mar 2014 23:02:11 -0700 Subject: [PATCH 11/14] Make TimeStampedWeakValueHashMap a wrapper of TimeStampedHashMap This allows us to get rid of WrappedJavaHashMap without much duplicate code. --- .../scala/org/apache/spark/SparkContext.scala | 1 - .../apache/spark/storage/BlockManager.scala | 7 +- .../spark/util/TimeStampedHashMap.scala | 117 +++++++--- .../util/TimeStampedWeakValueHashMap.scala | 164 +++++++------- .../spark/util/WrappedJavaHashMap.scala | 152 ------------- .../spark/util/WrappedJavaHashMapSuite.scala | 206 ------------------ 6 files changed, 168 insertions(+), 479 deletions(-) delete mode 100644 core/src/main/scala/org/apache/spark/util/WrappedJavaHashMap.scala delete mode 100644 core/src/test/scala/org/apache/spark/util/WrappedJavaHashMapSuite.scala diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 79574c271cfb6..13fba1e0dfe5d 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -35,7 +35,6 @@ import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat, Job => NewHad import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat => NewFileInputFormat} import org.apache.mesos.MesosNativeLibrary -import org.apache.spark.broadcast.Broadcast import org.apache.spark.deploy.{LocalSparkCluster, SparkHadoopUtil} import org.apache.spark.partial.{ApproximateEvaluator, PartialResult} import org.apache.spark.rdd._ diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index dd2dbd1c8a397..991881b00c0eb 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -209,7 +209,7 @@ private[spark] class BlockManager( } } - /** Return the status of the block identified by the given ID, if it exists. */ + /** Get the BlockStatus for the block identified by the given ID, if it exists.*/ def getStatus(blockId: BlockId): Option[BlockStatus] = { blockInfo.get(blockId).map { info => val memSize = if (memoryStore.contains(blockId)) memoryStore.getSize(blockId) else 0L @@ -635,9 +635,10 @@ private[spark] class BlockManager( diskStore.putValues(blockId, iterator, level, askForBytes) case ArrayBufferValues(array) => diskStore.putValues(blockId, array, level, askForBytes) - case ByteBufferValues(bytes) => + case ByteBufferValues(bytes) => { bytes.rewind() diskStore.putBytes(blockId, bytes, level) + } } size = res.size res.data match { @@ -872,7 +873,7 @@ private[spark] class BlockManager( } private def dropOldBlocks(cleanupTime: Long, shouldDrop: (BlockId => Boolean)) { - val iterator = blockInfo.internalMap.entrySet().iterator() + val iterator = blockInfo.getEntrySet.iterator while (iterator.hasNext) { val entry = iterator.next() val (id, info, time) = (entry.getKey, entry.getValue.value, entry.getValue.timestamp) diff --git a/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala b/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala index c4d770fecdf74..1721818c212f9 100644 --- a/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala +++ b/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala @@ -17,64 +17,108 @@ package org.apache.spark.util +import java.util.Set +import java.util.Map.Entry import java.util.concurrent.ConcurrentHashMap +import scala.collection.{immutable, JavaConversions, mutable} + import org.apache.spark.Logging -private[util] case class TimeStampedValue[T](timestamp: Long, value: T) +private[spark] case class TimeStampedValue[V](value: V, timestamp: Long) /** - * A map that stores the timestamp of when a key was inserted along with the value. If specified, - * the timestamp of each pair can be updated every time it is accessed. - * Key-value pairs whose timestamps are older than a particular - * threshold time can then be removed using the clearOldValues method. It exposes a - * scala.collection.mutable.Map interface to allow it to be a drop-in replacement for Scala - * HashMaps. - * - * Internally, it uses a Java ConcurrentHashMap, so all operations on this HashMap are thread-safe. + * This is a custom implementation of scala.collection.mutable.Map which stores the insertion + * timestamp along with each key-value pair. If specified, the timestamp of each pair can be + * updated every time it is accessed. Key-value pairs whose timestamp are older than a particular + * threshold time can then be removed using the clearOldValues method. This is intended to + * be a drop-in replacement of scala.collection.mutable.HashMap. * - * @param updateTimeStampOnGet When enabled, the timestamp of a pair will be - * updated when it is accessed + * @param updateTimeStampOnGet Whether timestamp of a pair will be updated when it is accessed */ private[spark] class TimeStampedHashMap[A, B](updateTimeStampOnGet: Boolean = false) - extends WrappedJavaHashMap[A, B, A, TimeStampedValue[B]] with Logging { + extends mutable.Map[A, B]() with Logging { - private[util] val internalJavaMap = new ConcurrentHashMap[A, TimeStampedValue[B]]() + private val internalMap = new ConcurrentHashMap[A, TimeStampedValue[B]]() - private[util] def newInstance[K1, V1](): WrappedJavaHashMap[K1, V1, _, _] = { - new TimeStampedHashMap[K1, V1]() + def get(key: A): Option[B] = { + val value = internalMap.get(key) + if (value != null && updateTimeStampOnGet) { + internalMap.replace(key, value, TimeStampedValue(value.value, currentTime)) + } + Option(value).map(_.value) } - def internalMap = internalJavaMap + def iterator: Iterator[(A, B)] = { + val jIterator = getEntrySet.iterator() + JavaConversions.asScalaIterator(jIterator).map(kv => (kv.getKey, kv.getValue.value)) + } - override def get(key: A): Option[B] = { - val timeStampedValue = internalMap.get(key) - if (updateTimeStampOnGet && timeStampedValue != null) { - internalJavaMap.replace(key, timeStampedValue, - TimeStampedValue(currentTime, timeStampedValue.value)) - } - Option(timeStampedValue).map(_.value) + def getEntrySet: Set[Entry[A, TimeStampedValue[B]]] = internalMap.entrySet() + + override def + [B1 >: B](kv: (A, B1)): mutable.Map[A, B1] = { + val newMap = new TimeStampedHashMap[A, B1] + val oldInternalMap = this.internalMap.asInstanceOf[ConcurrentHashMap[A, TimeStampedValue[B1]]] + newMap.internalMap.putAll(oldInternalMap) + kv match { case (a, b) => newMap.internalMap.put(a, TimeStampedValue(b, currentTime)) } + newMap } - @inline override protected def externalValueToInternalValue(v: B): TimeStampedValue[B] = { - new TimeStampedValue(currentTime, v) + + override def - (key: A): mutable.Map[A, B] = { + val newMap = new TimeStampedHashMap[A, B] + newMap.internalMap.putAll(this.internalMap) + newMap.internalMap.remove(key) + newMap + } + + override def += (kv: (A, B)): this.type = { + kv match { case (a, b) => internalMap.put(a, TimeStampedValue(b, currentTime)) } + this + } + + override def -= (key: A): this.type = { + internalMap.remove(key) + this + } + + override def update(key: A, value: B) { + this += ((key, value)) } - @inline override protected def internalValueToExternalValue(iv: TimeStampedValue[B]): B = { - iv.value + override def apply(key: A): B = { + val value = internalMap.get(key) + Option(value).map(_.value).getOrElse { throw new NoSuchElementException() } } - /** Atomically put if a key is absent. This exposes the existing API of ConcurrentHashMap. */ + override def filter(p: ((A, B)) => Boolean): mutable.Map[A, B] = { + JavaConversions.mapAsScalaConcurrentMap(internalMap) + .map { case (k, TimeStampedValue(v, t)) => (k, v) } + .filter(p) + } + + override def empty: mutable.Map[A, B] = new TimeStampedHashMap[A, B]() + + override def size: Int = internalMap.size + + override def foreach[U](f: ((A, B)) => U) { + val iterator = getEntrySet.iterator() + while(iterator.hasNext) { + val entry = iterator.next() + val kv = (entry.getKey, entry.getValue.value) + f(kv) + } + } + + // Should we return previous value directly or as Option? def putIfAbsent(key: A, value: B): Option[B] = { - val prev = internalJavaMap.putIfAbsent(key, TimeStampedValue(currentTime, value)) + val prev = internalMap.putIfAbsent(key, TimeStampedValue(value, currentTime)) Option(prev).map(_.value) } - /** - * Removes old key-value pairs that have timestamp earlier than `threshTime`, - * calling the supplied function on each such entry before removing. - */ + def toMap: immutable.Map[A, B] = iterator.toMap + def clearOldValues(threshTime: Long, f: (A, B) => Unit) { - val iterator = internalJavaMap.entrySet().iterator() + val iterator = getEntrySet.iterator() while (iterator.hasNext) { val entry = iterator.next() if (entry.getValue.timestamp < threshTime) { @@ -86,11 +130,12 @@ private[spark] class TimeStampedHashMap[A, B](updateTimeStampOnGet: Boolean = fa } /** - * Removes old key-value pairs that have timestamp earlier than `threshTime` + * Removes old key-value pairs that have timestamp earlier than `threshTime`. */ def clearOldValues(threshTime: Long) { clearOldValues(threshTime, (_, _) => ()) } - private def currentTime: Long = System.currentTimeMillis() + private def currentTime: Long = System.currentTimeMillis + } diff --git a/core/src/main/scala/org/apache/spark/util/TimeStampedWeakValueHashMap.scala b/core/src/main/scala/org/apache/spark/util/TimeStampedWeakValueHashMap.scala index 9f3247a27ba38..f814f58261bf3 100644 --- a/core/src/main/scala/org/apache/spark/util/TimeStampedWeakValueHashMap.scala +++ b/core/src/main/scala/org/apache/spark/util/TimeStampedWeakValueHashMap.scala @@ -18,113 +18,115 @@ package org.apache.spark.util import java.lang.ref.WeakReference -import java.util -import java.util.concurrent.ConcurrentHashMap -import java.util.concurrent.atomic.AtomicInteger -import scala.collection.JavaConversions - -import org.apache.spark.Logging - -private[util] case class TimeStampedWeakValue[T](timestamp: Long, weakValue: WeakReference[T]) { - def this(timestamp: Long, value: T) = this(timestamp, new WeakReference[T](value)) -} +import scala.collection.{immutable, mutable} /** - * A map that stores the timestamp of when a key was inserted along with the value, - * while ensuring that the values are weakly referenced. If the value is garbage collected and - * the weak reference is null, get() operation returns the key be non-existent. However, - * the key is actually not removed in the current implementation. Key-value pairs whose - * timestamps are older than a particular threshold time can then be removed using the - * clearOldValues method. It exposes a scala.collection.mutable.Map interface to allow it to be a - * drop-in replacement for Scala HashMaps. + * A wrapper of TimeStampedHashMap that ensures the values are weakly referenced and timestamped. + * + * If the value is garbage collected and the weak reference is null, get() operation returns + * a non-existent value. However, the corresponding key is actually not removed in the current + * implementation. Key-value pairs whose timestamps are older than a particular threshold time + * can then be removed using the clearOldValues method. It exposes a scala.collection.mutable.Map + * interface to allow it to be a drop-in replacement for Scala HashMaps. * * Internally, it uses a Java ConcurrentHashMap, so all operations on this HashMap are thread-safe. + * + * @param updateTimeStampOnGet Whether timestamp of a pair will be updated when it is accessed. */ +private[spark] class TimeStampedWeakValueHashMap[A, B](updateTimeStampOnGet: Boolean = false) + extends mutable.Map[A, B]() { + + import TimeStampedWeakValueHashMap._ -private[spark] class TimeStampedWeakValueHashMap[A, B]() - extends WrappedJavaHashMap[A, B, A, TimeStampedWeakValue[B]] with Logging { + private val internalMap = new TimeStampedHashMap[A, WeakReference[B]](updateTimeStampOnGet) - /** Number of inserts after which keys whose weak ref values are null will be cleaned */ - private val CLEANUP_INTERVAL = 1000 + def get(key: A): Option[B] = internalMap.get(key) - /** Counter for counting the number of inserts */ - private val insertCounts = new AtomicInteger(0) + def iterator: Iterator[(A, B)] = internalMap.iterator + + override def + [B1 >: B](kv: (A, B1)): mutable.Map[A, B1] = { + val newMap = new TimeStampedWeakValueHashMap[A, B1] + newMap.internalMap += kv + newMap + } - private[util] val internalJavaMap: util.Map[A, TimeStampedWeakValue[B]] = { - new ConcurrentHashMap[A, TimeStampedWeakValue[B]]() + override def - (key: A): mutable.Map[A, B] = { + val newMap = new TimeStampedWeakValueHashMap[A, B] + newMap.internalMap -= key + newMap } - private[util] def newInstance[K1, V1](): WrappedJavaHashMap[K1, V1, _, _] = { - new TimeStampedWeakValueHashMap[K1, V1]() + override def += (kv: (A, B)): this.type = { + internalMap += kv + this } - override def +=(kv: (A, B)): this.type = { - // Cleanup null value at certain intervals - if (insertCounts.incrementAndGet() % CLEANUP_INTERVAL == 0) { - cleanNullValues() - } - super.+=(kv) + override def -= (key: A): this.type = { + internalMap -= key + this } - override def get(key: A): Option[B] = { - Option(internalJavaMap.get(key)).flatMap { weakValue => - val value = weakValue.weakValue.get - if (value == null) { - internalJavaMap.remove(key) - } - Option(value) - } + override def update(key: A, value: B) = this += ((key, value)) + + override def apply(key: A): B = internalMap.apply(key) + + override def filter(p: ((A, B)) => Boolean): mutable.Map[A, B] = internalMap.filter(p) + + override def empty: mutable.Map[A, B] = new TimeStampedWeakValueHashMap[A, B]() + + override def size: Int = internalMap.size + + override def foreach[U](f: ((A, B)) => U) = internalMap.foreach(f) + + def putIfAbsent(key: A, value: B): Option[B] = internalMap.putIfAbsent(key, value) + + def toMap: immutable.Map[A, B] = iterator.toMap + + /** + * Remove old key-value pairs that have timestamp earlier than `threshTime`. + */ + def clearOldValues(threshTime: Long) = internalMap.clearOldValues(threshTime) + +} + +/** + * Helper methods for converting to and from WeakReferences. + */ +private[spark] object TimeStampedWeakValueHashMap { + + /* Implicit conversion methods to WeakReferences */ + + implicit def toWeakReference[V](v: V): WeakReference[V] = new WeakReference[V](v) + + implicit def toWeakReferenceTuple[K, V](kv: (K, V)): (K, WeakReference[V]) = { + kv match { case (k, v) => (k, toWeakReference(v)) } } - @inline override protected def externalValueToInternalValue(v: B): TimeStampedWeakValue[B] = { - new TimeStampedWeakValue(currentTime, v) + implicit def toWeakReferenceFunction[K, V, R](p: ((K, V)) => R): ((K, WeakReference[V])) => R = { + (kv: (K, WeakReference[V])) => p(kv) } - @inline override protected def internalValueToExternalValue(iv: TimeStampedWeakValue[B]): B = { - iv.weakValue.get + /* Implicit conversion methods from WeakReferences */ + + implicit def fromWeakReference[V](ref: WeakReference[V]): V = ref.get + + implicit def fromWeakReferenceOption[V](v: Option[WeakReference[V]]): Option[V] = { + v.map(fromWeakReference) } - override def iterator: Iterator[(A, B)] = { - val iterator = internalJavaMap.entrySet().iterator() - JavaConversions.asScalaIterator(iterator).flatMap(kv => { - val (key, value) = (kv.getKey, kv.getValue.weakValue.get) - if (value != null) Seq((key, value)) else Seq.empty - }) + implicit def fromWeakReferenceTuple[K, V](kv: (K, WeakReference[V])): (K, V) = { + kv match { case (k, v) => (k, fromWeakReference(v)) } } - /** - * Removes old key-value pairs that have timestamp earlier than `threshTime`, - * calling the supplied function on each such entry before removing. - */ - def clearOldValues(threshTime: Long, f: (A, B) => Unit = null) { - val iterator = internalJavaMap.entrySet().iterator() - while (iterator.hasNext) { - val entry = iterator.next() - if (entry.getValue.timestamp < threshTime) { - val value = entry.getValue.weakValue.get - if (f != null && value != null) { - f(entry.getKey, value) - } - logDebug("Removing key " + entry.getKey) - iterator.remove() - } - } + implicit def fromWeakReferenceIterator[K, V]( + it: Iterator[(K, WeakReference[V])]): Iterator[(K, V)] = { + it.map(fromWeakReferenceTuple) } - /** - * Removes keys whose weak referenced values have become null. - */ - private def cleanNullValues() { - val iterator = internalJavaMap.entrySet().iterator() - while (iterator.hasNext) { - val entry = iterator.next() - if (entry.getValue.weakValue.get == null) { - logDebug("Removing key " + entry.getKey) - iterator.remove() - } - } + implicit def fromWeakReferenceMap[K, V]( + map: mutable.Map[K, WeakReference[V]]) : mutable.Map[K, V] = { + mutable.Map(map.mapValues(fromWeakReference).toSeq: _*) } - private def currentTime = System.currentTimeMillis() } diff --git a/core/src/main/scala/org/apache/spark/util/WrappedJavaHashMap.scala b/core/src/main/scala/org/apache/spark/util/WrappedJavaHashMap.scala deleted file mode 100644 index 6cc3007f5d7ac..0000000000000 --- a/core/src/main/scala/org/apache/spark/util/WrappedJavaHashMap.scala +++ /dev/null @@ -1,152 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.util - -import scala.collection.mutable.Map -import java.util.{Map => JMap} -import java.util.Map.{Entry => JMapEntry} -import scala.collection.{immutable, JavaConversions} -import scala.reflect.ClassTag - -/** - * Convenient wrapper class for exposing Java HashMaps as Scala Maps even if the - * exposed key-value type is different from the internal type. This allows these - * implementations of WrappedJavaHashMap to be drop-in replacements for Scala HashMaps. - * - * While Java <-> Scala conversion methods exists, its hard to understand the performance - * implications and thread safety of the Scala wrapper. This class allows you to convert - * between types and applying the necessary overridden methods to take care of performance. - * - * Note that the threading behavior of an implementation of WrappedJavaHashMap is tied to that of - * the internal Java HashMap used in the implementation. Each implementation must use - * necessary traits (e.g, scala.collection.mutable.SynchronizedMap), etc. to achieve the - * desired thread safety. - * - * @tparam K External key type - * @tparam V External value type - * @tparam IK Internal key type - * @tparam IV Internal value type - */ -private[spark] abstract class WrappedJavaHashMap[K, V, IK, IV] extends Map[K, V] { - - /* Methods that must be defined. */ - - /** - * Internal Java HashMap that is being wrapped. - * Scoped private[util] so that rest of Spark code cannot - * directly access the internal map. - */ - private[util] val internalJavaMap: JMap[IK, IV] - - /** Method to get a new instance of the internal Java HashMap. */ - private[util] def newInstance[K1, V1](): WrappedJavaHashMap[K1, V1, _, _] - - /* - Methods that convert between internal and external types. These implementations - optimistically assume that the internal types are same as external types. These must - be overridden if the internal and external types are different. Otherwise there will be - runtime exceptions. - */ - - @inline protected def externalKeyToInternalKey(k: K): IK = { - k.asInstanceOf[IK] // works only if K is same or subclass of K - } - - @inline protected def externalValueToInternalValue(v: V): IV = { - v.asInstanceOf[IV] // works only if V is same or subclass of - } - - @inline protected def internalKeyToExternalKey(ik: IK): K = { - ik.asInstanceOf[K] - } - - @inline protected def internalValueToExternalValue(iv: IV): V = { - iv.asInstanceOf[V] - } - - @inline protected def internalPairToExternalPair(ip: JMapEntry[IK, IV]): (K, V) = { - (internalKeyToExternalKey(ip.getKey), internalValueToExternalValue(ip.getValue) ) - } - - /* Implicit methods to convert the types. */ - - @inline implicit private def convExtKeyToIntKey(k: K) = externalKeyToInternalKey(k) - - @inline implicit private def convExtValueToIntValue(v: V) = externalValueToInternalValue(v) - - @inline implicit private def convIntKeyToExtKey(ia: IK) = internalKeyToExternalKey(ia) - - @inline implicit private def convIntValueToExtValue(ib: IV) = internalValueToExternalValue(ib) - - @inline implicit private def convIntPairToExtPair(ip: JMapEntry[IK, IV]) = { - internalPairToExternalPair(ip) - } - - /* Methods that must be implemented for a scala.collection.mutable.Map */ - - def get(key: K): Option[V] = { - Option(internalJavaMap.get(key)) - } - - def iterator: Iterator[(K, V)] = { - val jIterator = internalJavaMap.entrySet().iterator() - JavaConversions.asScalaIterator(jIterator).map(kv => convIntPairToExtPair(kv)) - } - - /* Other methods that are implemented to ensure performance. */ - - def +=(kv: (K, V)): this.type = { - internalJavaMap.put(kv._1, kv._2) - this - } - - def -=(key: K): this.type = { - internalJavaMap.remove(key) - this - } - - override def + [V1 >: V](kv: (K, V1)): Map[K, V1] = { - val newMap = newInstance[K, V1]() - newMap.internalJavaMap.asInstanceOf[JMap[IK, IV]].putAll(this.internalJavaMap) - newMap += kv - newMap - } - - override def - (key: K): Map[K, V] = { - val newMap = newInstance[K, V]() - newMap.internalJavaMap.asInstanceOf[JMap[IK, IV]].putAll(this.internalJavaMap) - newMap -= key - } - - override def foreach[U](f: ((K, V)) => U) { - val jIterator = internalJavaMap.entrySet().iterator() - while(jIterator.hasNext) { - f(jIterator.next()) - } - } - - override def empty: Map[K, V] = newInstance[K, V]() - - override def size: Int = internalJavaMap.size - - override def filter(p: ((K, V)) => Boolean): Map[K, V] = { - newInstance[K, V]() ++= iterator.filter(p) - } - - def toMap: immutable.Map[K, V] = iterator.toMap -} diff --git a/core/src/test/scala/org/apache/spark/util/WrappedJavaHashMapSuite.scala b/core/src/test/scala/org/apache/spark/util/WrappedJavaHashMapSuite.scala deleted file mode 100644 index f6e6a4c77c820..0000000000000 --- a/core/src/test/scala/org/apache/spark/util/WrappedJavaHashMapSuite.scala +++ /dev/null @@ -1,206 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.util - -import java.lang.ref.WeakReference -import java.util - -import scala.collection.mutable.{ArrayBuffer, HashMap, Map} -import scala.util.Random - -import org.scalatest.FunSuite - -class WrappedJavaHashMapSuite extends FunSuite { - - // Test the testMap function - a Scala HashMap should obviously pass - testMap(new HashMap[String, String]()) - - // Test a simple WrappedJavaHashMap - testMap(new TestMap[String, String]()) - - // Test TimeStampedHashMap - testMap(new TimeStampedHashMap[String, String]) - - testMapThreadSafety(new TimeStampedHashMap[String, String]) - - test("TimeStampedHashMap - clearing by timestamp") { - // clearing by insertion time - val map = new TimeStampedHashMap[String, String](false) - map("k1") = "v1" - assert(map("k1") === "v1") - Thread.sleep(10) - val threshTime = System.currentTimeMillis() - assert(map.internalMap.get("k1").timestamp < threshTime) - map.clearOldValues(threshTime) - assert(map.get("k1") === None) - - // clearing by modification time - val map1 = new TimeStampedHashMap[String, String](true) - map1("k1") = "v1" - map1("k2") = "v2" - assert(map1("k1") === "v1") - Thread.sleep(10) - val threshTime1 = System.currentTimeMillis() - Thread.sleep(10) - assert(map1("k2") === "v2") // access k2 to update its access time to > threshTime - assert(map1.internalMap.get("k1").timestamp < threshTime1) - assert(map1.internalMap.get("k2").timestamp >= threshTime1) - map1.clearOldValues(threshTime1) //should only clear k1 - assert(map1.get("k1") === None) - assert(map1.get("k2").isDefined) - } - - // Test TimeStampedHashMap - testMap(new TimeStampedWeakValueHashMap[String, String]) - - testMapThreadSafety(new TimeStampedWeakValueHashMap[String, String]) - - test("TimeStampedWeakValueHashMap - clearing by timestamp") { - // clearing by insertion time - val map = new TimeStampedWeakValueHashMap[String, String]() - map("k1") = "v1" - assert(map("k1") === "v1") - Thread.sleep(10) - val threshTime = System.currentTimeMillis() - assert(map.internalJavaMap.get("k1").timestamp < threshTime) - map.clearOldValues(threshTime) - assert(map.get("k1") === None) - } - - - test("TimeStampedWeakValueHashMap - get not returning null when weak reference is cleared") { - var strongRef = new Object - val weakRef = new WeakReference(strongRef) - val map = new TimeStampedWeakValueHashMap[String, Object] - - map("k1") = strongRef - assert(map("k1") === strongRef) - - strongRef = null - val startTime = System.currentTimeMillis - System.gc() // Make a best effort to run the garbage collection. It *usually* runs GC. - System.runFinalization() // Make a best effort to call finalizer on all cleaned objects. - while(System.currentTimeMillis - startTime < 10000 && weakRef.get != null) { - System.gc() - System.runFinalization() - Thread.sleep(100) - } - assert(map.internalJavaMap.get("k1").weakValue.get == null) - assert(map.get("k1") === None) - - // TODO (TD): Test clearing of null-value pairs - } - - def testMap(hashMapConstructor: => Map[String, String]) { - def newMap() = hashMapConstructor - - val name = newMap().getClass.getSimpleName - - test(name + " - basic test") { - val testMap1 = newMap() - - // put and get - testMap1 += (("k1", "v1")) - assert(testMap1.get("k1").get === "v1") - testMap1("k2") = "v2" - assert(testMap1.get("k2").get === "v2") - assert(testMap1("k2") === "v2") - - // remove - testMap1.remove("k1") - assert(testMap1.get("k1").isEmpty) - testMap1.remove("k2") - intercept[Exception] { - testMap1("k2") // Map.apply() causes exception - } - - // multi put - val keys = (1 to 100).map(_.toString) - val pairs = keys.map(x => (x, x * 2)) - val testMap2 = newMap() - assert((testMap2 ++ pairs).iterator.toSet === pairs.toSet) - testMap2 ++= pairs - - // iterator - assert(testMap2.iterator.toSet === pairs.toSet) - testMap2("k1") = "v1" - - // foreach - val buffer = new ArrayBuffer[(String, String)] - testMap2.foreach(x => buffer += x) - assert(testMap2.toSet === buffer.toSet) - - // multi remove - testMap2 --= keys - assert(testMap2.size === 1) - assert(testMap2.iterator.toSeq.head === ("k1", "v1")) - } - } - - def testMapThreadSafety(hashMapConstructor: => Map[String, String]) { - def newMap() = hashMapConstructor - - val name = newMap().getClass.getSimpleName - val testMap = newMap() - @volatile var error = false - - def getRandomKey(m: Map[String, String]): Option[String] = { - val keys = testMap.keysIterator.toSeq - if (keys.nonEmpty) { - Some(keys(Random.nextInt(keys.size))) - } else { - None - } - } - - val threads = (1 to 100).map(i => new Thread() { - override def run() { - try { - for (j <- 1 to 1000) { - Random.nextInt(3) match { - case 0 => - testMap(Random.nextString(10)) = Random.nextDouble.toString // put - case 1 => - getRandomKey(testMap).map(testMap.get) // get - case 2 => - getRandomKey(testMap).map(testMap.remove) // remove - } - } - } catch { - case t : Throwable => - error = true - throw t - } - } - }) - - test(name + " - threading safety test") { - threads.map(_.start) - threads.map(_.join) - assert(!error) - } - } -} - -class TestMap[A, B] extends WrappedJavaHashMap[A, B, A, B] { - private[util] val internalJavaMap: util.Map[A, B] = new util.HashMap[A, B]() - - private[util] def newInstance[K1, V1](): WrappedJavaHashMap[K1, V1, _, _] = { - new TestMap[K1, V1] - } -} From 7ed72fbbef4be653bce83ce75ad9929d29b36fcf Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Mon, 31 Mar 2014 14:24:18 -0700 Subject: [PATCH 12/14] Fix style test fail + remove verbose test message regarding broadcast --- .../org/apache/spark/broadcast/HttpBroadcast.scala | 12 ++++++------ .../org/apache/spark/storage/BlockManager.scala | 2 +- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala index d8981bb42e684..79216bd2b8404 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala @@ -219,12 +219,12 @@ private[spark] object HttpBroadcast extends Logging { private def deleteBroadcastFile(file: File) { try { - if (!file.exists) { - logWarning("Broadcast file to be deleted does not exist: %s".format(file)) - } else if (file.delete()) { - logInfo("Deleted broadcast file: %s".format(file)) - } else { - logWarning("Could not delete broadcast file: %s".format(file)) + if (file.exists) { + if (file.delete()) { + logInfo("Deleted broadcast file: %s".format(file)) + } else { + logWarning("Could not delete broadcast file: %s".format(file)) + } } } catch { case e: Exception => diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 991881b00c0eb..c90abb187bdb1 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -209,7 +209,7 @@ private[spark] class BlockManager( } } - /** Get the BlockStatus for the block identified by the given ID, if it exists.*/ + /** Get the BlockStatus for the block identified by the given ID, if it exists. */ def getStatus(blockId: BlockId): Option[BlockStatus] = { blockInfo.get(blockId).map { info => val memSize = if (memoryStore.contains(blockId)) memoryStore.getSize(blockId) else 0L From 5016375fb32c0de8df0529467a3c5a57fe73a18f Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Tue, 1 Apr 2014 13:16:11 -0700 Subject: [PATCH 13/14] Address TD's comments --- .../org/apache/spark/ContextCleaner.scala | 7 +- .../apache/spark/broadcast/Broadcast.scala | 31 +++++--- .../spark/broadcast/BroadcastManager.scala | 3 +- .../spark/broadcast/HttpBroadcast.scala | 25 ++++--- .../spark/broadcast/TorrentBroadcast.scala | 35 ++++++---- .../main/scala/org/apache/spark/rdd/RDD.scala | 1 - .../org/apache/spark/storage/BlockId.scala | 21 ++---- .../apache/spark/storage/BlockManager.scala | 4 +- .../spark/storage/BlockManagerMaster.scala | 23 +++--- .../storage/BlockManagerMasterActor.scala | 11 +-- .../org/apache/spark/util/JsonProtocol.scala | 70 ++++++++++++++++++- .../org/apache/spark/BroadcastSuite.scala | 17 +++-- .../spark/storage/BlockManagerSuite.scala | 2 +- .../apache/spark/util/JsonProtocolSuite.scala | 15 +++- 14 files changed, 181 insertions(+), 84 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ContextCleaner.scala b/core/src/main/scala/org/apache/spark/ContextCleaner.scala index f856a13f84dec..b71b7fa517fd2 100644 --- a/core/src/main/scala/org/apache/spark/ContextCleaner.scala +++ b/core/src/main/scala/org/apache/spark/ContextCleaner.scala @@ -169,18 +169,17 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { // Used for testing - private[spark] def cleanupRDD(rdd: RDD[_]) { + def cleanupRDD(rdd: RDD[_]) { doCleanupRDD(rdd.id) } - private[spark] def cleanupShuffle(shuffleDependency: ShuffleDependency[_, _]) { + def cleanupShuffle(shuffleDependency: ShuffleDependency[_, _]) { doCleanupShuffle(shuffleDependency.shuffleId) } - private[spark] def cleanupBroadcast[T](broadcast: Broadcast[T]) { + def cleanupBroadcast[T](broadcast: Broadcast[T]) { doCleanupBroadcast(broadcast.id) } - } private object ContextCleaner { diff --git a/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala index 3a2fef05861e6..81e0e5297683b 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala @@ -19,6 +19,8 @@ package org.apache.spark.broadcast import java.io.Serializable +import org.apache.spark.SparkException + /** * A broadcast variable. Broadcast variables allow the programmer to keep a read-only variable * cached on each machine rather than shipping a copy of it with tasks. They can be used, for @@ -49,25 +51,36 @@ import java.io.Serializable */ abstract class Broadcast[T](val id: Long) extends Serializable { + protected var _isValid: Boolean = true + /** * Whether this Broadcast is actually usable. This should be false once persisted state is * removed from the driver. */ - protected var isValid: Boolean = true + def isValid: Boolean = _isValid def value: T /** - * Remove all persisted state associated with this broadcast. Overriding implementations - * should set isValid to false if persisted state is also removed from the driver. - * - * @param removeFromDriver Whether to remove state from the driver. - * If true, the resulting broadcast should no longer be valid. + * Remove all persisted state associated with this broadcast on the executors. The next use + * of this broadcast on the executors will trigger a remote fetch. */ - def unpersist(removeFromDriver: Boolean) + def unpersist() - // We cannot define abstract readObject and writeObject here due to some weird issues - // with these methods having to be 'private' in sub-classes. + /** + * Remove all persisted state associated with this broadcast on both the executors and the + * driver. Overriding implementations should set isValid to false. + */ + private[spark] def destroy() + + /** + * If this broadcast is no longer valid, throw an exception. + */ + protected def assertValid() { + if (!_isValid) { + throw new SparkException("Attempted to use %s when is no longer valid!".format(toString)) + } + } override def toString = "Broadcast(" + id + ")" } diff --git a/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala b/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala index 85d62aae03959..c3ea16ff9eb5e 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala @@ -25,7 +25,7 @@ private[spark] class BroadcastManager( val isDriver: Boolean, conf: SparkConf, securityManager: SecurityManager) - extends Logging with Serializable { + extends Logging { private var initialized = false private var broadcastFactory: BroadcastFactory = null @@ -63,5 +63,4 @@ private[spark] class BroadcastManager( def unbroadcast(id: Long, removeFromDriver: Boolean) { broadcastFactory.unbroadcast(id, removeFromDriver) } - } diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala index 79216bd2b8404..ec5acf5f23f5f 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala @@ -31,7 +31,10 @@ import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedH private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long) extends Broadcast[T](id) with Logging with Serializable { - override def value = value_ + def value: T = { + assertValid() + value_ + } val blockId = BroadcastBlockId(id) @@ -45,17 +48,24 @@ private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolea } /** - * Remove all persisted state associated with this HTTP broadcast. - * @param removeFromDriver Whether to remove state from the driver. + * Remove all persisted state associated with this HTTP broadcast on the executors. + */ + def unpersist() { + HttpBroadcast.unpersist(id, removeFromDriver = false) + } + + /** + * Remove all persisted state associated with this HTTP Broadcast on both the executors + * and the driver. */ - override def unpersist(removeFromDriver: Boolean) { - isValid = !removeFromDriver - HttpBroadcast.unpersist(id, removeFromDriver) + private[spark] def destroy() { + _isValid = false + HttpBroadcast.unpersist(id, removeFromDriver = true) } // Used by the JVM when serializing this object private def writeObject(out: ObjectOutputStream) { - assert(isValid, "Attempted to serialize a broadcast variable that has been destroyed!") + assertValid() out.defaultWriteObject() } @@ -231,5 +241,4 @@ private[spark] object HttpBroadcast extends Logging { logError("Exception while deleting broadcast file: %s".format(file), e) } } - } diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala index dbe65d88104fb..590caa9699dd3 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala @@ -29,7 +29,10 @@ import org.apache.spark.util.Utils private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long) extends Broadcast[T](id) with Logging with Serializable { - override def value = value_ + def value = { + assertValid() + value_ + } val broadcastId = BroadcastBlockId(id) @@ -47,7 +50,23 @@ private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boo sendBroadcast() } - def sendBroadcast() { + /** + * Remove all persisted state associated with this Torrent broadcast on the executors. + */ + def unpersist() { + TorrentBroadcast.unpersist(id, removeFromDriver = false) + } + + /** + * Remove all persisted state associated with this Torrent broadcast on both the executors + * and the driver. + */ + private[spark] def destroy() { + _isValid = false + TorrentBroadcast.unpersist(id, removeFromDriver = true) + } + + private def sendBroadcast() { val tInfo = TorrentBroadcast.blockifyObject(value_) totalBlocks = tInfo.totalBlocks totalBytes = tInfo.totalBytes @@ -71,18 +90,9 @@ private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boo } } - /** - * Remove all persisted state associated with this Torrent broadcast. - * @param removeFromDriver Whether to remove state from the driver. - */ - override def unpersist(removeFromDriver: Boolean) { - isValid = !removeFromDriver - TorrentBroadcast.unpersist(id, removeFromDriver) - } - // Used by the JVM when serializing this object private def writeObject(out: ObjectOutputStream) { - assert(isValid, "Attempted to serialize a broadcast variable that has been destroyed!") + assertValid() out.defaultWriteObject() } @@ -240,7 +250,6 @@ private[spark] object TorrentBroadcast extends Logging { def unpersist(id: Long, removeFromDriver: Boolean) = synchronized { SparkEnv.get.blockManager.master.removeBroadcast(id, removeFromDriver) } - } private[spark] case class TorrentBlock( diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index e8d36e6bfc810..ea22ad29bc885 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -1128,5 +1128,4 @@ abstract class RDD[T: ClassTag]( def toJavaRDD() : JavaRDD[T] = { new JavaRDD(this)(elementClassTag) } - } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockId.scala b/core/src/main/scala/org/apache/spark/storage/BlockId.scala index 27e271368ed06..cffea28fbf794 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockId.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockId.scala @@ -53,9 +53,7 @@ private[spark] case class ShuffleBlockId(shuffleId: Int, mapId: Int, reduceId: I def name = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId } -// Leave field as an instance variable to avoid matching on it -private[spark] case class BroadcastBlockId(broadcastId: Long) extends BlockId { - var field = "" +private[spark] case class BroadcastBlockId(broadcastId: Long, field: String = "") extends BlockId { def name = "broadcast_" + broadcastId + (if (field == "") "" else "_" + field) } @@ -77,19 +75,10 @@ private[spark] case class TestBlockId(id: String) extends BlockId { def name = "test_" + id } -private[spark] object BroadcastBlockId { - def apply(broadcastId: Long, field: String) = { - val blockId = new BroadcastBlockId(broadcastId) - blockId.field = field - blockId - } -} - private[spark] object BlockId { val RDD = "rdd_([0-9]+)_([0-9]+)".r val SHUFFLE = "shuffle_([0-9]+)_([0-9]+)_([0-9]+)".r - val BROADCAST = "broadcast_([0-9]+)".r - val BROADCAST_FIELD = "broadcast_([0-9]+)_([A-Za-z0-9]+)".r + val BROADCAST = "broadcast_([0-9]+)([_A-Za-z0-9]*)".r val TASKRESULT = "taskresult_([0-9]+)".r val STREAM = "input-([0-9]+)-([0-9]+)".r val TEST = "test_(.*)".r @@ -100,10 +89,8 @@ private[spark] object BlockId { RDDBlockId(rddId.toInt, splitIndex.toInt) case SHUFFLE(shuffleId, mapId, reduceId) => ShuffleBlockId(shuffleId.toInt, mapId.toInt, reduceId.toInt) - case BROADCAST(broadcastId) => - BroadcastBlockId(broadcastId.toLong) - case BROADCAST_FIELD(broadcastId, field) => - BroadcastBlockId(broadcastId.toLong, field) + case BROADCAST(broadcastId, field) => + BroadcastBlockId(broadcastId.toLong, field.stripPrefix("_")) case TASKRESULT(taskId) => TaskResultBlockId(taskId.toLong) case STREAM(streamId, uniqueId) => diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index c90abb187bdb1..925cee1eb6be7 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -832,7 +832,7 @@ private[spark] class BlockManager( def removeBroadcast(broadcastId: Long, removeFromDriver: Boolean) { logInfo("Removing broadcast " + broadcastId) val blocksToRemove = blockInfo.keys.collect { - case bid: BroadcastBlockId if bid.broadcastId == broadcastId => bid + case bid @ BroadcastBlockId(`broadcastId`, _) => bid } blocksToRemove.foreach { blockId => removeBlock(blockId, removeFromDriver) } } @@ -897,7 +897,7 @@ private[spark] class BlockManager( def shouldCompress(blockId: BlockId): Boolean = blockId match { case ShuffleBlockId(_, _, _) => compressShuffle - case BroadcastBlockId(_) => compressBroadcast + case BroadcastBlockId(_, _) => compressBroadcast case RDDBlockId(_, _) => compressRdds case TempBlockId(_) => compressShuffleSpill case _ => false diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala index f61aa1d6bc0fc..4e45bb8452fd8 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala @@ -106,9 +106,7 @@ class BlockManagerMaster(var driverActor: ActorRef, conf: SparkConf) extends Log askDriverWithReply(RemoveBlock(blockId)) } - /** - * Remove all blocks belonging to the given RDD. - */ + /** Remove all blocks belonging to the given RDD. */ def removeRdd(rddId: Int, blocking: Boolean) { val future = askDriverWithReply[Future[Seq[Int]]](RemoveRdd(rddId)) future onFailure { @@ -119,16 +117,12 @@ class BlockManagerMaster(var driverActor: ActorRef, conf: SparkConf) extends Log } } - /** - * Remove all blocks belonging to the given shuffle. - */ + /** Remove all blocks belonging to the given shuffle. */ def removeShuffle(shuffleId: Int) { askDriverWithReply(RemoveShuffle(shuffleId)) } - /** - * Remove all blocks belonging to the given broadcast. - */ + /** Remove all blocks belonging to the given broadcast. */ def removeBroadcast(broadcastId: Long, removeFromMaster: Boolean) { askDriverWithReply(RemoveBroadcast(broadcastId, removeFromMaster)) } @@ -148,20 +142,21 @@ class BlockManagerMaster(var driverActor: ActorRef, conf: SparkConf) extends Log } /** - * Return the block's local status on all block managers, if any. + * Return the block's status on all block managers, if any. * * If askSlaves is true, this invokes the master to query each block manager for the most * updated block statuses. This is useful when the master is not informed of the given block * by all block managers. - * - * To avoid potential deadlocks, the use of Futures is necessary, because the master actor - * should not block on waiting for a block manager, which can in turn be waiting for the - * master actor for a response to a prior message. */ def getBlockStatus( blockId: BlockId, askSlaves: Boolean = true): Map[BlockManagerId, BlockStatus] = { val msg = GetBlockStatus(blockId, askSlaves) + /* + * To avoid potential deadlocks, the use of Futures is necessary, because the master actor + * should not block on waiting for a block manager, which can in turn be waiting for the + * master actor for a response to a prior message. + */ val response = askDriverWithReply[Map[BlockManagerId, Future[Option[BlockStatus]]]](msg) val (blockManagerIds, futures) = response.unzip val result = Await.result(Future.sequence(futures), timeout) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala index 2d9445425b879..4159fc733a566 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala @@ -255,21 +255,22 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus } /** - * Return the block's local status for all block managers, if any. + * Return the block's status for all block managers, if any. * * If askSlaves is true, the master queries each block manager for the most updated block * statuses. This is useful when the master is not informed of the given block by all block * managers. - * - * Rather than blocking on the block status query, master actor should simply return a - * Future to avoid potential deadlocks. This can arise if there exists a block manager - * that is also waiting for this master actor's response to a previous message. */ private def blockStatus( blockId: BlockId, askSlaves: Boolean): Map[BlockManagerId, Future[Option[BlockStatus]]] = { import context.dispatcher val getBlockStatus = GetBlockStatus(blockId) + /* + * Rather than blocking on the block status query, master actor should simply return + * Futures to avoid potential deadlocks. This can arise if there exists a block manager + * that is also waiting for this master actor's response to a previous message. + */ blockManagerInfo.values.map { info => val blockStatusFuture = if (askSlaves) { diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index d9a6af61872d1..c23b6b3944ba0 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -195,7 +195,7 @@ private[spark] object JsonProtocol { taskMetrics.shuffleWriteMetrics.map(shuffleWriteMetricsToJson).getOrElse(JNothing) val updatedBlocks = taskMetrics.updatedBlocks.map { blocks => JArray(blocks.toList.map { case (id, status) => - ("Block ID" -> id.toString) ~ + ("Block ID" -> blockIdToJson(id)) ~ ("Status" -> blockStatusToJson(status)) }) }.getOrElse(JNothing) @@ -284,6 +284,33 @@ private[spark] object JsonProtocol { ("Replication" -> storageLevel.replication) } + def blockIdToJson(blockId: BlockId): JValue = { + val blockType = Utils.getFormattedClassName(blockId) + val json: JObject = blockId match { + case rddBlockId: RDDBlockId => + ("RDD ID" -> rddBlockId.rddId) ~ + ("Split Index" -> rddBlockId.splitIndex) + case shuffleBlockId: ShuffleBlockId => + ("Shuffle ID" -> shuffleBlockId.shuffleId) ~ + ("Map ID" -> shuffleBlockId.mapId) ~ + ("Reduce ID" -> shuffleBlockId.reduceId) + case broadcastBlockId: BroadcastBlockId => + ("Broadcast ID" -> broadcastBlockId.broadcastId) ~ + ("Field" -> broadcastBlockId.field) + case taskResultBlockId: TaskResultBlockId => + "Task ID" -> taskResultBlockId.taskId + case streamBlockId: StreamBlockId => + ("Stream ID" -> streamBlockId.streamId) ~ + ("Unique ID" -> streamBlockId.uniqueId) + case tempBlockId: TempBlockId => + val uuid = UUIDToJson(tempBlockId.id) + "Temp ID" -> uuid + case testBlockId: TestBlockId => + "Test ID" -> testBlockId.id + } + ("Type" -> blockType) ~ json + } + def blockStatusToJson(blockStatus: BlockStatus): JValue = { val storageLevel = storageLevelToJson(blockStatus.storageLevel) ("Storage Level" -> storageLevel) ~ @@ -484,7 +511,7 @@ private[spark] object JsonProtocol { Utils.jsonOption(json \ "Shuffle Write Metrics").map(shuffleWriteMetricsFromJson) metrics.updatedBlocks = Utils.jsonOption(json \ "Updated Blocks").map { value => value.extract[List[JValue]].map { block => - val id = BlockId((block \ "Block ID").extract[String]) + val id = blockIdFromJson(block \ "Block ID") val status = blockStatusFromJson(block \ "Status") (id, status) } @@ -587,6 +614,45 @@ private[spark] object JsonProtocol { StorageLevel(useDisk, useMemory, deserialized, replication) } + def blockIdFromJson(json: JValue): BlockId = { + val rddBlockId = Utils.getFormattedClassName(RDDBlockId) + val shuffleBlockId = Utils.getFormattedClassName(ShuffleBlockId) + val broadcastBlockId = Utils.getFormattedClassName(BroadcastBlockId) + val taskResultBlockId = Utils.getFormattedClassName(TaskResultBlockId) + val streamBlockId = Utils.getFormattedClassName(StreamBlockId) + val tempBlockId = Utils.getFormattedClassName(TempBlockId) + val testBlockId = Utils.getFormattedClassName(TestBlockId) + + (json \ "Type").extract[String] match { + case `rddBlockId` => + val rddId = (json \ "RDD ID").extract[Int] + val splitIndex = (json \ "Split Index").extract[Int] + new RDDBlockId(rddId, splitIndex) + case `shuffleBlockId` => + val shuffleId = (json \ "Shuffle ID").extract[Int] + val mapId = (json \ "Map ID").extract[Int] + val reduceId = (json \ "Reduce ID").extract[Int] + new ShuffleBlockId(shuffleId, mapId, reduceId) + case `broadcastBlockId` => + val broadcastId = (json \ "Broadcast ID").extract[Long] + val field = (json \ "Field").extract[String] + new BroadcastBlockId(broadcastId, field) + case `taskResultBlockId` => + val taskId = (json \ "Task ID").extract[Long] + new TaskResultBlockId(taskId) + case `streamBlockId` => + val streamId = (json \ "Stream ID").extract[Int] + val uniqueId = (json \ "Unique ID").extract[Long] + new StreamBlockId(streamId, uniqueId) + case `tempBlockId` => + val tempId = UUIDFromJson(json \ "Temp ID") + new TempBlockId(tempId) + case `testBlockId` => + val testId = (json \ "Test ID").extract[String] + new TestBlockId(testId) + } + } + def blockStatusFromJson(json: JValue): BlockStatus = { val storageLevel = storageLevelFromJson(json \ "Storage Level") val memorySize = (json \ "Memory Size").extract[Long] diff --git a/core/src/test/scala/org/apache/spark/BroadcastSuite.scala b/core/src/test/scala/org/apache/spark/BroadcastSuite.scala index d28496e316a34..f1bfb6666ddda 100644 --- a/core/src/test/scala/org/apache/spark/BroadcastSuite.scala +++ b/core/src/test/scala/org/apache/spark/BroadcastSuite.scala @@ -246,12 +246,20 @@ class BroadcastSuite extends FunSuite with LocalSparkContext { afterUsingBroadcast(blocks, blockManagerMaster) // Unpersist broadcast - broadcast.unpersist(removeFromDriver) + if (removeFromDriver) { + broadcast.destroy() + } else { + broadcast.unpersist() + } afterUnpersist(blocks, blockManagerMaster) - if (!removeFromDriver) { - // The broadcast variable is not completely destroyed (i.e. state still exists on driver) - // Using the variable again should yield the same answer as before. + // If the broadcast is removed from driver, all subsequent uses of the broadcast variable + // should throw SparkExceptions. Otherwise, the result should be the same as before. + if (removeFromDriver) { + // Using this variable on the executors crashes them, which hangs the test. + // Instead, crash the driver by directly accessing the broadcast value. + intercept[SparkException] { broadcast.value } + } else { val results = sc.parallelize(1 to numSlaves, numSlaves).map(x => (x, broadcast.value.sum)) assert(results.collect().toSet === (1 to numSlaves).map(x => (x, 10)).toSet) } @@ -263,5 +271,4 @@ class BroadcastSuite extends FunSuite with LocalSparkContext { conf.set("spark.broadcast.factory", "org.apache.spark.broadcast.%s".format(factoryName)) conf } - } diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index bddbd381c2665..b47de5eab95a4 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -773,7 +773,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT // getLocations should return nothing because the master is not informed // getBlockStatus without asking slaves should have the same result - // getBlockStatus with asking slaves, however, should present the actual block statuses + // getBlockStatus with asking slaves, however, should return the actual block statuses assert(store.master.getLocations("list4").size === 0) assert(store.master.getLocations("list5").size === 0) assert(store.master.getLocations("list6").size === 0) diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index 580ac34f5f0b4..6bc8bcc036cb3 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -104,6 +104,14 @@ class JsonProtocolSuite extends FunSuite { testTaskEndReason(TaskKilled) testTaskEndReason(ExecutorLostFailure) testTaskEndReason(UnknownReason) + + // BlockId + testBlockId(RDDBlockId(1, 2)) + testBlockId(ShuffleBlockId(1, 2, 3)) + testBlockId(BroadcastBlockId(1L, "")) + testBlockId(TaskResultBlockId(1L)) + testBlockId(StreamBlockId(1, 2L)) + testBlockId(TempBlockId(UUID.randomUUID())) } @@ -158,6 +166,11 @@ class JsonProtocolSuite extends FunSuite { assertEquals(reason, newReason) } + private def testBlockId(blockId: BlockId) { + val newBlockId = JsonProtocol.blockIdFromJson(JsonProtocol.blockIdToJson(blockId)) + blockId == newBlockId + } + /** -------------------------------- * | Util methods for comparing events | @@ -542,4 +555,4 @@ class JsonProtocolSuite extends FunSuite { {"Event":"SparkListenerUnpersistRDD","RDD ID":12345} """ - } +} From f0aabb1c8496dc79daeb6d090fb36ceef310622b Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Tue, 1 Apr 2014 19:55:45 -0700 Subject: [PATCH 14/14] Correct semantics for TimeStampedWeakValueHashMap + add tests This largely accounts for the cases when WeakReference becomes no longer strongly reachable, in which case the map should return None for all get() operations, and should skip the entry for all listing operations. --- .../apache/spark/broadcast/Broadcast.scala | 2 +- .../spark/util/TimeStampedHashMap.scala | 43 +-- .../util/TimeStampedWeakValueHashMap.scala | 78 ++++-- .../spark/util/TimeStampedHashMapSuite.scala | 264 ++++++++++++++++++ 4 files changed, 350 insertions(+), 37 deletions(-) create mode 100644 core/src/test/scala/org/apache/spark/util/TimeStampedHashMapSuite.scala diff --git a/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala index 81e0e5297683b..b28e15a6840d9 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala @@ -78,7 +78,7 @@ abstract class Broadcast[T](val id: Long) extends Serializable { */ protected def assertValid() { if (!_isValid) { - throw new SparkException("Attempted to use %s when is no longer valid!".format(toString)) + throw new SparkException("Attempted to use %s after it has been destroyed!".format(toString)) } } diff --git a/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala b/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala index 1721818c212f9..5c239329588d8 100644 --- a/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala +++ b/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala @@ -21,7 +21,7 @@ import java.util.Set import java.util.Map.Entry import java.util.concurrent.ConcurrentHashMap -import scala.collection.{immutable, JavaConversions, mutable} +import scala.collection.{JavaConversions, mutable} import org.apache.spark.Logging @@ -50,11 +50,11 @@ private[spark] class TimeStampedHashMap[A, B](updateTimeStampOnGet: Boolean = fa } def iterator: Iterator[(A, B)] = { - val jIterator = getEntrySet.iterator() + val jIterator = getEntrySet.iterator JavaConversions.asScalaIterator(jIterator).map(kv => (kv.getKey, kv.getValue.value)) } - def getEntrySet: Set[Entry[A, TimeStampedValue[B]]] = internalMap.entrySet() + def getEntrySet: Set[Entry[A, TimeStampedValue[B]]] = internalMap.entrySet override def + [B1 >: B](kv: (A, B1)): mutable.Map[A, B1] = { val newMap = new TimeStampedHashMap[A, B1] @@ -86,8 +86,7 @@ private[spark] class TimeStampedHashMap[A, B](updateTimeStampOnGet: Boolean = fa } override def apply(key: A): B = { - val value = internalMap.get(key) - Option(value).map(_.value).getOrElse { throw new NoSuchElementException() } + get(key).getOrElse { throw new NoSuchElementException() } } override def filter(p: ((A, B)) => Boolean): mutable.Map[A, B] = { @@ -101,9 +100,9 @@ private[spark] class TimeStampedHashMap[A, B](updateTimeStampOnGet: Boolean = fa override def size: Int = internalMap.size override def foreach[U](f: ((A, B)) => U) { - val iterator = getEntrySet.iterator() - while(iterator.hasNext) { - val entry = iterator.next() + val it = getEntrySet.iterator + while(it.hasNext) { + val entry = it.next() val kv = (entry.getKey, entry.getValue.value) f(kv) } @@ -115,27 +114,39 @@ private[spark] class TimeStampedHashMap[A, B](updateTimeStampOnGet: Boolean = fa Option(prev).map(_.value) } - def toMap: immutable.Map[A, B] = iterator.toMap + def putAll(map: Map[A, B]) { + map.foreach { case (k, v) => update(k, v) } + } + + def toMap: Map[A, B] = iterator.toMap def clearOldValues(threshTime: Long, f: (A, B) => Unit) { - val iterator = getEntrySet.iterator() - while (iterator.hasNext) { - val entry = iterator.next() + val it = getEntrySet.iterator + while (it.hasNext) { + val entry = it.next() if (entry.getValue.timestamp < threshTime) { f(entry.getKey, entry.getValue.value) logDebug("Removing key " + entry.getKey) - iterator.remove() + it.remove() } } } - /** - * Removes old key-value pairs that have timestamp earlier than `threshTime`. - */ + /** Removes old key-value pairs that have timestamp earlier than `threshTime`. */ def clearOldValues(threshTime: Long) { clearOldValues(threshTime, (_, _) => ()) } private def currentTime: Long = System.currentTimeMillis + // For testing + + def getTimeStampedValue(key: A): Option[TimeStampedValue[B]] = { + Option(internalMap.get(key)) + } + + def getTimestamp(key: A): Option[Long] = { + getTimeStampedValue(key).map(_.timestamp) + } + } diff --git a/core/src/main/scala/org/apache/spark/util/TimeStampedWeakValueHashMap.scala b/core/src/main/scala/org/apache/spark/util/TimeStampedWeakValueHashMap.scala index f814f58261bf3..b65017d6806c6 100644 --- a/core/src/main/scala/org/apache/spark/util/TimeStampedWeakValueHashMap.scala +++ b/core/src/main/scala/org/apache/spark/util/TimeStampedWeakValueHashMap.scala @@ -18,47 +18,61 @@ package org.apache.spark.util import java.lang.ref.WeakReference +import java.util.concurrent.atomic.AtomicInteger -import scala.collection.{immutable, mutable} +import scala.collection.mutable + +import org.apache.spark.Logging /** * A wrapper of TimeStampedHashMap that ensures the values are weakly referenced and timestamped. * - * If the value is garbage collected and the weak reference is null, get() operation returns - * a non-existent value. However, the corresponding key is actually not removed in the current - * implementation. Key-value pairs whose timestamps are older than a particular threshold time - * can then be removed using the clearOldValues method. It exposes a scala.collection.mutable.Map - * interface to allow it to be a drop-in replacement for Scala HashMaps. + * If the value is garbage collected and the weak reference is null, get() will return a + * non-existent value. These entries are removed from the map periodically (every N inserts), as + * their values are no longer strongly reachable. Further, key-value pairs whose timestamps are + * older than a particular threshold can be removed using the clearOldValues method. * - * Internally, it uses a Java ConcurrentHashMap, so all operations on this HashMap are thread-safe. + * TimeStampedWeakValueHashMap exposes a scala.collection.mutable.Map interface, which allows it + * to be a drop-in replacement for Scala HashMaps. Internally, it uses a Java ConcurrentHashMap, + * so all operations on this HashMap are thread-safe. * * @param updateTimeStampOnGet Whether timestamp of a pair will be updated when it is accessed. */ private[spark] class TimeStampedWeakValueHashMap[A, B](updateTimeStampOnGet: Boolean = false) - extends mutable.Map[A, B]() { + extends mutable.Map[A, B]() with Logging { import TimeStampedWeakValueHashMap._ private val internalMap = new TimeStampedHashMap[A, WeakReference[B]](updateTimeStampOnGet) + private val insertCount = new AtomicInteger(0) + + /** Return a map consisting only of entries whose values are still strongly reachable. */ + private def nonNullReferenceMap = internalMap.filter { case (_, ref) => ref.get != null } def get(key: A): Option[B] = internalMap.get(key) - def iterator: Iterator[(A, B)] = internalMap.iterator + def iterator: Iterator[(A, B)] = nonNullReferenceMap.iterator override def + [B1 >: B](kv: (A, B1)): mutable.Map[A, B1] = { val newMap = new TimeStampedWeakValueHashMap[A, B1] + val oldMap = nonNullReferenceMap.asInstanceOf[mutable.Map[A, WeakReference[B1]]] + newMap.internalMap.putAll(oldMap.toMap) newMap.internalMap += kv newMap } override def - (key: A): mutable.Map[A, B] = { val newMap = new TimeStampedWeakValueHashMap[A, B] + newMap.internalMap.putAll(nonNullReferenceMap.toMap) newMap.internalMap -= key newMap } override def += (kv: (A, B)): this.type = { internalMap += kv + if (insertCount.incrementAndGet() % CLEAR_NULL_VALUES_INTERVAL == 0) { + clearNullValues() + } this } @@ -71,31 +85,53 @@ private[spark] class TimeStampedWeakValueHashMap[A, B](updateTimeStampOnGet: Boo override def apply(key: A): B = internalMap.apply(key) - override def filter(p: ((A, B)) => Boolean): mutable.Map[A, B] = internalMap.filter(p) + override def filter(p: ((A, B)) => Boolean): mutable.Map[A, B] = nonNullReferenceMap.filter(p) override def empty: mutable.Map[A, B] = new TimeStampedWeakValueHashMap[A, B]() override def size: Int = internalMap.size - override def foreach[U](f: ((A, B)) => U) = internalMap.foreach(f) + override def foreach[U](f: ((A, B)) => U) = nonNullReferenceMap.foreach(f) def putIfAbsent(key: A, value: B): Option[B] = internalMap.putIfAbsent(key, value) - def toMap: immutable.Map[A, B] = iterator.toMap + def toMap: Map[A, B] = iterator.toMap - /** - * Remove old key-value pairs that have timestamp earlier than `threshTime`. - */ + /** Remove old key-value pairs with timestamps earlier than `threshTime`. */ def clearOldValues(threshTime: Long) = internalMap.clearOldValues(threshTime) + /** Remove entries with values that are no longer strongly reachable. */ + def clearNullValues() { + val it = internalMap.getEntrySet.iterator + while (it.hasNext) { + val entry = it.next() + if (entry.getValue.value.get == null) { + logDebug("Removing key " + entry.getKey + " because it is no longer strongly reachable.") + it.remove() + } + } + } + + // For testing + + def getTimestamp(key: A): Option[Long] = { + internalMap.getTimeStampedValue(key).map(_.timestamp) + } + + def getReference(key: A): Option[WeakReference[B]] = { + internalMap.getTimeStampedValue(key).map(_.value) + } } /** * Helper methods for converting to and from WeakReferences. */ -private[spark] object TimeStampedWeakValueHashMap { +private object TimeStampedWeakValueHashMap { - /* Implicit conversion methods to WeakReferences */ + // Number of inserts after which entries with null references are removed + val CLEAR_NULL_VALUES_INTERVAL = 100 + + /* Implicit conversion methods to WeakReferences. */ implicit def toWeakReference[V](v: V): WeakReference[V] = new WeakReference[V](v) @@ -107,12 +143,15 @@ private[spark] object TimeStampedWeakValueHashMap { (kv: (K, WeakReference[V])) => p(kv) } - /* Implicit conversion methods from WeakReferences */ + /* Implicit conversion methods from WeakReferences. */ implicit def fromWeakReference[V](ref: WeakReference[V]): V = ref.get implicit def fromWeakReferenceOption[V](v: Option[WeakReference[V]]): Option[V] = { - v.map(fromWeakReference) + v match { + case Some(ref) => Option(fromWeakReference(ref)) + case None => None + } } implicit def fromWeakReferenceTuple[K, V](kv: (K, WeakReference[V])): (K, V) = { @@ -128,5 +167,4 @@ private[spark] object TimeStampedWeakValueHashMap { map: mutable.Map[K, WeakReference[V]]) : mutable.Map[K, V] = { mutable.Map(map.mapValues(fromWeakReference).toSeq: _*) } - } diff --git a/core/src/test/scala/org/apache/spark/util/TimeStampedHashMapSuite.scala b/core/src/test/scala/org/apache/spark/util/TimeStampedHashMapSuite.scala new file mode 100644 index 0000000000000..6a5653ed2fb54 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/util/TimeStampedHashMapSuite.scala @@ -0,0 +1,264 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import java.lang.ref.WeakReference + +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer +import scala.util.Random + +import org.scalatest.FunSuite + +class TimeStampedHashMapSuite extends FunSuite { + + // Test the testMap function - a Scala HashMap should obviously pass + testMap(new mutable.HashMap[String, String]()) + + // Test TimeStampedHashMap basic functionality + testMap(new TimeStampedHashMap[String, String]()) + testMapThreadSafety(new TimeStampedHashMap[String, String]()) + + // Test TimeStampedWeakValueHashMap basic functionality + testMap(new TimeStampedWeakValueHashMap[String, String]()) + testMapThreadSafety(new TimeStampedWeakValueHashMap[String, String]()) + + test("TimeStampedHashMap - clearing by timestamp") { + // clearing by insertion time + val map = new TimeStampedHashMap[String, String](updateTimeStampOnGet = false) + map("k1") = "v1" + assert(map("k1") === "v1") + Thread.sleep(10) + val threshTime = System.currentTimeMillis + assert(map.getTimestamp("k1").isDefined) + assert(map.getTimestamp("k1").get < threshTime) + map.clearOldValues(threshTime) + assert(map.get("k1") === None) + + // clearing by modification time + val map1 = new TimeStampedHashMap[String, String](updateTimeStampOnGet = true) + map1("k1") = "v1" + map1("k2") = "v2" + assert(map1("k1") === "v1") + Thread.sleep(10) + val threshTime1 = System.currentTimeMillis + Thread.sleep(10) + assert(map1("k2") === "v2") // access k2 to update its access time to > threshTime + assert(map1.getTimestamp("k1").isDefined) + assert(map1.getTimestamp("k1").get < threshTime1) + assert(map1.getTimestamp("k2").isDefined) + assert(map1.getTimestamp("k2").get >= threshTime1) + map1.clearOldValues(threshTime1) //should only clear k1 + assert(map1.get("k1") === None) + assert(map1.get("k2").isDefined) + } + + test("TimeStampedWeakValueHashMap - clearing by timestamp") { + // clearing by insertion time + val map = new TimeStampedWeakValueHashMap[String, String](updateTimeStampOnGet = false) + map("k1") = "v1" + assert(map("k1") === "v1") + Thread.sleep(10) + val threshTime = System.currentTimeMillis + assert(map.getTimestamp("k1").isDefined) + assert(map.getTimestamp("k1").get < threshTime) + map.clearOldValues(threshTime) + assert(map.get("k1") === None) + + // clearing by modification time + val map1 = new TimeStampedWeakValueHashMap[String, String](updateTimeStampOnGet = true) + map1("k1") = "v1" + map1("k2") = "v2" + assert(map1("k1") === "v1") + Thread.sleep(10) + val threshTime1 = System.currentTimeMillis + Thread.sleep(10) + assert(map1("k2") === "v2") // access k2 to update its access time to > threshTime + assert(map1.getTimestamp("k1").isDefined) + assert(map1.getTimestamp("k1").get < threshTime1) + assert(map1.getTimestamp("k2").isDefined) + assert(map1.getTimestamp("k2").get >= threshTime1) + map1.clearOldValues(threshTime1) //should only clear k1 + assert(map1.get("k1") === None) + assert(map1.get("k2").isDefined) + } + + test("TimeStampedWeakValueHashMap - clearing weak references") { + var strongRef = new Object + val weakRef = new WeakReference(strongRef) + val map = new TimeStampedWeakValueHashMap[String, Object] + map("k1") = strongRef + map("k2") = "v2" + map("k3") = "v3" + assert(map("k1") === strongRef) + + // clear strong reference to "k1" + strongRef = null + val startTime = System.currentTimeMillis + System.gc() // Make a best effort to run the garbage collection. It *usually* runs GC. + System.runFinalization() // Make a best effort to call finalizer on all cleaned objects. + while(System.currentTimeMillis - startTime < 10000 && weakRef.get != null) { + System.gc() + System.runFinalization() + Thread.sleep(100) + } + assert(map.getReference("k1").isDefined) + val ref = map.getReference("k1").get + assert(ref.get === null) + assert(map.get("k1") === None) + + // operations should only display non-null entries + assert(map.iterator.forall { case (k, v) => k != "k1" }) + assert(map.filter { case (k, v) => k != "k2" }.size === 1) + assert(map.filter { case (k, v) => k != "k2" }.head._1 === "k3") + assert(map.toMap.size === 2) + assert(map.toMap.forall { case (k, v) => k != "k1" }) + val buffer = new ArrayBuffer[String] + map.foreach { case (k, v) => buffer += v.toString } + assert(buffer.size === 2) + assert(buffer.forall(_ != "k1")) + val plusMap = map + (("k4", "v4")) + assert(plusMap.size === 3) + assert(plusMap.forall { case (k, v) => k != "k1" }) + val minusMap = map - "k2" + assert(minusMap.size === 1) + assert(minusMap.head._1 == "k3") + + // clear null values - should only clear k1 + map.clearNullValues() + assert(map.getReference("k1") === None) + assert(map.get("k1") === None) + assert(map.get("k2").isDefined) + assert(map.get("k2").get === "v2") + } + + /** Test basic operations of a Scala mutable Map. */ + def testMap(hashMapConstructor: => mutable.Map[String, String]) { + def newMap() = hashMapConstructor + val testMap1 = newMap() + val testMap2 = newMap() + val name = testMap1.getClass.getSimpleName + + test(name + " - basic test") { + // put, get, and apply + testMap1 += (("k1", "v1")) + assert(testMap1.get("k1").isDefined) + assert(testMap1.get("k1").get === "v1") + testMap1("k2") = "v2" + assert(testMap1.get("k2").isDefined) + assert(testMap1.get("k2").get === "v2") + assert(testMap1("k2") === "v2") + testMap1.update("k3", "v3") + assert(testMap1.get("k3").isDefined) + assert(testMap1.get("k3").get === "v3") + + // remove + testMap1.remove("k1") + assert(testMap1.get("k1").isEmpty) + testMap1.remove("k2") + intercept[NoSuchElementException] { + testMap1("k2") // Map.apply() causes exception + } + testMap1 -= "k3" + assert(testMap1.get("k3").isEmpty) + + // multi put + val keys = (1 to 100).map(_.toString) + val pairs = keys.map(x => (x, x * 2)) + assert((testMap2 ++ pairs).iterator.toSet === pairs.toSet) + testMap2 ++= pairs + + // iterator + assert(testMap2.iterator.toSet === pairs.toSet) + + // filter + val filtered = testMap2.filter { case (_, v) => v.toInt % 2 == 0 } + val evenPairs = pairs.filter { case (_, v) => v.toInt % 2 == 0 } + assert(filtered.iterator.toSet === evenPairs.toSet) + + // foreach + val buffer = new ArrayBuffer[(String, String)] + testMap2.foreach(x => buffer += x) + assert(testMap2.toSet === buffer.toSet) + + // multi remove + testMap2("k1") = "v1" + testMap2 --= keys + assert(testMap2.size === 1) + assert(testMap2.iterator.toSeq.head === ("k1", "v1")) + + // + + val testMap3 = testMap2 + (("k0", "v0")) + assert(testMap3.size === 2) + assert(testMap3.get("k1").isDefined) + assert(testMap3.get("k1").get === "v1") + assert(testMap3.get("k0").isDefined) + assert(testMap3.get("k0").get === "v0") + + // - + val testMap4 = testMap3 - "k0" + assert(testMap4.size === 1) + assert(testMap4.get("k1").isDefined) + assert(testMap4.get("k1").get === "v1") + } + } + + /** Test thread safety of a Scala mutable map. */ + def testMapThreadSafety(hashMapConstructor: => mutable.Map[String, String]) { + def newMap() = hashMapConstructor + val name = newMap().getClass.getSimpleName + val testMap = newMap() + @volatile var error = false + + def getRandomKey(m: mutable.Map[String, String]): Option[String] = { + val keys = testMap.keysIterator.toSeq + if (keys.nonEmpty) { + Some(keys(Random.nextInt(keys.size))) + } else { + None + } + } + + val threads = (1 to 25).map(i => new Thread() { + override def run() { + try { + for (j <- 1 to 1000) { + Random.nextInt(3) match { + case 0 => + testMap(Random.nextString(10)) = Random.nextDouble().toString // put + case 1 => + getRandomKey(testMap).map(testMap.get) // get + case 2 => + getRandomKey(testMap).map(testMap.remove) // remove + } + } + } catch { + case t: Throwable => + error = true + throw t + } + } + }) + + test(name + " - threading safety test") { + threads.map(_.start) + threads.map(_.join) + assert(!error) + } + } +}