Skip to content

Commit

Permalink
Generalize BroadcastBlockId to remove BroadcastHelperBlockId
Browse files Browse the repository at this point in the history
Rather than having a special purpose BroadcastHelperBlockId just for TorrentBroadcast,
we now have a single BroadcastBlockId that has a possibly empty field. This simplifies
broadcast clean-up because now we only have to look for one type of block.

This commit also simplifies BlockId JSON de/serialization in general by parsing the
name through regex with apply.
  • Loading branch information
andrewor14 committed Mar 28, 2014
1 parent 0d17060 commit 34f436f
Show file tree
Hide file tree
Showing 6 changed files with 54 additions and 140 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import scala.math
import scala.util.Random

import org.apache.spark.{Logging, SparkConf, SparkEnv, SparkException}
import org.apache.spark.storage.{BroadcastBlockId, BroadcastHelperBlockId, StorageLevel}
import org.apache.spark.storage.{BroadcastBlockId, StorageLevel}
import org.apache.spark.util.Utils

private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long)
Expand Down Expand Up @@ -54,7 +54,7 @@ private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boo
hasBlocks = tInfo.totalBlocks

// Store meta-info
val metaId = BroadcastHelperBlockId(broadcastId, "meta")
val metaId = BroadcastBlockId(id, "meta")
val metaInfo = TorrentInfo(null, totalBlocks, totalBytes)
TorrentBroadcast.synchronized {
SparkEnv.get.blockManager.putSingle(
Expand All @@ -63,7 +63,7 @@ private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boo

// Store individual pieces
for (i <- 0 until totalBlocks) {
val pieceId = BroadcastHelperBlockId(broadcastId, "piece" + i)
val pieceId = BroadcastBlockId(id, "piece" + i)
TorrentBroadcast.synchronized {
SparkEnv.get.blockManager.putSingle(
pieceId, tInfo.arrayOfBlocks(i), StorageLevel.MEMORY_AND_DISK, tellMaster = true)
Expand Down Expand Up @@ -131,7 +131,7 @@ private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boo

def receiveBroadcast(): Boolean = {
// Receive meta-info
val metaId = BroadcastHelperBlockId(broadcastId, "meta")
val metaId = BroadcastBlockId(id, "meta")
var attemptId = 10
while (attemptId > 0 && totalBlocks == -1) {
TorrentBroadcast.synchronized {
Expand All @@ -156,7 +156,7 @@ private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boo
// Receive actual blocks
val recvOrder = new Random().shuffle(Array.iterate(0, totalBlocks)(_ + 1).toList)
for (pid <- recvOrder) {
val pieceId = BroadcastHelperBlockId(broadcastId, "piece" + pid)
val pieceId = BroadcastBlockId(id, "piece" + pid)
TorrentBroadcast.synchronized {
SparkEnv.get.blockManager.getSingle(pieceId) match {
case Some(x) =>
Expand Down
29 changes: 17 additions & 12 deletions core/src/main/scala/org/apache/spark/storage/BlockId.scala
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ private[spark] sealed abstract class BlockId {
def asRDDId = if (isRDD) Some(asInstanceOf[RDDBlockId]) else None
def isRDD = isInstanceOf[RDDBlockId]
def isShuffle = isInstanceOf[ShuffleBlockId]
def isBroadcast = isInstanceOf[BroadcastBlockId] || isInstanceOf[BroadcastHelperBlockId]
def isBroadcast = isInstanceOf[BroadcastBlockId]

override def toString = name
override def hashCode = name.hashCode
Expand All @@ -48,18 +48,15 @@ private[spark] case class RDDBlockId(rddId: Int, splitIndex: Int) extends BlockI
def name = "rdd_" + rddId + "_" + splitIndex
}

private[spark]
case class ShuffleBlockId(shuffleId: Int, mapId: Int, reduceId: Int) extends BlockId {
private[spark] case class ShuffleBlockId(shuffleId: Int, mapId: Int, reduceId: Int)
extends BlockId {
def name = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId
}

// Leave field as an instance variable to avoid matching on it
private[spark] case class BroadcastBlockId(broadcastId: Long) extends BlockId {
def name = "broadcast_" + broadcastId
}

private[spark]
case class BroadcastHelperBlockId(broadcastId: BroadcastBlockId, hType: String) extends BlockId {
def name = broadcastId.name + "_" + hType
var field = ""
def name = "broadcast_" + broadcastId + (if (field == "") "" else "_" + field)
}

private[spark] case class TaskResultBlockId(taskId: Long) extends BlockId {
Expand All @@ -80,11 +77,19 @@ private[spark] case class TestBlockId(id: String) extends BlockId {
def name = "test_" + id
}

private[spark] object BroadcastBlockId {
def apply(broadcastId: Long, field: String) = {
val blockId = new BroadcastBlockId(broadcastId)
blockId.field = field
blockId
}
}

private[spark] object BlockId {
val RDD = "rdd_([0-9]+)_([0-9]+)".r
val SHUFFLE = "shuffle_([0-9]+)_([0-9]+)_([0-9]+)".r
val BROADCAST = "broadcast_([0-9]+)".r
val BROADCAST_HELPER = "broadcast_([0-9]+)_([A-Za-z0-9]+)".r
val BROADCAST_FIELD = "broadcast_([0-9]+)_([A-Za-z0-9]+)".r
val TASKRESULT = "taskresult_([0-9]+)".r
val STREAM = "input-([0-9]+)-([0-9]+)".r
val TEST = "test_(.*)".r
Expand All @@ -97,8 +102,8 @@ private[spark] object BlockId {
ShuffleBlockId(shuffleId.toInt, mapId.toInt, reduceId.toInt)
case BROADCAST(broadcastId) =>
BroadcastBlockId(broadcastId.toLong)
case BROADCAST_HELPER(broadcastId, hType) =>
BroadcastHelperBlockId(BroadcastBlockId(broadcastId.toLong), hType)
case BROADCAST_FIELD(broadcastId, field) =>
BroadcastBlockId(broadcastId.toLong, field)
case TASKRESULT(taskId) =>
TaskResultBlockId(taskId.toLong)
case STREAM(streamId, uniqueId) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -827,9 +827,8 @@ private[spark] class BlockManager(
*/
def removeBroadcast(broadcastId: Long, removeFromDriver: Boolean) {
logInfo("Removing broadcast " + broadcastId)
val blocksToRemove = blockInfo.keys.filter(_.isBroadcast).collect {
val blocksToRemove = blockInfo.keys.collect {
case bid: BroadcastBlockId if bid.broadcastId == broadcastId => bid
case bid: BroadcastHelperBlockId if bid.broadcastId.broadcastId == broadcastId => bid
}
blocksToRemove.foreach { blockId => removeBlock(blockId, removeFromDriver) }
}
Expand Down
77 changes: 2 additions & 75 deletions core/src/main/scala/org/apache/spark/util/JsonProtocol.scala
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ private[spark] object JsonProtocol {
taskMetrics.shuffleWriteMetrics.map(shuffleWriteMetricsToJson).getOrElse(JNothing)
val updatedBlocks = taskMetrics.updatedBlocks.map { blocks =>
JArray(blocks.toList.map { case (id, status) =>
("Block ID" -> blockIdToJson(id)) ~
("Block ID" -> id.toString) ~
("Status" -> blockStatusToJson(status))
})
}.getOrElse(JNothing)
Expand Down Expand Up @@ -284,35 +284,6 @@ private[spark] object JsonProtocol {
("Replication" -> storageLevel.replication)
}

def blockIdToJson(blockId: BlockId): JValue = {
val blockType = Utils.getFormattedClassName(blockId)
val json: JObject = blockId match {
case rddBlockId: RDDBlockId =>
("RDD ID" -> rddBlockId.rddId) ~
("Split Index" -> rddBlockId.splitIndex)
case shuffleBlockId: ShuffleBlockId =>
("Shuffle ID" -> shuffleBlockId.shuffleId) ~
("Map ID" -> shuffleBlockId.mapId) ~
("Reduce ID" -> shuffleBlockId.reduceId)
case broadcastBlockId: BroadcastBlockId =>
"Broadcast ID" -> broadcastBlockId.broadcastId
case broadcastHelperBlockId: BroadcastHelperBlockId =>
("Broadcast Block ID" -> blockIdToJson(broadcastHelperBlockId.broadcastId)) ~
("Helper Type" -> broadcastHelperBlockId.hType)
case taskResultBlockId: TaskResultBlockId =>
"Task ID" -> taskResultBlockId.taskId
case streamBlockId: StreamBlockId =>
("Stream ID" -> streamBlockId.streamId) ~
("Unique ID" -> streamBlockId.uniqueId)
case tempBlockId: TempBlockId =>
val uuid = UUIDToJson(tempBlockId.id)
"Temp ID" -> uuid
case testBlockId: TestBlockId =>
"Test ID" -> testBlockId.id
}
("Type" -> blockType) ~ json
}

def blockStatusToJson(blockStatus: BlockStatus): JValue = {
val storageLevel = storageLevelToJson(blockStatus.storageLevel)
("Storage Level" -> storageLevel) ~
Expand Down Expand Up @@ -513,7 +484,7 @@ private[spark] object JsonProtocol {
Utils.jsonOption(json \ "Shuffle Write Metrics").map(shuffleWriteMetricsFromJson)
metrics.updatedBlocks = Utils.jsonOption(json \ "Updated Blocks").map { value =>
value.extract[List[JValue]].map { block =>
val id = blockIdFromJson(block \ "Block ID")
val id = BlockId((block \ "Block ID").extract[String])
val status = blockStatusFromJson(block \ "Status")
(id, status)
}
Expand Down Expand Up @@ -616,50 +587,6 @@ private[spark] object JsonProtocol {
StorageLevel(useDisk, useMemory, deserialized, replication)
}

def blockIdFromJson(json: JValue): BlockId = {
val rddBlockId = Utils.getFormattedClassName(RDDBlockId)
val shuffleBlockId = Utils.getFormattedClassName(ShuffleBlockId)
val broadcastBlockId = Utils.getFormattedClassName(BroadcastBlockId)
val broadcastHelperBlockId = Utils.getFormattedClassName(BroadcastHelperBlockId)
val taskResultBlockId = Utils.getFormattedClassName(TaskResultBlockId)
val streamBlockId = Utils.getFormattedClassName(StreamBlockId)
val tempBlockId = Utils.getFormattedClassName(TempBlockId)
val testBlockId = Utils.getFormattedClassName(TestBlockId)

(json \ "Type").extract[String] match {
case `rddBlockId` =>
val rddId = (json \ "RDD ID").extract[Int]
val splitIndex = (json \ "Split Index").extract[Int]
new RDDBlockId(rddId, splitIndex)
case `shuffleBlockId` =>
val shuffleId = (json \ "Shuffle ID").extract[Int]
val mapId = (json \ "Map ID").extract[Int]
val reduceId = (json \ "Reduce ID").extract[Int]
new ShuffleBlockId(shuffleId, mapId, reduceId)
case `broadcastBlockId` =>
val broadcastId = (json \ "Broadcast ID").extract[Long]
new BroadcastBlockId(broadcastId)
case `broadcastHelperBlockId` =>
val broadcastBlockId =
blockIdFromJson(json \ "Broadcast Block ID").asInstanceOf[BroadcastBlockId]
val hType = (json \ "Helper Type").extract[String]
new BroadcastHelperBlockId(broadcastBlockId, hType)
case `taskResultBlockId` =>
val taskId = (json \ "Task ID").extract[Long]
new TaskResultBlockId(taskId)
case `streamBlockId` =>
val streamId = (json \ "Stream ID").extract[Int]
val uniqueId = (json \ "Unique ID").extract[Long]
new StreamBlockId(streamId, uniqueId)
case `tempBlockId` =>
val tempId = UUIDFromJson(json \ "Temp ID")
new TempBlockId(tempId)
case `testBlockId` =>
val testId = (json \ "Test ID").extract[String]
new TestBlockId(testId)
}
}

def blockStatusFromJson(json: JValue): BlockStatus = {
val storageLevel = storageLevelFromJson(json \ "Storage Level")
val memorySize = (json \ "Memory Size").extract[Long]
Expand Down
61 changes: 29 additions & 32 deletions core/src/test/scala/org/apache/spark/BroadcastSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import org.scalatest.FunSuite

import org.apache.spark.storage._
import org.apache.spark.broadcast.HttpBroadcast
import org.apache.spark.storage.{BroadcastBlockId, BroadcastHelperBlockId}
import org.apache.spark.storage.BroadcastBlockId

class BroadcastSuite extends FunSuite with LocalSparkContext {

Expand Down Expand Up @@ -102,23 +102,22 @@ class BroadcastSuite extends FunSuite with LocalSparkContext {
* are present only on the expected nodes.
*/
private def testUnpersistHttpBroadcast(numSlaves: Int, removeFromDriver: Boolean) {
def getBlockIds(id: Long) = Seq[BlockId](BroadcastBlockId(id))
def getBlockIds(id: Long) = Seq[BroadcastBlockId](BroadcastBlockId(id))

// Verify that the broadcast file is created, and blocks are persisted only on the driver
def afterCreation(blockIds: Seq[BlockId], bmm: BlockManagerMaster) {
def afterCreation(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) {
assert(blockIds.size === 1)
val broadcastBlockId = blockIds.head.asInstanceOf[BroadcastBlockId]
val levels = bmm.askForStorageLevels(broadcastBlockId, waitTimeMs = 0)
val levels = bmm.askForStorageLevels(blockIds.head, waitTimeMs = 0)
assert(levels.size === 1)
levels.head match { case (bm, level) =>
assert(bm.executorId === "<driver>")
assert(level === StorageLevel.MEMORY_AND_DISK)
}
assert(HttpBroadcast.getFile(broadcastBlockId.broadcastId).exists)
assert(HttpBroadcast.getFile(blockIds.head.broadcastId).exists)
}

// Verify that blocks are persisted in both the executors and the driver
def afterUsingBroadcast(blockIds: Seq[BlockId], bmm: BlockManagerMaster) {
def afterUsingBroadcast(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) {
assert(blockIds.size === 1)
val levels = bmm.askForStorageLevels(blockIds.head, waitTimeMs = 0)
assert(levels.size === numSlaves + 1)
Expand All @@ -129,12 +128,11 @@ class BroadcastSuite extends FunSuite with LocalSparkContext {

// Verify that blocks are unpersisted on all executors, and on all nodes if removeFromDriver
// is true. In the latter case, also verify that the broadcast file is deleted on the driver.
def afterUnpersist(blockIds: Seq[BlockId], bmm: BlockManagerMaster) {
def afterUnpersist(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) {
assert(blockIds.size === 1)
val broadcastBlockId = blockIds.head.asInstanceOf[BroadcastBlockId]
val levels = bmm.askForStorageLevels(broadcastBlockId, waitTimeMs = 0)
val levels = bmm.askForStorageLevels(blockIds.head, waitTimeMs = 0)
assert(levels.size === (if (removeFromDriver) 0 else 1))
assert(removeFromDriver === !HttpBroadcast.getFile(broadcastBlockId.broadcastId).exists)
assert(removeFromDriver === !HttpBroadcast.getFile(blockIds.head.broadcastId).exists)
}

testUnpersistBroadcast(numSlaves, httpConf, getBlockIds, afterCreation,
Expand All @@ -151,14 +149,14 @@ class BroadcastSuite extends FunSuite with LocalSparkContext {
private def testUnpersistTorrentBroadcast(numSlaves: Int, removeFromDriver: Boolean) {
def getBlockIds(id: Long) = {
val broadcastBlockId = BroadcastBlockId(id)
val metaBlockId = BroadcastHelperBlockId(broadcastBlockId, "meta")
val metaBlockId = BroadcastBlockId(id, "meta")
// Assume broadcast value is small enough to fit into 1 piece
val pieceBlockId = BroadcastHelperBlockId(broadcastBlockId, "piece0")
Seq[BlockId](broadcastBlockId, metaBlockId, pieceBlockId)
val pieceBlockId = BroadcastBlockId(id, "piece0")
Seq[BroadcastBlockId](broadcastBlockId, metaBlockId, pieceBlockId)
}

// Verify that blocks are persisted only on the driver
def afterCreation(blockIds: Seq[BlockId], bmm: BlockManagerMaster) {
def afterCreation(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) {
blockIds.foreach { blockId =>
val levels = bmm.askForStorageLevels(blockId, waitTimeMs = 0)
assert(levels.size === 1)
Expand All @@ -170,27 +168,26 @@ class BroadcastSuite extends FunSuite with LocalSparkContext {
}

// Verify that blocks are persisted in both the executors and the driver
def afterUsingBroadcast(blockIds: Seq[BlockId], bmm: BlockManagerMaster) {
def afterUsingBroadcast(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) {
blockIds.foreach { blockId =>
val levels = bmm.askForStorageLevels(blockId, waitTimeMs = 0)
blockId match {
case BroadcastHelperBlockId(_, "meta") =>
// Meta data is only on the driver
assert(levels.size === 1)
levels.head match { case (bm, _) => assert(bm.executorId === "<driver>") }
case _ =>
// Other blocks are on both the executors and the driver
assert(levels.size === numSlaves + 1)
levels.foreach { case (_, level) =>
assert(level === StorageLevel.MEMORY_AND_DISK)
}
if (blockId.field == "meta") {
// Meta data is only on the driver
assert(levels.size === 1)
levels.head match { case (bm, _) => assert(bm.executorId === "<driver>") }
} else {
// Other blocks are on both the executors and the driver
assert(levels.size === numSlaves + 1)
levels.foreach { case (_, level) =>
assert(level === StorageLevel.MEMORY_AND_DISK)
}
}
}
}

// Verify that blocks are unpersisted on all executors, and on all nodes if removeFromDriver
// is true.
def afterUnpersist(blockIds: Seq[BlockId], bmm: BlockManagerMaster) {
def afterUnpersist(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) {
val expectedNumBlocks = if (removeFromDriver) 0 else 1
var waitTimeMs = 1000L
blockIds.foreach { blockId =>
Expand All @@ -217,10 +214,10 @@ class BroadcastSuite extends FunSuite with LocalSparkContext {
private def testUnpersistBroadcast(
numSlaves: Int,
broadcastConf: SparkConf,
getBlockIds: Long => Seq[BlockId],
afterCreation: (Seq[BlockId], BlockManagerMaster) => Unit,
afterUsingBroadcast: (Seq[BlockId], BlockManagerMaster) => Unit,
afterUnpersist: (Seq[BlockId], BlockManagerMaster) => Unit,
getBlockIds: Long => Seq[BroadcastBlockId],
afterCreation: (Seq[BroadcastBlockId], BlockManagerMaster) => Unit,
afterUsingBroadcast: (Seq[BroadcastBlockId], BlockManagerMaster) => Unit,
afterUnpersist: (Seq[BroadcastBlockId], BlockManagerMaster) => Unit,
removeFromDriver: Boolean) {

sc = new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test", broadcastConf)
Expand Down
14 changes: 0 additions & 14 deletions core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -104,15 +104,6 @@ class JsonProtocolSuite extends FunSuite {
testTaskEndReason(TaskKilled)
testTaskEndReason(ExecutorLostFailure)
testTaskEndReason(UnknownReason)

// BlockId
testBlockId(RDDBlockId(1, 2))
testBlockId(ShuffleBlockId(1, 2, 3))
testBlockId(BroadcastBlockId(1L))
testBlockId(BroadcastHelperBlockId(BroadcastBlockId(2L), "Spark"))
testBlockId(TaskResultBlockId(1L))
testBlockId(StreamBlockId(1, 2L))
testBlockId(TempBlockId(UUID.randomUUID()))
}


Expand Down Expand Up @@ -167,11 +158,6 @@ class JsonProtocolSuite extends FunSuite {
assertEquals(reason, newReason)
}

private def testBlockId(blockId: BlockId) {
val newBlockId = JsonProtocol.blockIdFromJson(JsonProtocol.blockIdToJson(blockId))
blockId == newBlockId
}


/** -------------------------------- *
| Util methods for comparing events |
Expand Down

0 comments on commit 34f436f

Please sign in to comment.