From 1e752f1a5c5f3887df2ca20d63a9d30f1d32f9d1 Mon Sep 17 00:00:00 2001 From: Roman Pastukhov Date: Wed, 5 Feb 2014 20:11:56 +0400 Subject: [PATCH] Added unpersist method to Broadcast. --- .../scala/org/apache/spark/SparkContext.scala | 7 ++- .../apache/spark/broadcast/Broadcast.scala | 13 ++++- .../spark/broadcast/BroadcastFactory.scala | 2 +- .../spark/broadcast/HttpBroadcast.scala | 45 ++++++++++++----- .../spark/broadcast/TorrentBroadcast.scala | 43 +++++++++++++---- .../apache/spark/storage/BlockManager.scala | 12 +++++ .../apache/spark/storage/MemoryStore.scala | 31 +++++++----- .../org/apache/spark/BroadcastSuite.scala | 48 +++++++++++++++++++ 8 files changed, 163 insertions(+), 38 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 566472e597958..f42589c3900d0 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -613,8 +613,13 @@ class SparkContext( * Broadcast a read-only variable to the cluster, returning a * [[org.apache.spark.broadcast.Broadcast]] object for reading it in distributed functions. * The variable will be sent to each cluster only once. + * + * If `registerBlocks` is true, workers will notify driver about blocks they create + * and these blocks will be dropped when `unpersist` method of the broadcast variable is called. */ - def broadcast[T](value: T) = env.broadcastManager.newBroadcast[T](value, isLocal) + def broadcast[T](value: T, registerBlocks: Boolean = false) = { + env.broadcastManager.newBroadcast[T](value, isLocal, registerBlocks) + } /** * Add a file to be downloaded with this Spark job on every node. diff --git a/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala index d113d4040594d..076d98f8de991 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala @@ -53,6 +53,15 @@ import org.apache.spark._ abstract class Broadcast[T](val id: Long) extends Serializable { def value: T + /** + * Removes all blocks of this broadcast from memory (and disk if removeSource is true). + * + * @param removeSource Whether to remove data from disk as well. + * Will cause errors if broadcast is accessed on workers afterwards + * (e.g. in case of RDD re-computation due to executor failure). + */ + def unpersist(removeSource: Boolean = false) + // We cannot have an abstract readObject here due to some weird issues with // readObject having to be 'private' in sub-classes. @@ -91,8 +100,8 @@ class BroadcastManager(val _isDriver: Boolean, conf: SparkConf) extends Logging private val nextBroadcastId = new AtomicLong(0) - def newBroadcast[T](value_ : T, isLocal: Boolean) = - broadcastFactory.newBroadcast[T](value_, isLocal, nextBroadcastId.getAndIncrement()) + def newBroadcast[T](value_ : T, isLocal: Boolean, registerBlocks: Boolean) = + broadcastFactory.newBroadcast[T](value_, isLocal, nextBroadcastId.getAndIncrement(), registerBlocks) def isDriver = _isDriver } diff --git a/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala b/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala index 940e5ab805100..e38283f244ea1 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala @@ -27,6 +27,6 @@ import org.apache.spark.SparkConf */ trait BroadcastFactory { def initialize(isDriver: Boolean, conf: SparkConf): Unit - def newBroadcast[T](value: T, isLocal: Boolean, id: Long): Broadcast[T] + def newBroadcast[T](value: T, isLocal: Boolean, id: Long, registerBlocks: Boolean): Broadcast[T] def stop(): Unit } diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala index 39ee0dbb92841..53fcc2748b4e0 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala @@ -29,11 +29,20 @@ import org.apache.spark.io.CompressionCodec import org.apache.spark.storage.{BroadcastBlockId, StorageLevel} import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashSet, Utils} -private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long) +private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long, registerBlocks: Boolean) extends Broadcast[T](id) with Logging with Serializable { def value = value_ + def unpersist(removeSource: Boolean) { + SparkEnv.get.blockManager.master.removeBlock(blockId) + SparkEnv.get.blockManager.removeBlock(blockId) + + if (removeSource) { + HttpBroadcast.cleanupById(id) + } + } + def blockId = BroadcastBlockId(id) HttpBroadcast.synchronized { @@ -54,7 +63,7 @@ private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolea logInfo("Started reading broadcast variable " + id) val start = System.nanoTime value_ = HttpBroadcast.read[T](id) - SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, false) + SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, registerBlocks) val time = (System.nanoTime - start) / 1e9 logInfo("Reading broadcast variable " + id + " took " + time + " s") } @@ -69,8 +78,8 @@ private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolea class HttpBroadcastFactory extends BroadcastFactory { def initialize(isDriver: Boolean, conf: SparkConf) { HttpBroadcast.initialize(isDriver, conf) } - def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) = - new HttpBroadcast[T](value_, isLocal, id) + def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long, registerBlocks: Boolean) = + new HttpBroadcast[T](value_, isLocal, id, registerBlocks) def stop() { HttpBroadcast.stop() } } @@ -132,8 +141,10 @@ private object HttpBroadcast extends Logging { logInfo("Broadcast server started at " + serverUri) } + def getFile(id: Long) = new File(broadcastDir, BroadcastBlockId(id).name) + def write(id: Long, value: Any) { - val file = new File(broadcastDir, BroadcastBlockId(id).name) + val file = getFile(id) val out: OutputStream = { if (compress) { compressionCodec.compressedOutputStream(new FileOutputStream(file)) @@ -167,20 +178,30 @@ private object HttpBroadcast extends Logging { obj } + def deleteFile(fileName: String) { + try { + new File(fileName).delete() + logInfo("Deleted broadcast file '" + fileName + "'") + } catch { + case e: Exception => logWarning("Could not delete broadcast file '" + fileName + "'", e) + } + } + def cleanup(cleanupTime: Long) { val iterator = files.internalMap.entrySet().iterator() while(iterator.hasNext) { val entry = iterator.next() val (file, time) = (entry.getKey, entry.getValue) if (time < cleanupTime) { - try { - iterator.remove() - new File(file.toString).delete() - logInfo("Deleted broadcast file '" + file + "'") - } catch { - case e: Exception => logWarning("Could not delete broadcast file '" + file + "'", e) - } + iterator.remove() + deleteFile(file) } } } + + def cleanupById(id: Long) { + val file = getFile(id).getAbsolutePath + files.internalMap.remove(file) + deleteFile(file) + } } diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala index d351dfc1f56a2..11e74675491c6 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala @@ -23,16 +23,36 @@ import scala.math import scala.util.Random import org.apache.spark._ -import org.apache.spark.storage.{BroadcastBlockId, BroadcastHelperBlockId, StorageLevel} +import org.apache.spark.storage.{BlockId, BroadcastBlockId, BroadcastHelperBlockId, StorageLevel} import org.apache.spark.util.Utils -private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long) +private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long, registerBlocks: Boolean) extends Broadcast[T](id) with Logging with Serializable { def value = value_ + def unpersist(removeSource: Boolean) { + SparkEnv.get.blockManager.master.removeBlock(broadcastId) + SparkEnv.get.blockManager.removeBlock(broadcastId) + + if (removeSource) { + for (pid <- pieceIds) { + SparkEnv.get.blockManager.removeBlock(pieceBlockId(pid)) + } + SparkEnv.get.blockManager.removeBlock(metaId) + } else { + for (pid <- pieceIds) { + SparkEnv.get.blockManager.dropFromMemory(pieceBlockId(pid)) + } + SparkEnv.get.blockManager.dropFromMemory(metaId) + } + } + def broadcastId = BroadcastBlockId(id) + private def metaId = BroadcastHelperBlockId(broadcastId, "meta") + private def pieceBlockId(pid: Int) = BroadcastHelperBlockId(broadcastId, "piece" + pid) + private def pieceIds = Array.iterate(0, totalBlocks)(_ + 1).toList TorrentBroadcast.synchronized { SparkEnv.get.blockManager.putSingle(broadcastId, value_, StorageLevel.MEMORY_AND_DISK, false) @@ -55,7 +75,6 @@ extends Broadcast[T](id) with Logging with Serializable { hasBlocks = tInfo.totalBlocks // Store meta-info - val metaId = BroadcastHelperBlockId(broadcastId, "meta") val metaInfo = TorrentInfo(null, totalBlocks, totalBytes) TorrentBroadcast.synchronized { SparkEnv.get.blockManager.putSingle( @@ -64,7 +83,7 @@ extends Broadcast[T](id) with Logging with Serializable { // Store individual pieces for (i <- 0 until totalBlocks) { - val pieceId = BroadcastHelperBlockId(broadcastId, "piece" + i) + val pieceId = pieceBlockId(i) TorrentBroadcast.synchronized { SparkEnv.get.blockManager.putSingle( pieceId, tInfo.arrayOfBlocks(i), StorageLevel.MEMORY_AND_DISK, true) @@ -94,7 +113,7 @@ extends Broadcast[T](id) with Logging with Serializable { // This creates a tradeoff between memory usage and latency. // Storing copy doubles the memory footprint; not storing doubles deserialization cost. SparkEnv.get.blockManager.putSingle( - broadcastId, value_, StorageLevel.MEMORY_AND_DISK, false) + broadcastId, value_, StorageLevel.MEMORY_AND_DISK, registerBlocks) // Remove arrayOfBlocks from memory once value_ is on local cache resetWorkerVariables() @@ -109,6 +128,11 @@ extends Broadcast[T](id) with Logging with Serializable { } private def resetWorkerVariables() { + if (arrayOfBlocks != null) { + for (pid <- pieceIds) { + SparkEnv.get.blockManager.removeBlock(pieceBlockId(pid)) + } + } arrayOfBlocks = null totalBytes = -1 totalBlocks = -1 @@ -117,7 +141,6 @@ extends Broadcast[T](id) with Logging with Serializable { def receiveBroadcast(variableID: Long): Boolean = { // Receive meta-info - val metaId = BroadcastHelperBlockId(broadcastId, "meta") var attemptId = 10 while (attemptId > 0 && totalBlocks == -1) { TorrentBroadcast.synchronized { @@ -140,9 +163,9 @@ extends Broadcast[T](id) with Logging with Serializable { } // Receive actual blocks - val recvOrder = new Random().shuffle(Array.iterate(0, totalBlocks)(_ + 1).toList) + val recvOrder = new Random().shuffle(pieceIds) for (pid <- recvOrder) { - val pieceId = BroadcastHelperBlockId(broadcastId, "piece" + pid) + val pieceId = pieceBlockId(pid) TorrentBroadcast.synchronized { SparkEnv.get.blockManager.getSingle(pieceId) match { case Some(x) => @@ -243,8 +266,8 @@ class TorrentBroadcastFactory extends BroadcastFactory { def initialize(isDriver: Boolean, conf: SparkConf) { TorrentBroadcast.initialize(isDriver, conf) } - def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) = - new TorrentBroadcast[T](value_, isLocal, id) + def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long, registerBlocks: Boolean) = + new TorrentBroadcast[T](value_, isLocal, id, registerBlocks) def stop() { TorrentBroadcast.stop() } } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index ed53558566edf..f8c121615567f 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -196,6 +196,11 @@ private[spark] class BlockManager( } } + /** + * For testing. Returns number of blocks BlockManager knows about that are in memory. + */ + def numberOfBlocksInMemory() = blockInfo.keys.count(memoryStore.contains(_)) + /** * Get storage level of local block. If no info exists for the block, then returns null. */ @@ -720,6 +725,13 @@ private[spark] class BlockManager( } /** + * Drop a block from memory, possibly putting it on disk if applicable. + */ + def dropFromMemory(blockId: BlockId) { + memoryStore.asInstanceOf[MemoryStore].dropFromMemory(blockId) + } + + /** * Remove all blocks belonging to the given RDD. * @return The number of blocks removed. */ diff --git a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala index eb5a18521683e..4e47a06c1fed2 100644 --- a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala @@ -182,6 +182,24 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long) } } + /** + * Drop a block from memory, possibly putting it on disk if applicable. + */ + def dropFromMemory(blockId: BlockId) { + val entry = entries.synchronized { entries.get(blockId) } + // This should never be null as only one thread should be dropping + // blocks and removing entries. However the check is still here for + // future safety. + if (entry != null) { + val data = if (entry.deserialized) { + Left(entry.value.asInstanceOf[ArrayBuffer[Any]]) + } else { + Right(entry.value.asInstanceOf[ByteBuffer].duplicate()) + } + blockManager.dropFromMemory(blockId, data) + } + } + /** * Tries to free up a given amount of space to store a particular block, but can fail and return * false if either the block is bigger than our memory or it would require replacing another @@ -227,18 +245,7 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long) if (maxMemory - (currentMemory - selectedMemory) >= space) { logInfo(selectedBlocks.size + " blocks selected for dropping") for (blockId <- selectedBlocks) { - val entry = entries.synchronized { entries.get(blockId) } - // This should never be null as only one thread should be dropping - // blocks and removing entries. However the check is still here for - // future safety. - if (entry != null) { - val data = if (entry.deserialized) { - Left(entry.value.asInstanceOf[ArrayBuffer[Any]]) - } else { - Right(entry.value.asInstanceOf[ByteBuffer].duplicate()) - } - blockManager.dropFromMemory(blockId, data) - } + dropFromMemory(blockId) } return true } else { diff --git a/core/src/test/scala/org/apache/spark/BroadcastSuite.scala b/core/src/test/scala/org/apache/spark/BroadcastSuite.scala index e022accee6d08..a657753144b24 100644 --- a/core/src/test/scala/org/apache/spark/BroadcastSuite.scala +++ b/core/src/test/scala/org/apache/spark/BroadcastSuite.scala @@ -18,6 +18,11 @@ package org.apache.spark import org.scalatest.FunSuite +import org.scalatest.concurrent.Timeouts._ +import org.scalatest.time.{Millis, Span} +import org.scalatest.concurrent.Eventually._ +import org.scalatest.time.SpanSugar._ +import org.scalatest.matchers.ShouldMatchers._ class BroadcastSuite extends FunSuite with LocalSparkContext { @@ -82,4 +87,47 @@ class BroadcastSuite extends FunSuite with LocalSparkContext { assert(results.collect.toSet === (1 to numSlaves).map(x => (x, 10)).toSet) } + def blocksExist(sc: SparkContext, numSlaves: Int) = { + val rdd = sc.parallelize(1 to numSlaves, numSlaves) + val workerBlocks = rdd.mapPartitions(_ => { + val blocks = SparkEnv.get.blockManager.numberOfBlocksInMemory() + Seq(blocks).iterator + }) + val totalKnown = workerBlocks.reduce(_ + _) + sc.env.blockManager.numberOfBlocksInMemory() + + totalKnown > 0 + } + + def testUnpersist(bcFactory: String, removeSource: Boolean) { + test("Broadcast unpersist(" + removeSource + ") with " + bcFactory) { + val numSlaves = 2 + System.setProperty("spark.broadcast.factory", bcFactory) + sc = new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test") + val list = List(1, 2, 3, 4) + + assert(!blocksExist(sc, numSlaves)) + + val listBroadcast = sc.broadcast(list, true) + val results = sc.parallelize(1 to numSlaves).map(x => (x, listBroadcast.value.sum)) + assert(results.collect.toSet === (1 to numSlaves).map(x => (x, 10)).toSet) + + assert(blocksExist(sc, numSlaves)) + + listBroadcast.unpersist(removeSource) + + eventually(timeout(1000 milliseconds), interval(10 milliseconds)) { + blocksExist(sc, numSlaves) should be (false) + } + + if (!removeSource) { + val results = sc.parallelize(1 to numSlaves).map(x => (x, listBroadcast.value.sum)) + assert(results.collect.toSet === (1 to numSlaves).map(x => (x, 10)).toSet) + } + } + } + + for (removeSource <- Seq(true, false)) { + testUnpersist("org.apache.spark.broadcast.HttpBroadcastFactory", removeSource) + testUnpersist("org.apache.spark.broadcast.TorrentBroadcastFactory", removeSource) + } }