Skip to content

Commit

Permalink
Added unpersist method to Broadcast.
Browse files Browse the repository at this point in the history
  • Loading branch information
Roman Pastukhov committed Feb 5, 2014
1 parent 9209287 commit 1e752f1
Show file tree
Hide file tree
Showing 8 changed files with 163 additions and 38 deletions.
7 changes: 6 additions & 1 deletion core/src/main/scala/org/apache/spark/SparkContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -613,8 +613,13 @@ class SparkContext(
* Broadcast a read-only variable to the cluster, returning a
* [[org.apache.spark.broadcast.Broadcast]] object for reading it in distributed functions.
* The variable will be sent to each cluster only once.
*
* If `registerBlocks` is true, workers will notify driver about blocks they create
* and these blocks will be dropped when `unpersist` method of the broadcast variable is called.
*/
def broadcast[T](value: T) = env.broadcastManager.newBroadcast[T](value, isLocal)
def broadcast[T](value: T, registerBlocks: Boolean = false) = {
env.broadcastManager.newBroadcast[T](value, isLocal, registerBlocks)
}

/**
* Add a file to be downloaded with this Spark job on every node.
Expand Down
13 changes: 11 additions & 2 deletions core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,15 @@ import org.apache.spark._
abstract class Broadcast[T](val id: Long) extends Serializable {
def value: T

/**
* Removes all blocks of this broadcast from memory (and disk if removeSource is true).
*
* @param removeSource Whether to remove data from disk as well.
* Will cause errors if broadcast is accessed on workers afterwards
* (e.g. in case of RDD re-computation due to executor failure).
*/
def unpersist(removeSource: Boolean = false)

// We cannot have an abstract readObject here due to some weird issues with
// readObject having to be 'private' in sub-classes.

Expand Down Expand Up @@ -91,8 +100,8 @@ class BroadcastManager(val _isDriver: Boolean, conf: SparkConf) extends Logging

private val nextBroadcastId = new AtomicLong(0)

def newBroadcast[T](value_ : T, isLocal: Boolean) =
broadcastFactory.newBroadcast[T](value_, isLocal, nextBroadcastId.getAndIncrement())
def newBroadcast[T](value_ : T, isLocal: Boolean, registerBlocks: Boolean) =
broadcastFactory.newBroadcast[T](value_, isLocal, nextBroadcastId.getAndIncrement(), registerBlocks)

def isDriver = _isDriver
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,6 @@ import org.apache.spark.SparkConf
*/
trait BroadcastFactory {
def initialize(isDriver: Boolean, conf: SparkConf): Unit
def newBroadcast[T](value: T, isLocal: Boolean, id: Long): Broadcast[T]
def newBroadcast[T](value: T, isLocal: Boolean, id: Long, registerBlocks: Boolean): Broadcast[T]
def stop(): Unit
}
45 changes: 33 additions & 12 deletions core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,20 @@ import org.apache.spark.io.CompressionCodec
import org.apache.spark.storage.{BroadcastBlockId, StorageLevel}
import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashSet, Utils}

private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long)
private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long, registerBlocks: Boolean)
extends Broadcast[T](id) with Logging with Serializable {

def value = value_

def unpersist(removeSource: Boolean) {
SparkEnv.get.blockManager.master.removeBlock(blockId)
SparkEnv.get.blockManager.removeBlock(blockId)

if (removeSource) {
HttpBroadcast.cleanupById(id)
}
}

def blockId = BroadcastBlockId(id)

HttpBroadcast.synchronized {
Expand All @@ -54,7 +63,7 @@ private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolea
logInfo("Started reading broadcast variable " + id)
val start = System.nanoTime
value_ = HttpBroadcast.read[T](id)
SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, false)
SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, registerBlocks)
val time = (System.nanoTime - start) / 1e9
logInfo("Reading broadcast variable " + id + " took " + time + " s")
}
Expand All @@ -69,8 +78,8 @@ private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolea
class HttpBroadcastFactory extends BroadcastFactory {
def initialize(isDriver: Boolean, conf: SparkConf) { HttpBroadcast.initialize(isDriver, conf) }

def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) =
new HttpBroadcast[T](value_, isLocal, id)
def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long, registerBlocks: Boolean) =
new HttpBroadcast[T](value_, isLocal, id, registerBlocks)

def stop() { HttpBroadcast.stop() }
}
Expand Down Expand Up @@ -132,8 +141,10 @@ private object HttpBroadcast extends Logging {
logInfo("Broadcast server started at " + serverUri)
}

def getFile(id: Long) = new File(broadcastDir, BroadcastBlockId(id).name)

def write(id: Long, value: Any) {
val file = new File(broadcastDir, BroadcastBlockId(id).name)
val file = getFile(id)
val out: OutputStream = {
if (compress) {
compressionCodec.compressedOutputStream(new FileOutputStream(file))
Expand Down Expand Up @@ -167,20 +178,30 @@ private object HttpBroadcast extends Logging {
obj
}

def deleteFile(fileName: String) {
try {
new File(fileName).delete()
logInfo("Deleted broadcast file '" + fileName + "'")
} catch {
case e: Exception => logWarning("Could not delete broadcast file '" + fileName + "'", e)
}
}

def cleanup(cleanupTime: Long) {
val iterator = files.internalMap.entrySet().iterator()
while(iterator.hasNext) {
val entry = iterator.next()
val (file, time) = (entry.getKey, entry.getValue)
if (time < cleanupTime) {
try {
iterator.remove()
new File(file.toString).delete()
logInfo("Deleted broadcast file '" + file + "'")
} catch {
case e: Exception => logWarning("Could not delete broadcast file '" + file + "'", e)
}
iterator.remove()
deleteFile(file)
}
}
}

def cleanupById(id: Long) {
val file = getFile(id).getAbsolutePath
files.internalMap.remove(file)
deleteFile(file)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,36 @@ import scala.math
import scala.util.Random

import org.apache.spark._
import org.apache.spark.storage.{BroadcastBlockId, BroadcastHelperBlockId, StorageLevel}
import org.apache.spark.storage.{BlockId, BroadcastBlockId, BroadcastHelperBlockId, StorageLevel}
import org.apache.spark.util.Utils


private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long)
private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long, registerBlocks: Boolean)
extends Broadcast[T](id) with Logging with Serializable {

def value = value_

def unpersist(removeSource: Boolean) {
SparkEnv.get.blockManager.master.removeBlock(broadcastId)
SparkEnv.get.blockManager.removeBlock(broadcastId)

if (removeSource) {
for (pid <- pieceIds) {
SparkEnv.get.blockManager.removeBlock(pieceBlockId(pid))
}
SparkEnv.get.blockManager.removeBlock(metaId)
} else {
for (pid <- pieceIds) {
SparkEnv.get.blockManager.dropFromMemory(pieceBlockId(pid))
}
SparkEnv.get.blockManager.dropFromMemory(metaId)
}
}

def broadcastId = BroadcastBlockId(id)
private def metaId = BroadcastHelperBlockId(broadcastId, "meta")
private def pieceBlockId(pid: Int) = BroadcastHelperBlockId(broadcastId, "piece" + pid)
private def pieceIds = Array.iterate(0, totalBlocks)(_ + 1).toList

TorrentBroadcast.synchronized {
SparkEnv.get.blockManager.putSingle(broadcastId, value_, StorageLevel.MEMORY_AND_DISK, false)
Expand All @@ -55,7 +75,6 @@ extends Broadcast[T](id) with Logging with Serializable {
hasBlocks = tInfo.totalBlocks

// Store meta-info
val metaId = BroadcastHelperBlockId(broadcastId, "meta")
val metaInfo = TorrentInfo(null, totalBlocks, totalBytes)
TorrentBroadcast.synchronized {
SparkEnv.get.blockManager.putSingle(
Expand All @@ -64,7 +83,7 @@ extends Broadcast[T](id) with Logging with Serializable {

// Store individual pieces
for (i <- 0 until totalBlocks) {
val pieceId = BroadcastHelperBlockId(broadcastId, "piece" + i)
val pieceId = pieceBlockId(i)
TorrentBroadcast.synchronized {
SparkEnv.get.blockManager.putSingle(
pieceId, tInfo.arrayOfBlocks(i), StorageLevel.MEMORY_AND_DISK, true)
Expand Down Expand Up @@ -94,7 +113,7 @@ extends Broadcast[T](id) with Logging with Serializable {
// This creates a tradeoff between memory usage and latency.
// Storing copy doubles the memory footprint; not storing doubles deserialization cost.
SparkEnv.get.blockManager.putSingle(
broadcastId, value_, StorageLevel.MEMORY_AND_DISK, false)
broadcastId, value_, StorageLevel.MEMORY_AND_DISK, registerBlocks)

// Remove arrayOfBlocks from memory once value_ is on local cache
resetWorkerVariables()
Expand All @@ -109,6 +128,11 @@ extends Broadcast[T](id) with Logging with Serializable {
}

private def resetWorkerVariables() {
if (arrayOfBlocks != null) {
for (pid <- pieceIds) {
SparkEnv.get.blockManager.removeBlock(pieceBlockId(pid))
}
}
arrayOfBlocks = null
totalBytes = -1
totalBlocks = -1
Expand All @@ -117,7 +141,6 @@ extends Broadcast[T](id) with Logging with Serializable {

def receiveBroadcast(variableID: Long): Boolean = {
// Receive meta-info
val metaId = BroadcastHelperBlockId(broadcastId, "meta")
var attemptId = 10
while (attemptId > 0 && totalBlocks == -1) {
TorrentBroadcast.synchronized {
Expand All @@ -140,9 +163,9 @@ extends Broadcast[T](id) with Logging with Serializable {
}

// Receive actual blocks
val recvOrder = new Random().shuffle(Array.iterate(0, totalBlocks)(_ + 1).toList)
val recvOrder = new Random().shuffle(pieceIds)
for (pid <- recvOrder) {
val pieceId = BroadcastHelperBlockId(broadcastId, "piece" + pid)
val pieceId = pieceBlockId(pid)
TorrentBroadcast.synchronized {
SparkEnv.get.blockManager.getSingle(pieceId) match {
case Some(x) =>
Expand Down Expand Up @@ -243,8 +266,8 @@ class TorrentBroadcastFactory extends BroadcastFactory {

def initialize(isDriver: Boolean, conf: SparkConf) { TorrentBroadcast.initialize(isDriver, conf) }

def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) =
new TorrentBroadcast[T](value_, isLocal, id)
def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long, registerBlocks: Boolean) =
new TorrentBroadcast[T](value_, isLocal, id, registerBlocks)

def stop() { TorrentBroadcast.stop() }
}
12 changes: 12 additions & 0 deletions core/src/main/scala/org/apache/spark/storage/BlockManager.scala
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,11 @@ private[spark] class BlockManager(
}
}

/**
* For testing. Returns number of blocks BlockManager knows about that are in memory.
*/
def numberOfBlocksInMemory() = blockInfo.keys.count(memoryStore.contains(_))

/**
* Get storage level of local block. If no info exists for the block, then returns null.
*/
Expand Down Expand Up @@ -720,6 +725,13 @@ private[spark] class BlockManager(
}

/**
* Drop a block from memory, possibly putting it on disk if applicable.
*/
def dropFromMemory(blockId: BlockId) {
memoryStore.asInstanceOf[MemoryStore].dropFromMemory(blockId)
}

/**
* Remove all blocks belonging to the given RDD.
* @return The number of blocks removed.
*/
Expand Down
31 changes: 19 additions & 12 deletions core/src/main/scala/org/apache/spark/storage/MemoryStore.scala
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,24 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long)
}
}

/**
* Drop a block from memory, possibly putting it on disk if applicable.
*/
def dropFromMemory(blockId: BlockId) {
val entry = entries.synchronized { entries.get(blockId) }
// This should never be null as only one thread should be dropping
// blocks and removing entries. However the check is still here for
// future safety.
if (entry != null) {
val data = if (entry.deserialized) {
Left(entry.value.asInstanceOf[ArrayBuffer[Any]])
} else {
Right(entry.value.asInstanceOf[ByteBuffer].duplicate())
}
blockManager.dropFromMemory(blockId, data)
}
}

/**
* Tries to free up a given amount of space to store a particular block, but can fail and return
* false if either the block is bigger than our memory or it would require replacing another
Expand Down Expand Up @@ -227,18 +245,7 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long)
if (maxMemory - (currentMemory - selectedMemory) >= space) {
logInfo(selectedBlocks.size + " blocks selected for dropping")
for (blockId <- selectedBlocks) {
val entry = entries.synchronized { entries.get(blockId) }
// This should never be null as only one thread should be dropping
// blocks and removing entries. However the check is still here for
// future safety.
if (entry != null) {
val data = if (entry.deserialized) {
Left(entry.value.asInstanceOf[ArrayBuffer[Any]])
} else {
Right(entry.value.asInstanceOf[ByteBuffer].duplicate())
}
blockManager.dropFromMemory(blockId, data)
}
dropFromMemory(blockId)
}
return true
} else {
Expand Down
48 changes: 48 additions & 0 deletions core/src/test/scala/org/apache/spark/BroadcastSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@
package org.apache.spark

import org.scalatest.FunSuite
import org.scalatest.concurrent.Timeouts._
import org.scalatest.time.{Millis, Span}
import org.scalatest.concurrent.Eventually._
import org.scalatest.time.SpanSugar._
import org.scalatest.matchers.ShouldMatchers._

class BroadcastSuite extends FunSuite with LocalSparkContext {

Expand Down Expand Up @@ -82,4 +87,47 @@ class BroadcastSuite extends FunSuite with LocalSparkContext {
assert(results.collect.toSet === (1 to numSlaves).map(x => (x, 10)).toSet)
}

def blocksExist(sc: SparkContext, numSlaves: Int) = {
val rdd = sc.parallelize(1 to numSlaves, numSlaves)
val workerBlocks = rdd.mapPartitions(_ => {
val blocks = SparkEnv.get.blockManager.numberOfBlocksInMemory()
Seq(blocks).iterator
})
val totalKnown = workerBlocks.reduce(_ + _) + sc.env.blockManager.numberOfBlocksInMemory()

totalKnown > 0
}

def testUnpersist(bcFactory: String, removeSource: Boolean) {
test("Broadcast unpersist(" + removeSource + ") with " + bcFactory) {
val numSlaves = 2
System.setProperty("spark.broadcast.factory", bcFactory)
sc = new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test")
val list = List(1, 2, 3, 4)

assert(!blocksExist(sc, numSlaves))

val listBroadcast = sc.broadcast(list, true)
val results = sc.parallelize(1 to numSlaves).map(x => (x, listBroadcast.value.sum))
assert(results.collect.toSet === (1 to numSlaves).map(x => (x, 10)).toSet)

assert(blocksExist(sc, numSlaves))

listBroadcast.unpersist(removeSource)

eventually(timeout(1000 milliseconds), interval(10 milliseconds)) {
blocksExist(sc, numSlaves) should be (false)
}

if (!removeSource) {
val results = sc.parallelize(1 to numSlaves).map(x => (x, listBroadcast.value.sum))
assert(results.collect.toSet === (1 to numSlaves).map(x => (x, 10)).toSet)
}
}
}

for (removeSource <- Seq(true, false)) {
testUnpersist("org.apache.spark.broadcast.HttpBroadcastFactory", removeSource)
testUnpersist("org.apache.spark.broadcast.TorrentBroadcastFactory", removeSource)
}
}

0 comments on commit 1e752f1

Please sign in to comment.