From 34f436f7d1799a6fd22b745d339734f220108dae Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Fri, 28 Mar 2014 13:17:20 -0700 Subject: [PATCH] Generalize BroadcastBlockId to remove BroadcastHelperBlockId Rather than having a special purpose BroadcastHelperBlockId just for TorrentBroadcast, we now have a single BroadcastBlockId that has a possibly empty field. This simplifies broadcast clean-up because now we only have to look for one type of block. This commit also simplifies BlockId JSON de/serialization in general by parsing the name through regex with apply. --- .../spark/broadcast/TorrentBroadcast.scala | 10 +-- .../org/apache/spark/storage/BlockId.scala | 29 ++++--- .../apache/spark/storage/BlockManager.scala | 3 +- .../org/apache/spark/util/JsonProtocol.scala | 77 +------------------ .../org/apache/spark/BroadcastSuite.scala | 61 +++++++-------- .../apache/spark/util/JsonProtocolSuite.scala | 14 ---- 6 files changed, 54 insertions(+), 140 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala index ab280fad4e28f..dbe65d88104fb 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala @@ -23,7 +23,7 @@ import scala.math import scala.util.Random import org.apache.spark.{Logging, SparkConf, SparkEnv, SparkException} -import org.apache.spark.storage.{BroadcastBlockId, BroadcastHelperBlockId, StorageLevel} +import org.apache.spark.storage.{BroadcastBlockId, StorageLevel} import org.apache.spark.util.Utils private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long) @@ -54,7 +54,7 @@ private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boo hasBlocks = tInfo.totalBlocks // Store meta-info - val metaId = BroadcastHelperBlockId(broadcastId, "meta") + val metaId = BroadcastBlockId(id, "meta") val metaInfo = TorrentInfo(null, totalBlocks, totalBytes) TorrentBroadcast.synchronized { SparkEnv.get.blockManager.putSingle( @@ -63,7 +63,7 @@ private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boo // Store individual pieces for (i <- 0 until totalBlocks) { - val pieceId = BroadcastHelperBlockId(broadcastId, "piece" + i) + val pieceId = BroadcastBlockId(id, "piece" + i) TorrentBroadcast.synchronized { SparkEnv.get.blockManager.putSingle( pieceId, tInfo.arrayOfBlocks(i), StorageLevel.MEMORY_AND_DISK, tellMaster = true) @@ -131,7 +131,7 @@ private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boo def receiveBroadcast(): Boolean = { // Receive meta-info - val metaId = BroadcastHelperBlockId(broadcastId, "meta") + val metaId = BroadcastBlockId(id, "meta") var attemptId = 10 while (attemptId > 0 && totalBlocks == -1) { TorrentBroadcast.synchronized { @@ -156,7 +156,7 @@ private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boo // Receive actual blocks val recvOrder = new Random().shuffle(Array.iterate(0, totalBlocks)(_ + 1).toList) for (pid <- recvOrder) { - val pieceId = BroadcastHelperBlockId(broadcastId, "piece" + pid) + val pieceId = BroadcastBlockId(id, "piece" + pid) TorrentBroadcast.synchronized { SparkEnv.get.blockManager.getSingle(pieceId) match { case Some(x) => diff --git a/core/src/main/scala/org/apache/spark/storage/BlockId.scala b/core/src/main/scala/org/apache/spark/storage/BlockId.scala index 301d784b350a3..27e271368ed06 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockId.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockId.scala @@ -34,7 +34,7 @@ private[spark] sealed abstract class BlockId { def asRDDId = if (isRDD) Some(asInstanceOf[RDDBlockId]) else None def isRDD = isInstanceOf[RDDBlockId] def isShuffle = isInstanceOf[ShuffleBlockId] - def isBroadcast = isInstanceOf[BroadcastBlockId] || isInstanceOf[BroadcastHelperBlockId] + def isBroadcast = isInstanceOf[BroadcastBlockId] override def toString = name override def hashCode = name.hashCode @@ -48,18 +48,15 @@ private[spark] case class RDDBlockId(rddId: Int, splitIndex: Int) extends BlockI def name = "rdd_" + rddId + "_" + splitIndex } -private[spark] -case class ShuffleBlockId(shuffleId: Int, mapId: Int, reduceId: Int) extends BlockId { +private[spark] case class ShuffleBlockId(shuffleId: Int, mapId: Int, reduceId: Int) + extends BlockId { def name = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId } +// Leave field as an instance variable to avoid matching on it private[spark] case class BroadcastBlockId(broadcastId: Long) extends BlockId { - def name = "broadcast_" + broadcastId -} - -private[spark] -case class BroadcastHelperBlockId(broadcastId: BroadcastBlockId, hType: String) extends BlockId { - def name = broadcastId.name + "_" + hType + var field = "" + def name = "broadcast_" + broadcastId + (if (field == "") "" else "_" + field) } private[spark] case class TaskResultBlockId(taskId: Long) extends BlockId { @@ -80,11 +77,19 @@ private[spark] case class TestBlockId(id: String) extends BlockId { def name = "test_" + id } +private[spark] object BroadcastBlockId { + def apply(broadcastId: Long, field: String) = { + val blockId = new BroadcastBlockId(broadcastId) + blockId.field = field + blockId + } +} + private[spark] object BlockId { val RDD = "rdd_([0-9]+)_([0-9]+)".r val SHUFFLE = "shuffle_([0-9]+)_([0-9]+)_([0-9]+)".r val BROADCAST = "broadcast_([0-9]+)".r - val BROADCAST_HELPER = "broadcast_([0-9]+)_([A-Za-z0-9]+)".r + val BROADCAST_FIELD = "broadcast_([0-9]+)_([A-Za-z0-9]+)".r val TASKRESULT = "taskresult_([0-9]+)".r val STREAM = "input-([0-9]+)-([0-9]+)".r val TEST = "test_(.*)".r @@ -97,8 +102,8 @@ private[spark] object BlockId { ShuffleBlockId(shuffleId.toInt, mapId.toInt, reduceId.toInt) case BROADCAST(broadcastId) => BroadcastBlockId(broadcastId.toLong) - case BROADCAST_HELPER(broadcastId, hType) => - BroadcastHelperBlockId(BroadcastBlockId(broadcastId.toLong), hType) + case BROADCAST_FIELD(broadcastId, field) => + BroadcastBlockId(broadcastId.toLong, field) case TASKRESULT(taskId) => TaskResultBlockId(taskId.toLong) case STREAM(streamId, uniqueId) => diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 24ec8d3ab44bf..a88eb1315a37b 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -827,9 +827,8 @@ private[spark] class BlockManager( */ def removeBroadcast(broadcastId: Long, removeFromDriver: Boolean) { logInfo("Removing broadcast " + broadcastId) - val blocksToRemove = blockInfo.keys.filter(_.isBroadcast).collect { + val blocksToRemove = blockInfo.keys.collect { case bid: BroadcastBlockId if bid.broadcastId == broadcastId => bid - case bid: BroadcastHelperBlockId if bid.broadcastId.broadcastId == broadcastId => bid } blocksToRemove.foreach { blockId => removeBlock(blockId, removeFromDriver) } } diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index 346f2b7856791..d9a6af61872d1 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -195,7 +195,7 @@ private[spark] object JsonProtocol { taskMetrics.shuffleWriteMetrics.map(shuffleWriteMetricsToJson).getOrElse(JNothing) val updatedBlocks = taskMetrics.updatedBlocks.map { blocks => JArray(blocks.toList.map { case (id, status) => - ("Block ID" -> blockIdToJson(id)) ~ + ("Block ID" -> id.toString) ~ ("Status" -> blockStatusToJson(status)) }) }.getOrElse(JNothing) @@ -284,35 +284,6 @@ private[spark] object JsonProtocol { ("Replication" -> storageLevel.replication) } - def blockIdToJson(blockId: BlockId): JValue = { - val blockType = Utils.getFormattedClassName(blockId) - val json: JObject = blockId match { - case rddBlockId: RDDBlockId => - ("RDD ID" -> rddBlockId.rddId) ~ - ("Split Index" -> rddBlockId.splitIndex) - case shuffleBlockId: ShuffleBlockId => - ("Shuffle ID" -> shuffleBlockId.shuffleId) ~ - ("Map ID" -> shuffleBlockId.mapId) ~ - ("Reduce ID" -> shuffleBlockId.reduceId) - case broadcastBlockId: BroadcastBlockId => - "Broadcast ID" -> broadcastBlockId.broadcastId - case broadcastHelperBlockId: BroadcastHelperBlockId => - ("Broadcast Block ID" -> blockIdToJson(broadcastHelperBlockId.broadcastId)) ~ - ("Helper Type" -> broadcastHelperBlockId.hType) - case taskResultBlockId: TaskResultBlockId => - "Task ID" -> taskResultBlockId.taskId - case streamBlockId: StreamBlockId => - ("Stream ID" -> streamBlockId.streamId) ~ - ("Unique ID" -> streamBlockId.uniqueId) - case tempBlockId: TempBlockId => - val uuid = UUIDToJson(tempBlockId.id) - "Temp ID" -> uuid - case testBlockId: TestBlockId => - "Test ID" -> testBlockId.id - } - ("Type" -> blockType) ~ json - } - def blockStatusToJson(blockStatus: BlockStatus): JValue = { val storageLevel = storageLevelToJson(blockStatus.storageLevel) ("Storage Level" -> storageLevel) ~ @@ -513,7 +484,7 @@ private[spark] object JsonProtocol { Utils.jsonOption(json \ "Shuffle Write Metrics").map(shuffleWriteMetricsFromJson) metrics.updatedBlocks = Utils.jsonOption(json \ "Updated Blocks").map { value => value.extract[List[JValue]].map { block => - val id = blockIdFromJson(block \ "Block ID") + val id = BlockId((block \ "Block ID").extract[String]) val status = blockStatusFromJson(block \ "Status") (id, status) } @@ -616,50 +587,6 @@ private[spark] object JsonProtocol { StorageLevel(useDisk, useMemory, deserialized, replication) } - def blockIdFromJson(json: JValue): BlockId = { - val rddBlockId = Utils.getFormattedClassName(RDDBlockId) - val shuffleBlockId = Utils.getFormattedClassName(ShuffleBlockId) - val broadcastBlockId = Utils.getFormattedClassName(BroadcastBlockId) - val broadcastHelperBlockId = Utils.getFormattedClassName(BroadcastHelperBlockId) - val taskResultBlockId = Utils.getFormattedClassName(TaskResultBlockId) - val streamBlockId = Utils.getFormattedClassName(StreamBlockId) - val tempBlockId = Utils.getFormattedClassName(TempBlockId) - val testBlockId = Utils.getFormattedClassName(TestBlockId) - - (json \ "Type").extract[String] match { - case `rddBlockId` => - val rddId = (json \ "RDD ID").extract[Int] - val splitIndex = (json \ "Split Index").extract[Int] - new RDDBlockId(rddId, splitIndex) - case `shuffleBlockId` => - val shuffleId = (json \ "Shuffle ID").extract[Int] - val mapId = (json \ "Map ID").extract[Int] - val reduceId = (json \ "Reduce ID").extract[Int] - new ShuffleBlockId(shuffleId, mapId, reduceId) - case `broadcastBlockId` => - val broadcastId = (json \ "Broadcast ID").extract[Long] - new BroadcastBlockId(broadcastId) - case `broadcastHelperBlockId` => - val broadcastBlockId = - blockIdFromJson(json \ "Broadcast Block ID").asInstanceOf[BroadcastBlockId] - val hType = (json \ "Helper Type").extract[String] - new BroadcastHelperBlockId(broadcastBlockId, hType) - case `taskResultBlockId` => - val taskId = (json \ "Task ID").extract[Long] - new TaskResultBlockId(taskId) - case `streamBlockId` => - val streamId = (json \ "Stream ID").extract[Int] - val uniqueId = (json \ "Unique ID").extract[Long] - new StreamBlockId(streamId, uniqueId) - case `tempBlockId` => - val tempId = UUIDFromJson(json \ "Temp ID") - new TempBlockId(tempId) - case `testBlockId` => - val testId = (json \ "Test ID").extract[String] - new TestBlockId(testId) - } - } - def blockStatusFromJson(json: JValue): BlockStatus = { val storageLevel = storageLevelFromJson(json \ "Storage Level") val memorySize = (json \ "Memory Size").extract[Long] diff --git a/core/src/test/scala/org/apache/spark/BroadcastSuite.scala b/core/src/test/scala/org/apache/spark/BroadcastSuite.scala index a462654197ea0..9e600f1e91aa2 100644 --- a/core/src/test/scala/org/apache/spark/BroadcastSuite.scala +++ b/core/src/test/scala/org/apache/spark/BroadcastSuite.scala @@ -21,7 +21,7 @@ import org.scalatest.FunSuite import org.apache.spark.storage._ import org.apache.spark.broadcast.HttpBroadcast -import org.apache.spark.storage.{BroadcastBlockId, BroadcastHelperBlockId} +import org.apache.spark.storage.BroadcastBlockId class BroadcastSuite extends FunSuite with LocalSparkContext { @@ -102,23 +102,22 @@ class BroadcastSuite extends FunSuite with LocalSparkContext { * are present only on the expected nodes. */ private def testUnpersistHttpBroadcast(numSlaves: Int, removeFromDriver: Boolean) { - def getBlockIds(id: Long) = Seq[BlockId](BroadcastBlockId(id)) + def getBlockIds(id: Long) = Seq[BroadcastBlockId](BroadcastBlockId(id)) // Verify that the broadcast file is created, and blocks are persisted only on the driver - def afterCreation(blockIds: Seq[BlockId], bmm: BlockManagerMaster) { + def afterCreation(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) { assert(blockIds.size === 1) - val broadcastBlockId = blockIds.head.asInstanceOf[BroadcastBlockId] - val levels = bmm.askForStorageLevels(broadcastBlockId, waitTimeMs = 0) + val levels = bmm.askForStorageLevels(blockIds.head, waitTimeMs = 0) assert(levels.size === 1) levels.head match { case (bm, level) => assert(bm.executorId === "") assert(level === StorageLevel.MEMORY_AND_DISK) } - assert(HttpBroadcast.getFile(broadcastBlockId.broadcastId).exists) + assert(HttpBroadcast.getFile(blockIds.head.broadcastId).exists) } // Verify that blocks are persisted in both the executors and the driver - def afterUsingBroadcast(blockIds: Seq[BlockId], bmm: BlockManagerMaster) { + def afterUsingBroadcast(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) { assert(blockIds.size === 1) val levels = bmm.askForStorageLevels(blockIds.head, waitTimeMs = 0) assert(levels.size === numSlaves + 1) @@ -129,12 +128,11 @@ class BroadcastSuite extends FunSuite with LocalSparkContext { // Verify that blocks are unpersisted on all executors, and on all nodes if removeFromDriver // is true. In the latter case, also verify that the broadcast file is deleted on the driver. - def afterUnpersist(blockIds: Seq[BlockId], bmm: BlockManagerMaster) { + def afterUnpersist(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) { assert(blockIds.size === 1) - val broadcastBlockId = blockIds.head.asInstanceOf[BroadcastBlockId] - val levels = bmm.askForStorageLevels(broadcastBlockId, waitTimeMs = 0) + val levels = bmm.askForStorageLevels(blockIds.head, waitTimeMs = 0) assert(levels.size === (if (removeFromDriver) 0 else 1)) - assert(removeFromDriver === !HttpBroadcast.getFile(broadcastBlockId.broadcastId).exists) + assert(removeFromDriver === !HttpBroadcast.getFile(blockIds.head.broadcastId).exists) } testUnpersistBroadcast(numSlaves, httpConf, getBlockIds, afterCreation, @@ -151,14 +149,14 @@ class BroadcastSuite extends FunSuite with LocalSparkContext { private def testUnpersistTorrentBroadcast(numSlaves: Int, removeFromDriver: Boolean) { def getBlockIds(id: Long) = { val broadcastBlockId = BroadcastBlockId(id) - val metaBlockId = BroadcastHelperBlockId(broadcastBlockId, "meta") + val metaBlockId = BroadcastBlockId(id, "meta") // Assume broadcast value is small enough to fit into 1 piece - val pieceBlockId = BroadcastHelperBlockId(broadcastBlockId, "piece0") - Seq[BlockId](broadcastBlockId, metaBlockId, pieceBlockId) + val pieceBlockId = BroadcastBlockId(id, "piece0") + Seq[BroadcastBlockId](broadcastBlockId, metaBlockId, pieceBlockId) } // Verify that blocks are persisted only on the driver - def afterCreation(blockIds: Seq[BlockId], bmm: BlockManagerMaster) { + def afterCreation(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) { blockIds.foreach { blockId => val levels = bmm.askForStorageLevels(blockId, waitTimeMs = 0) assert(levels.size === 1) @@ -170,27 +168,26 @@ class BroadcastSuite extends FunSuite with LocalSparkContext { } // Verify that blocks are persisted in both the executors and the driver - def afterUsingBroadcast(blockIds: Seq[BlockId], bmm: BlockManagerMaster) { + def afterUsingBroadcast(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) { blockIds.foreach { blockId => val levels = bmm.askForStorageLevels(blockId, waitTimeMs = 0) - blockId match { - case BroadcastHelperBlockId(_, "meta") => - // Meta data is only on the driver - assert(levels.size === 1) - levels.head match { case (bm, _) => assert(bm.executorId === "") } - case _ => - // Other blocks are on both the executors and the driver - assert(levels.size === numSlaves + 1) - levels.foreach { case (_, level) => - assert(level === StorageLevel.MEMORY_AND_DISK) - } + if (blockId.field == "meta") { + // Meta data is only on the driver + assert(levels.size === 1) + levels.head match { case (bm, _) => assert(bm.executorId === "") } + } else { + // Other blocks are on both the executors and the driver + assert(levels.size === numSlaves + 1) + levels.foreach { case (_, level) => + assert(level === StorageLevel.MEMORY_AND_DISK) + } } } } // Verify that blocks are unpersisted on all executors, and on all nodes if removeFromDriver // is true. - def afterUnpersist(blockIds: Seq[BlockId], bmm: BlockManagerMaster) { + def afterUnpersist(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) { val expectedNumBlocks = if (removeFromDriver) 0 else 1 var waitTimeMs = 1000L blockIds.foreach { blockId => @@ -217,10 +214,10 @@ class BroadcastSuite extends FunSuite with LocalSparkContext { private def testUnpersistBroadcast( numSlaves: Int, broadcastConf: SparkConf, - getBlockIds: Long => Seq[BlockId], - afterCreation: (Seq[BlockId], BlockManagerMaster) => Unit, - afterUsingBroadcast: (Seq[BlockId], BlockManagerMaster) => Unit, - afterUnpersist: (Seq[BlockId], BlockManagerMaster) => Unit, + getBlockIds: Long => Seq[BroadcastBlockId], + afterCreation: (Seq[BroadcastBlockId], BlockManagerMaster) => Unit, + afterUsingBroadcast: (Seq[BroadcastBlockId], BlockManagerMaster) => Unit, + afterUnpersist: (Seq[BroadcastBlockId], BlockManagerMaster) => Unit, removeFromDriver: Boolean) { sc = new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test", broadcastConf) diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index 67c0a434c9b52..580ac34f5f0b4 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -104,15 +104,6 @@ class JsonProtocolSuite extends FunSuite { testTaskEndReason(TaskKilled) testTaskEndReason(ExecutorLostFailure) testTaskEndReason(UnknownReason) - - // BlockId - testBlockId(RDDBlockId(1, 2)) - testBlockId(ShuffleBlockId(1, 2, 3)) - testBlockId(BroadcastBlockId(1L)) - testBlockId(BroadcastHelperBlockId(BroadcastBlockId(2L), "Spark")) - testBlockId(TaskResultBlockId(1L)) - testBlockId(StreamBlockId(1, 2L)) - testBlockId(TempBlockId(UUID.randomUUID())) } @@ -167,11 +158,6 @@ class JsonProtocolSuite extends FunSuite { assertEquals(reason, newReason) } - private def testBlockId(blockId: BlockId) { - val newBlockId = JsonProtocol.blockIdFromJson(JsonProtocol.blockIdToJson(blockId)) - blockId == newBlockId - } - /** -------------------------------- * | Util methods for comparing events |