diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/TrackStateDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/TrackStateDStream.scala index 0ada1111ce30a..1a07929ee43e7 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/TrackStateDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/TrackStateDStream.scala @@ -132,22 +132,39 @@ class InternalTrackStateDStream[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassT /** Method that generates a RDD for the given time */ override def compute(validTime: Time): Option[RDD[TrackStateRDDRecord[K, S, E]]] = { // Get the previous state or create a new empty state RDD - val prevStateRDD = getOrCompute(validTime - slideDuration).getOrElse { - TrackStateRDD.createFromPairRDD[K, V, S, E]( - spec.getInitialStateRDD().getOrElse(new EmptyRDD[(K, S)](ssc.sparkContext)), - partitioner, validTime - ) + val prevStateRDD = getOrCompute(validTime - slideDuration) match { + case Some(rdd) => + if (rdd.partitioner != Some(partitioner)) { + // If the RDD is not partitioned the right way, let us repartition it using the + // partition index as the key. This is to ensure that state RDD is always partitioned + // before creating another state RDD using it + val kvRDD = rdd.mapPartitions { iter => + iter.map { x => (TaskContext.get().partitionId(), x)} + } + kvRDD.partitionBy(partitioner).mapPartitions(iter => iter.map { _._2 }, + preservesPartitioning = true) + } else { + rdd + } + case None => + TrackStateRDD.createFromPairRDD[K, V, S, E]( + spec.getInitialStateRDD().getOrElse(new EmptyRDD[(K, S)](ssc.sparkContext)), + partitioner, validTime + ) } + // Compute the new state RDD with previous state RDD and partitioned data RDD - parent.getOrCompute(validTime).map { dataRDD => - val partitionedDataRDD = dataRDD.partitionBy(partitioner) - val timeoutThresholdTime = spec.getTimeoutInterval().map { interval => - (validTime - interval).milliseconds - } - new TrackStateRDD( - prevStateRDD, partitionedDataRDD, trackingFunction, validTime, timeoutThresholdTime) + // Even if there is no data RDD, use an empty one to create a new state RDD + val dataRDD = parent.getOrCompute(validTime).getOrElse { + context.sparkContext.emptyRDD[(K, V)] + } + val partitionedDataRDD = dataRDD.partitionBy(partitioner) + val timeoutThresholdTime = spec.getTimeoutInterval().map { interval => + (validTime - interval).milliseconds } + Some(new TrackStateRDD( + prevStateRDD, partitionedDataRDD, trackingFunction, validTime, timeoutThresholdTime)) } } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala index b1cbc7163bee3..bdab96cd95855 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala @@ -33,17 +33,133 @@ import org.mockito.Mockito.mock import org.scalatest.concurrent.Eventually._ import org.scalatest.time.SpanSugar._ -import org.apache.spark.TestUtils +import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite, TestUtils} import org.apache.spark.streaming.dstream.{DStream, FileInputDStream} import org.apache.spark.streaming.scheduler._ import org.apache.spark.util.{MutableURLClassLoader, Clock, ManualClock, Utils} +/** + * A trait of that can be mixed in to get methods for testing DStream operations under + * DStream checkpointing. Note that the implementations of this trait has to implement + * the `setupCheckpointOperation` + */ +trait DStreamCheckpointTester { self: SparkFunSuite => + + /** + * Tests a streaming operation under checkpointing, by restarting the operation + * from checkpoint file and verifying whether the final output is correct. + * The output is assumed to have come from a reliable queue which an replay + * data as required. + * + * NOTE: This takes into consideration that the last batch processed before + * master failure will be re-processed after restart/recovery. + */ + protected def testCheckpointedOperation[U: ClassTag, V: ClassTag]( + input: Seq[Seq[U]], + operation: DStream[U] => DStream[V], + expectedOutput: Seq[Seq[V]], + numBatchesBeforeRestart: Int, + batchDuration: Duration = Seconds(1), + stopSparkContextAfterTest: Boolean = true + ) { + require(numBatchesBeforeRestart < expectedOutput.size, + "Number of batches before context restart less than number of expected output " + + "(i.e. number of total batches to run)") + require(StreamingContext.getActive().isEmpty, + "Cannot run test with already active streaming context") + + // Current code assumes that: + // number of inputs = number of outputs = number of batches to be run + val totalNumBatches = input.size + val nextNumBatches = totalNumBatches - numBatchesBeforeRestart + val initialNumExpectedOutputs = numBatchesBeforeRestart + val nextNumExpectedOutputs = expectedOutput.size - initialNumExpectedOutputs + 1 + // because the last batch will be processed again + + // Setup the stream computation + val checkpointDir = Utils.createTempDir(this.getClass.getSimpleName()).toString + logDebug(s"Using checkpoint directory $checkpointDir") + val ssc = createContextForCheckpointOperation(batchDuration) + require(ssc.conf.get("spark.streaming.clock") === classOf[ManualClock].getName, + "Cannot run test without manual clock in the conf") + + val inputStream = new TestInputStream(ssc, input, numPartitions = 2) + val operatedStream = operation(inputStream) + val outputStream = new TestOutputStreamWithPartitions(operatedStream, + new ArrayBuffer[Seq[Seq[V]]] with SynchronizedBuffer[Seq[Seq[V]]]) + outputStream.register() + ssc.checkpoint(checkpointDir) + + // Do the computation for initial number of batches, create checkpoint file and quit + generateAndAssertOutput[V](ssc, batchDuration, checkpointDir, + expectedOutput.take(numBatchesBeforeRestart), stopSparkContextAfterTest) + + // Restart and complete the computation from checkpoint file + logInfo( + "\n-------------------------------------------\n" + + " Restarting stream computation " + + "\n-------------------------------------------\n" + ) + val restartedSsc = new StreamingContext(checkpointDir) + generateAndAssertOutput[V](restartedSsc, batchDuration, checkpointDir, + expectedOutput.takeRight(nextNumExpectedOutputs), stopSparkContextAfterTest) + } + + protected def createContextForCheckpointOperation(batchDuration: Duration): StreamingContext = { + val conf = new SparkConf().setMaster("local").setAppName(this.getClass.getSimpleName) + conf.set("spark.streaming.clock", classOf[ManualClock].getName()) + new StreamingContext(SparkContext.getOrCreate(conf), Seconds(1)) + } + + private def generateAndAssertOutput[V: ClassTag]( + ssc: StreamingContext, + batchDuration: Duration, + checkpointDir: String, + expectedOutput: Seq[Seq[V]], + stopSparkContext: Boolean + ) { + try { + ssc.start() + val numBatches = expectedOutput.size + val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] + logDebug("Manual clock before advancing = " + clock.getTimeMillis()) + clock.advance((batchDuration * numBatches).milliseconds) + logDebug("Manual clock after advancing = " + clock.getTimeMillis()) + + val outputStream = ssc.graph.getOutputStreams().filter { dstream => + dstream.isInstanceOf[TestOutputStreamWithPartitions[V]] + }.head.asInstanceOf[TestOutputStreamWithPartitions[V]] + + eventually(timeout(10 seconds)) { + ssc.awaitTerminationOrTimeout(10) + assert(outputStream.output.size === expectedOutput.size) + } + + eventually(timeout(10 seconds)) { + Checkpoint.getCheckpointFiles(checkpointDir).exists { + _.toString.contains(clock.getTimeMillis.toString) + } + } + + val output = outputStream.output.map(_.flatten) + assert( + output.zip(expectedOutput).forall { case (o, e) => o.toSet === e.toSet }, + s"Set comparison failed\n" + + s"Expected output (${expectedOutput.size} items):\n${expectedOutput.mkString("\n")}\n" + + s"Generated output (${output.size} items): ${output.mkString("\n")}" + ) + } finally { + ssc.stop(stopSparkContext = stopSparkContext) + } + } +} + /** * This test suites tests the checkpointing functionality of DStreams - * the checkpointing of a DStream's RDDs as well as the checkpointing of * the whole DStream graph. */ -class CheckpointSuite extends TestSuiteBase { +class CheckpointSuite extends TestSuiteBase with DStreamCheckpointTester { var ssc: StreamingContext = null @@ -56,7 +172,7 @@ class CheckpointSuite extends TestSuiteBase { override def afterFunction() { super.afterFunction() - if (ssc != null) ssc.stop() + StreamingContext.getActive().foreach { _.stop() } Utils.deleteRecursively(new File(checkpointDir)) } @@ -634,53 +750,6 @@ class CheckpointSuite extends TestSuiteBase { checkpointWriter.stop() } - /** - * Tests a streaming operation under checkpointing, by restarting the operation - * from checkpoint file and verifying whether the final output is correct. - * The output is assumed to have come from a reliable queue which an replay - * data as required. - * - * NOTE: This takes into consideration that the last batch processed before - * master failure will be re-processed after restart/recovery. - */ - def testCheckpointedOperation[U: ClassTag, V: ClassTag]( - input: Seq[Seq[U]], - operation: DStream[U] => DStream[V], - expectedOutput: Seq[Seq[V]], - initialNumBatches: Int - ) { - - // Current code assumes that: - // number of inputs = number of outputs = number of batches to be run - val totalNumBatches = input.size - val nextNumBatches = totalNumBatches - initialNumBatches - val initialNumExpectedOutputs = initialNumBatches - val nextNumExpectedOutputs = expectedOutput.size - initialNumExpectedOutputs + 1 - // because the last batch will be processed again - - // Do the computation for initial number of batches, create checkpoint file and quit - ssc = setupStreams[U, V](input, operation) - ssc.start() - val output = advanceTimeWithRealDelay[V](ssc, initialNumBatches) - ssc.stop() - verifyOutput[V](output, expectedOutput.take(initialNumBatches), true) - Thread.sleep(1000) - - // Restart and complete the computation from checkpoint file - logInfo( - "\n-------------------------------------------\n" + - " Restarting stream computation " + - "\n-------------------------------------------\n" - ) - ssc = new StreamingContext(checkpointDir) - ssc.start() - val outputNew = advanceTimeWithRealDelay[V](ssc, nextNumBatches) - // the first element will be re-processed data of the last batch before restart - verifyOutput[V](outputNew, expectedOutput.takeRight(nextNumExpectedOutputs), true) - ssc.stop() - ssc = null - } - /** * Advances the manual clock on the streaming scheduler by given number of batches. * It also waits for the expected amount of time for each batch. diff --git a/streaming/src/test/scala/org/apache/spark/streaming/TrackStateByKeySuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/TrackStateByKeySuite.scala index 58aef74c0040f..89eef2318fdff 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/TrackStateByKeySuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/TrackStateByKeySuite.scala @@ -25,31 +25,27 @@ import scala.reflect.ClassTag import org.scalatest.PrivateMethodTester._ import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} -import org.apache.spark.streaming.dstream.{InternalTrackStateDStream, TrackStateDStream, TrackStateDStreamImpl} +import org.apache.spark.streaming.dstream.{DStream, InternalTrackStateDStream, TrackStateDStream, TrackStateDStreamImpl} import org.apache.spark.util.{ManualClock, Utils} import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} -class TrackStateByKeySuite extends SparkFunSuite with BeforeAndAfterAll with BeforeAndAfter { +class TrackStateByKeySuite extends SparkFunSuite + with DStreamCheckpointTester with BeforeAndAfterAll with BeforeAndAfter { private var sc: SparkContext = null - private var ssc: StreamingContext = null - private var checkpointDir: File = null - private val batchDuration = Seconds(1) + protected var checkpointDir: File = null + protected val batchDuration = Seconds(1) before { - StreamingContext.getActive().foreach { - _.stop(stopSparkContext = false) - } + StreamingContext.getActive().foreach { _.stop(stopSparkContext = false) } checkpointDir = Utils.createTempDir("checkpoint") - - ssc = new StreamingContext(sc, batchDuration) - ssc.checkpoint(checkpointDir.toString) } after { - StreamingContext.getActive().foreach { - _.stop(stopSparkContext = false) + if (checkpointDir != null) { + Utils.deleteRecursively(checkpointDir) } + StreamingContext.getActive().foreach { _.stop(stopSparkContext = false) } } override def beforeAll(): Unit = { @@ -242,7 +238,7 @@ class TrackStateByKeySuite extends SparkFunSuite with BeforeAndAfterAll with Bef assert(dstreamImpl.stateClass === classOf[Double]) assert(dstreamImpl.emittedClass === classOf[Long]) } - + val ssc = new StreamingContext(sc, batchDuration) val inputStream = new TestInputStream[(String, Int)](ssc, Seq.empty, numPartitions = 2) // Defining StateSpec inline with trackStateByKey and simple function implicitly gets the types @@ -451,8 +447,9 @@ class TrackStateByKeySuite extends SparkFunSuite with BeforeAndAfterAll with Bef expectedCheckpointDuration: Duration, explicitCheckpointDuration: Option[Duration] = None ): Unit = { + val ssc = new StreamingContext(sc, batchDuration) + try { - ssc = new StreamingContext(sc, batchDuration) val inputStream = new TestInputStream(ssc, Seq.empty[Seq[Int]], 2).map(_ -> 1) val dummyFunc = (value: Option[Int], state: State[Int]) => 0 val trackStateStream = inputStream.trackStateByKey(StateSpec.function(dummyFunc)) @@ -462,11 +459,12 @@ class TrackStateByKeySuite extends SparkFunSuite with BeforeAndAfterAll with Bef trackStateStream.checkpoint(d) } trackStateStream.register() + ssc.checkpoint(checkpointDir.toString) ssc.start() // should initialize all the checkpoint durations assert(trackStateStream.checkpointDuration === null) assert(internalTrackStateStream.checkpointDuration === expectedCheckpointDuration) } finally { - StreamingContext.getActive().foreach { _.stop(stopSparkContext = false) } + ssc.stop(stopSparkContext = false) } } @@ -479,6 +477,50 @@ class TrackStateByKeySuite extends SparkFunSuite with BeforeAndAfterAll with Bef testCheckpointDuration(Seconds(10), Seconds(20), Some(Seconds(20))) } + + test("trackStateByKey - drivery failure recovery") { + val inputData = + Seq( + Seq(), + Seq("a"), + Seq("a", "b"), + Seq("a", "b", "c"), + Seq("a", "b"), + Seq("a"), + Seq() + ) + + val stateData = + Seq( + Seq(), + Seq(("a", 1)), + Seq(("a", 2), ("b", 1)), + Seq(("a", 3), ("b", 2), ("c", 1)), + Seq(("a", 4), ("b", 3), ("c", 1)), + Seq(("a", 5), ("b", 3), ("c", 1)), + Seq(("a", 5), ("b", 3), ("c", 1)) + ) + + def operation(dstream: DStream[String]): DStream[(String, Int)] = { + + val checkpointDuration = batchDuration * (stateData.size / 2) + + val runningCount = (value: Option[Int], state: State[Int]) => { + state.update(state.getOption().getOrElse(0) + value.getOrElse(0)) + state.get() + } + + val trackStateStream = dstream.map { _ -> 1 }.trackStateByKey( + StateSpec.function(runningCount)) + // Set internval make sure there is one RDD checkpointing + trackStateStream.checkpoint(checkpointDuration) + trackStateStream.stateSnapshots() + } + + testCheckpointedOperation(inputData, operation, stateData, inputData.size / 2, + batchDuration = batchDuration, stopSparkContextAfterTest = false) + } + private def testOperation[K: ClassTag, S: ClassTag, T: ClassTag]( input: Seq[Seq[K]], trackStateSpec: StateSpec[K, Int, S, T], @@ -500,6 +542,7 @@ class TrackStateByKeySuite extends SparkFunSuite with BeforeAndAfterAll with Bef ): (Seq[Seq[T]], Seq[Seq[(K, S)]]) = { // Setup the stream computation + val ssc = new StreamingContext(sc, Seconds(1)) val inputStream = new TestInputStream(ssc, input, numPartitions = 2) val trackeStateStream = inputStream.map(x => (x, 1)).trackStateByKey(trackStateSpec) val collectedOutputs = new ArrayBuffer[Seq[T]] with SynchronizedBuffer[Seq[T]] @@ -511,12 +554,14 @@ class TrackStateByKeySuite extends SparkFunSuite with BeforeAndAfterAll with Bef stateSnapshotStream.register() val batchCounter = new BatchCounter(ssc) + ssc.checkpoint(checkpointDir.toString) ssc.start() val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] clock.advance(batchDuration.milliseconds * numBatches) batchCounter.waitUntilBatchesCompleted(numBatches, 10000) + ssc.stop(stopSparkContext = false) (collectedOutputs, collectedStateSnapshots) }