diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 828b70dfe92e9..af0e8f3a6b83c 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -783,6 +783,17 @@ package object config { .booleanConf .createWithDefault(false) + private[spark] val CACHE_CHECKPOINT_PREFERRED_LOCS_EXPIRE_TIME = + ConfigBuilder("spark.rdd.checkpoint.cachePreferredLocsExpireTime") + .internal() + .doc("Expire time in minutes for caching preferred locations of checkpointed RDD." + + "Caching preferred locations can relieve query loading to DFS and save the query " + + "time. The drawback is that the cached locations can be possibly outdated and " + + "lose data locality. If this config is not specified or is 0, it will not cache.") + .timeConf(TimeUnit.MINUTES) + .checkValue(_ > 0, "The expire time for caching preferred locations cannot be non-positive.") + .createOptional + private[spark] val SHUFFLE_ACCURATE_BLOCK_THRESHOLD = ConfigBuilder("spark.shuffle.accurateBlockThreshold") .doc("Threshold in bytes above which the size of shuffle blocks in " + diff --git a/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala index d165610291f1d..ce7c2be56c3cd 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala @@ -20,15 +20,17 @@ package org.apache.spark.rdd import java.io.{FileNotFoundException, IOException} import java.util.concurrent.TimeUnit +import scala.collection.mutable import scala.reflect.ClassTag import scala.util.control.NonFatal +import com.google.common.cache.{CacheBuilder, CacheLoader} import org.apache.hadoop.fs.Path import org.apache.spark._ import org.apache.spark.broadcast.Broadcast import org.apache.spark.internal.Logging -import org.apache.spark.internal.config.{BUFFER_SIZE, CHECKPOINT_COMPRESS} +import org.apache.spark.internal.config.{BUFFER_SIZE, CACHE_CHECKPOINT_PREFERRED_LOCS_EXPIRE_TIME, CHECKPOINT_COMPRESS} import org.apache.spark.io.CompressionCodec import org.apache.spark.util.{SerializableConfiguration, Utils} @@ -82,16 +84,40 @@ private[spark] class ReliableCheckpointRDD[T: ClassTag]( Array.tabulate(inputFiles.length)(i => new CheckpointRDDPartition(i)) } - /** - * Return the locations of the checkpoint file associated with the given partition. - */ - protected override def getPreferredLocations(split: Partition): Seq[String] = { + // Cache of preferred locations of checkpointed files. + @transient private[spark] lazy val cachedPreferredLocations = CacheBuilder.newBuilder() + .expireAfterWrite( + SparkEnv.get.conf.get(CACHE_CHECKPOINT_PREFERRED_LOCS_EXPIRE_TIME).get, + TimeUnit.MINUTES) + .build( + new CacheLoader[Partition, Seq[String]]() { + override def load(split: Partition): Seq[String] = { + getPartitionBlockLocations(split) + } + }) + + // Returns the block locations of given partition on file system. + private def getPartitionBlockLocations(split: Partition): Seq[String] = { val status = fs.getFileStatus( new Path(checkpointPath, ReliableCheckpointRDD.checkpointFileName(split.index))) val locations = fs.getFileBlockLocations(status, 0, status.getLen) locations.headOption.toList.flatMap(_.getHosts).filter(_ != "localhost") } + private lazy val cachedExpireTime = + SparkEnv.get.conf.get(CACHE_CHECKPOINT_PREFERRED_LOCS_EXPIRE_TIME) + + /** + * Return the locations of the checkpoint file associated with the given partition. + */ + protected override def getPreferredLocations(split: Partition): Seq[String] = { + if (cachedExpireTime.isDefined && cachedExpireTime.get > 0) { + cachedPreferredLocations.get(split) + } else { + getPartitionBlockLocations(split) + } + } + /** * Read the content of the checkpoint file associated with the given partition. */ diff --git a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala index 3a43f1a033da1..6a108a55045ee 100644 --- a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala +++ b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala @@ -24,6 +24,7 @@ import scala.reflect.ClassTag import com.google.common.io.ByteStreams import org.apache.hadoop.fs.Path +import org.apache.spark.internal.config.CACHE_CHECKPOINT_PREFERRED_LOCS_EXPIRE_TIME import org.apache.spark.internal.config.UI._ import org.apache.spark.io.CompressionCodec import org.apache.spark.rdd._ @@ -584,7 +585,7 @@ object CheckpointSuite { } } -class CheckpointCompressionSuite extends SparkFunSuite with LocalSparkContext { +class CheckpointStorageSuite extends SparkFunSuite with LocalSparkContext { test("checkpoint compression") { withTempDir { checkpointDir => @@ -618,4 +619,27 @@ class CheckpointCompressionSuite extends SparkFunSuite with LocalSparkContext { assert(rdd.collect().toSeq === (1 to 20)) } } + + test("cache checkpoint preferred location") { + withTempDir { checkpointDir => + val conf = new SparkConf() + .set(CACHE_CHECKPOINT_PREFERRED_LOCS_EXPIRE_TIME.key, "10") + .set(UI_ENABLED.key, "false") + sc = new SparkContext("local", "test", conf) + sc.setCheckpointDir(checkpointDir.toString) + val rdd = sc.makeRDD(1 to 20, numSlices = 1) + rdd.checkpoint() + assert(rdd.collect().toSeq === (1 to 20)) + + // Verify that RDD is checkpointed + assert(rdd.firstParent.isInstanceOf[ReliableCheckpointRDD[_]]) + val checkpointedRDD = rdd.firstParent.asInstanceOf[ReliableCheckpointRDD[_]] + val partiton = checkpointedRDD.partitions(0) + assert(!checkpointedRDD.cachedPreferredLocations.asMap.containsKey(partiton)) + + val preferredLoc = checkpointedRDD.preferredLocations(partiton) + assert(checkpointedRDD.cachedPreferredLocations.asMap.containsKey(partiton)) + assert(preferredLoc == checkpointedRDD.cachedPreferredLocations.get(partiton)) + } + } }