diff --git a/CHANGELOG.md b/CHANGELOG.md index 8c0236db69f60..4b55981dcd901 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,7 @@ ## 2.2.0.3 (upcoming) * Unify Vault variables +* Secret Broadcast variables (Experimental) ## 2.2.0.2 (December 26, 2017) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index ef2b8aca6e1ed..fa1a9aacea32d 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -1492,6 +1492,27 @@ class SparkContext(config: SparkConf) extends Logging { bc } + + /** + * Broadcast a read-only variable to the cluster, returning a + * [[org.apache.spark.broadcast.Broadcast]] object for reading it in distributed functions. + * The variable will be sent to each cluster only once. + * + * @param value value to broadcast to the Spark nodes + * @return `Broadcast` object, a read-only variable cached on each machine + */ + def secretBroadcast(secretVaultPath: String, + idJson: String): Broadcast[String] = { + assertNotStopped() + require(!classOf[RDD[_]].isAssignableFrom(classTag[String].runtimeClass), + "Can not directly broadcast RDDs; instead, call collect() and broadcast the result.") + val bc = env.broadcastManager.newSecretBroadcast(secretVaultPath, idJson, isLocal) + val callSite = getCallSite + logInfo("Created secret broadcast " + bc.id + " from " + callSite.shortForm) + cleaner.foreach(_.registerBroadcastForCleanup(bc)) + bc + } + /** * Add a file to be downloaded with this Spark job on every node. * diff --git a/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala b/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala index ece4ae6ab0310..65f708398d477 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala @@ -40,6 +40,18 @@ private[spark] trait BroadcastFactory { */ def newBroadcast[T: ClassTag](value: T, isLocal: Boolean, id: Long): Broadcast[T] + /** + * Creates a new broadcast variable. + * + * @param secretRepositoryValue secret repository access variable to broadcast + * @param isLocal whether we are in local mode (single JVM process) + * @param id unique id representing this broadcast variable + */ + def newSecretBroadcast(secretVaultPath: String, + idJson: String, + isLocal: Boolean, + id: Long): Broadcast[String] + def unbroadcast(id: Long, removeFromDriver: Boolean, blocking: Boolean): Unit def stop(): Unit diff --git a/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala b/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala index e88988fe03b2e..37bbc4ab6452f 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala @@ -55,6 +55,12 @@ private[spark] class BroadcastManager( def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean): Broadcast[T] = { broadcastFactory.newBroadcast[T](value_, isLocal, nextBroadcastId.getAndIncrement()) } + def newSecretBroadcast(secretVaultPath: String, + idJson: String, + isLocal: Boolean): Broadcast[String] = { + broadcastFactory.newSecretBroadcast(secretVaultPath, idJson, + isLocal, nextBroadcastId.getAndIncrement()) + } def unbroadcast(id: Long, removeFromDriver: Boolean, blocking: Boolean) { broadcastFactory.unbroadcast(id, removeFromDriver, blocking) diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala index 039df75ce74fd..73edc13059192 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala @@ -203,7 +203,7 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long) out.defaultWriteObject() } - private def readBroadcastBlock(): T = Utils.tryOrIOException { + private[spark] def readBroadcastBlock(): T = Utils.tryOrIOException { TorrentBroadcast.synchronized { setConf(SparkEnv.get.conf) val blockManager = SparkEnv.get.blockManager diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala index b11f9ba171b84..bd7740a487874 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala @@ -34,6 +34,20 @@ private[spark] class TorrentBroadcastFactory extends BroadcastFactory { new TorrentBroadcast[T](value_, id) } + /** + * Creates a new broadcast variable. + * + * @param secretRepositoryValue secret repository access variable to broadcast + * @param isLocal whether we are in local mode (single JVM process) + * @param id unique id representing this broadcast variable + */ + override def newSecretBroadcast(secretVaultPath: String, + idJson: String, + isLocal: Boolean, + id: Long): Broadcast[String] = { + new TorrentSecretBroadcast(secretVaultPath, idJson, isLocal, id) + } + override def stop() { } /** @@ -44,4 +58,5 @@ private[spark] class TorrentBroadcastFactory extends BroadcastFactory { override def unbroadcast(id: Long, removeFromDriver: Boolean, blocking: Boolean) { TorrentBroadcast.unpersist(id, removeFromDriver, blocking) } + } diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentSecretBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentSecretBroadcast.scala new file mode 100644 index 0000000000000..d7044f590c277 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentSecretBroadcast.scala @@ -0,0 +1,318 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.broadcast + +import java.io.{InputStream, ObjectOutputStream, SequenceInputStream} +import java.nio.ByteBuffer +import java.util.zip.Adler32 + +import scala.collection.JavaConverters._ +import scala.reflect.ClassTag +import scala.util.Random + +import org.apache.spark._ +import org.apache.spark.internal.Logging +import org.apache.spark.io.CompressionCodec +import org.apache.spark.security.VaultHelper +import org.apache.spark.serializer.Serializer +import org.apache.spark.storage._ +import org.apache.spark.util.Utils +import org.apache.spark.util.io.{ChunkedByteBuffer, ChunkedByteBufferOutputStream} + + +/** + * A BitTorrent-like implementation of [[org.apache.spark.broadcast.Broadcast]]. + * + * The mechanism is as follows: + * + * The driver divides the serialized object into small chunks and + * stores those chunks in the BlockManager of the driver. + * + * On each executor, the executor first attempts to fetch the object from its BlockManager. If + * it does not exist, it then uses remote fetches to fetch the small chunks from the driver and/or + * other executors if available. Once it gets the chunks, it puts the chunks in its own + * BlockManager, ready for other executors to fetch from. + * + * This prevents the driver from being the bottleneck in sending out multiple copies of the + * broadcast data (one per executor). + * + * The secret stored in this class will be the metainformation needed to connect with the secret + * management system and retrieve its secrets + * + * When initialized, TorrentBroadcast objects read SparkEnv.get.conf. + * + * @param obj object to broadcast + * @param id A unique identifier for the broadcast variable. + */ +private[spark] class TorrentSecretBroadcast(secretVaultPath: String, + idJson: String, + isLocal: Boolean, + id: Long) extends Broadcast[String](id) { + /** + * Value of the broadcast object on executors. This is reconstructed by [[readBroadcastBlock]], + * which builds this value by reading blocks from the driver and/or other executors. + * + * On the driver, if the value is required, it is read lazily from the block manager. + */ + @transient private lazy val _value: String = + { + val (secretVaultPath, idJson) = readBroadcastBlock() + VaultHelper.retrieveSecret(secretVaultPath, idJson) + } + + /** The compression codec to use, or None if compression is disabled */ + @transient private var compressionCodec: Option[CompressionCodec] = _ + /** Size of each block. Default value is 4MB. This value is only read by the broadcaster. */ + @transient private var blockSize: Int = _ + + private def setConf(conf: SparkConf) { + compressionCodec = if (conf.getBoolean("spark.broadcast.compress", true)) { + Some(CompressionCodec.createCodec(conf)) + } else { + None + } + // Note: use getSizeAsKb (not bytes) to maintain compatibility if no units are provided + blockSize = conf.getSizeAsKb("spark.broadcast.blockSize", "4m").toInt * 1024 + checksumEnabled = conf.getBoolean("spark.broadcast.checksum", true) + } + setConf(SparkEnv.get.conf) + + private val broadcastId = BroadcastBlockId(id) + + /** Total number of blocks this broadcast variable contains. */ + private val numBlocks: Int = writeBlocks((secretVaultPath, idJson)) + + /** Whether to generate checksum for blocks or not. */ + private var checksumEnabled: Boolean = false + /** The checksum for all the blocks. */ + private var checksums: Array[Int] = _ + + override protected def getValue() = { + _value + } + + private def calcChecksum(block: ByteBuffer): Int = { + val adler = new Adler32() + if (block.hasArray) { + adler.update(block.array, block.arrayOffset + block.position, block.limit - block.position) + } else { + val bytes = new Array[Byte](block.remaining()) + block.duplicate.get(bytes) + adler.update(bytes) + } + adler.getValue.toInt + } + + /** + * Divide the object into multiple blocks and put those blocks in the block manager. + * + * @param value the object to divide + * @return number of blocks this broadcast variable is divided into + */ + private def writeBlocks(value: (String, String)): Int = { + import StorageLevel._ + // Store a copy of the broadcast variable in the driver so that tasks run on the driver + // do not create a duplicate copy of the broadcast variable's value. + val blockManager = SparkEnv.get.blockManager + if (!blockManager.putSingle(broadcastId, value, MEMORY_AND_DISK, tellMaster = false)) { + throw new SparkException(s"Failed to store $broadcastId in BlockManager") + } + val blocks = + TorrentBroadcast.blockifyObject(value, blockSize, SparkEnv.get.serializer, compressionCodec) + if (checksumEnabled) { + checksums = new Array[Int](blocks.length) + } + blocks.zipWithIndex.foreach { case (block, i) => + if (checksumEnabled) { + checksums(i) = calcChecksum(block) + } + val pieceId = BroadcastBlockId(id, "piece" + i) + val bytes = new ChunkedByteBuffer(block.duplicate()) + if (!blockManager.putBytes(pieceId, bytes, MEMORY_AND_DISK_SER, tellMaster = true)) { + throw new SparkException(s"Failed to store $pieceId of $broadcastId in local BlockManager") + } + } + blocks.length + } + + /** Fetch torrent blocks from the driver and/or other executors. */ + private def readBlocks(): Array[BlockData] = { + // Fetch chunks of data. Note that all these chunks are stored in the BlockManager and reported + // to the driver, so other executors can pull these chunks from this executor as well. + val blocks = new Array[BlockData](numBlocks) + val bm = SparkEnv.get.blockManager + + for (pid <- Random.shuffle(Seq.range(0, numBlocks))) { + val pieceId = BroadcastBlockId(id, "piece" + pid) + logDebug(s"Reading piece $pieceId of $broadcastId") + // First try getLocalBytes because there is a chance that previous attempts to fetch the + // broadcast blocks have already fetched some of the blocks. In that case, some blocks + // would be available locally (on this executor). + bm.getLocalBytes(pieceId) match { + case Some(block) => + blocks(pid) = block + releaseLock(pieceId) + case None => + bm.getRemoteBytes(pieceId) match { + case Some(b) => + if (checksumEnabled) { + val sum = calcChecksum(b.chunks(0)) + if (sum != checksums(pid)) { + throw new SparkException(s"corrupt remote block $pieceId of $broadcastId:" + + s" $sum != ${checksums(pid)}") + } + } + // We found the block from remote executors/driver's BlockManager, so put the block + // in this executor's BlockManager. + if (!bm.putBytes(pieceId, b, StorageLevel.MEMORY_AND_DISK_SER, tellMaster = true)) { + throw new SparkException( + s"Failed to store $pieceId of $broadcastId in local BlockManager") + } + blocks(pid) = new ByteBufferBlockData(b, true) + case None => + throw new SparkException(s"Failed to get $pieceId of $broadcastId") + } + } + } + blocks + } + + /** + * Remove all persisted state associated with this Torrent broadcast on the executors. + */ + override protected def doUnpersist(blocking: Boolean) { + TorrentBroadcast.unpersist(id, removeFromDriver = false, blocking) + } + + /** + * Remove all persisted state associated with this Torrent broadcast on the executors + * and driver. + */ + override protected def doDestroy(blocking: Boolean) { + TorrentBroadcast.unpersist(id, removeFromDriver = true, blocking) + } + + /** Used by the JVM when serializing this object. */ + private def writeObject(out: ObjectOutputStream): Unit = Utils.tryOrIOException { + assertValid() + out.defaultWriteObject() + } + + private def readBroadcastBlock(): (String, String) = Utils.tryOrIOException { + TorrentBroadcast.synchronized { + setConf(SparkEnv.get.conf) + val blockManager = SparkEnv.get.blockManager + blockManager.getLocalValues(broadcastId) match { + case Some(blockResult) => + if (blockResult.data.hasNext) { + val x = blockResult.data.next().asInstanceOf[(String, String)] + releaseLock(broadcastId) + x + } else { + throw new SparkException(s"Failed to get locally stored broadcast data: $broadcastId") + } + case None => + logInfo("Started reading broadcast variable " + id) + val startTimeMs = System.currentTimeMillis() + val blocks = readBlocks() + logInfo("Reading broadcast variable " + id + " took" + Utils.getUsedTimeMs(startTimeMs)) + + try { + val obj = TorrentBroadcast.unBlockifyObject[(String, String)]( + blocks.map(_.toInputStream()), SparkEnv.get.serializer, compressionCodec) + // Store the merged copy in BlockManager so other tasks on this executor don't + // need to re-fetch it. + val storageLevel = StorageLevel.MEMORY_AND_DISK + if (!blockManager.putSingle(broadcastId, obj, storageLevel, tellMaster = false)) { + throw new SparkException(s"Failed to store $broadcastId in BlockManager") + } + obj + } finally { + blocks.foreach(_.dispose()) + } + } + } + } + + /** + * If running in a task, register the given block's locks for release upon task completion. + * Otherwise, if not running in a task then immediately release the lock. + */ + private def releaseLock(blockId: BlockId): Unit = { + val blockManager = SparkEnv.get.blockManager + Option(TaskContext.get()) match { + case Some(taskContext) => + taskContext.addTaskCompletionListener(_ => blockManager.releaseLock(blockId)) + case None => + // This should only happen on the driver, where broadcast variables may be accessed + // outside of running tasks (e.g. when computing rdd.partitions()). In order to allow + // broadcast variables to be garbage collected we need to free the reference here + // which is slightly unsafe but is technically okay because broadcast variables aren't + // stored off-heap. + blockManager.releaseLock(blockId) + } + } + +} + + +private object TorrentSecretBroadcast extends Logging { + + def blockifyObject[T: ClassTag]( + obj: T, + blockSize: Int, + serializer: Serializer, + compressionCodec: Option[CompressionCodec]): Array[ByteBuffer] = { + val cbbos = new ChunkedByteBufferOutputStream(blockSize, ByteBuffer.allocate) + val out = compressionCodec.map(c => c.compressedOutputStream(cbbos)).getOrElse(cbbos) + val ser = serializer.newInstance() + val serOut = ser.serializeStream(out) + Utils.tryWithSafeFinally { + serOut.writeObject[T](obj) + } { + serOut.close() + } + cbbos.toChunkedByteBuffer.getChunks() + } + + def unBlockifyObject[T: ClassTag]( + blocks: Array[InputStream], + serializer: Serializer, + compressionCodec: Option[CompressionCodec]): T = { + require(blocks.nonEmpty, "Cannot unblockify an empty array of blocks") + val is = new SequenceInputStream(blocks.iterator.asJavaEnumeration) + val in: InputStream = compressionCodec.map(c => c.compressedInputStream(is)).getOrElse(is) + val ser = serializer.newInstance() + val serIn = ser.deserializeStream(in) + val obj = Utils.tryWithSafeFinally { + serIn.readObject[T]() + } { + serIn.close() + } + obj + } + + /** + * Remove all persisted blocks associated with this torrent broadcast on the executors. + * If removeFromDriver is true, also remove these persisted blocks on the driver. + */ + def unpersist(id: Long, removeFromDriver: Boolean, blocking: Boolean): Unit = { + logDebug(s"Unpersisting TorrentBroadcast $id") + SparkEnv.get.blockManager.master.removeBroadcast(id, removeFromDriver, blocking) + } +} diff --git a/core/src/main/scala/org/apache/spark/security/VaultHelper.scala b/core/src/main/scala/org/apache/spark/security/VaultHelper.scala index f16453c201ed7..cb0d7e7afb853 100644 --- a/core/src/main/scala/org/apache/spark/security/VaultHelper.scala +++ b/core/src/main/scala/org/apache/spark/security/VaultHelper.scala @@ -18,6 +18,8 @@ package org.apache.spark.security import org.apache.spark.internal.Logging +import scala.util.Try + object VaultHelper extends Logging { @@ -127,4 +129,13 @@ object VaultHelper extends Logging { "data", Some(Seq(("X-Vault-Token", vaultTempToken.get))) )("token").asInstanceOf[String] } + + def retrieveSecret(secretVaultPath: String, idJSonSecret: String): String = { + logDebug(s"Retriving Secret: $secretVaultPath") + val requestUrl = s"${ConfigSecurity.vaultURI.get}/$secretVaultPath" + + HTTPHelper.executeGet(requestUrl, + "data", Some(Seq(("X-Vault-Token", + ConfigSecurity.vaultToken.get))))(idJSonSecret).asInstanceOf[String] + } }