Skip to content

Commit

Permalink
Partition previous state RDD if partitioner not present
Browse files Browse the repository at this point in the history
  • Loading branch information
tdas committed Nov 26, 2015
1 parent cc243a0 commit 0c5fe55
Show file tree
Hide file tree
Showing 3 changed files with 209 additions and 78 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -132,22 +132,39 @@ class InternalTrackStateDStream[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassT
/** Method that generates a RDD for the given time */
override def compute(validTime: Time): Option[RDD[TrackStateRDDRecord[K, S, E]]] = {
// Get the previous state or create a new empty state RDD
val prevStateRDD = getOrCompute(validTime - slideDuration).getOrElse {
TrackStateRDD.createFromPairRDD[K, V, S, E](
spec.getInitialStateRDD().getOrElse(new EmptyRDD[(K, S)](ssc.sparkContext)),
partitioner, validTime
)
val prevStateRDD = getOrCompute(validTime - slideDuration) match {
case Some(rdd) =>
if (rdd.partitioner != Some(partitioner)) {
// If the RDD is not partitioned the right way, let us repartition it using the
// partition index as the key. This is to ensure that state RDD is always partitioned
// before creating another state RDD using it
val kvRDD = rdd.mapPartitions { iter =>
iter.map { x => (TaskContext.get().partitionId(), x)}
}
kvRDD.partitionBy(partitioner).mapPartitions(iter => iter.map { _._2 },
preservesPartitioning = true)
} else {
rdd
}
case None =>
TrackStateRDD.createFromPairRDD[K, V, S, E](
spec.getInitialStateRDD().getOrElse(new EmptyRDD[(K, S)](ssc.sparkContext)),
partitioner, validTime
)
}


// Compute the new state RDD with previous state RDD and partitioned data RDD
parent.getOrCompute(validTime).map { dataRDD =>
val partitionedDataRDD = dataRDD.partitionBy(partitioner)
val timeoutThresholdTime = spec.getTimeoutInterval().map { interval =>
(validTime - interval).milliseconds
}
new TrackStateRDD(
prevStateRDD, partitionedDataRDD, trackingFunction, validTime, timeoutThresholdTime)
// Even if there is no data RDD, use an empty one to create a new state RDD
val dataRDD = parent.getOrCompute(validTime).getOrElse {
context.sparkContext.emptyRDD[(K, V)]
}
val partitionedDataRDD = dataRDD.partitionBy(partitioner)
val timeoutThresholdTime = spec.getTimeoutInterval().map { interval =>
(validTime - interval).milliseconds
}
Some(new TrackStateRDD(
prevStateRDD, partitionedDataRDD, trackingFunction, validTime, timeoutThresholdTime))
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,17 +33,133 @@ import org.mockito.Mockito.mock
import org.scalatest.concurrent.Eventually._
import org.scalatest.time.SpanSugar._

import org.apache.spark.TestUtils
import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite, TestUtils}
import org.apache.spark.streaming.dstream.{DStream, FileInputDStream}
import org.apache.spark.streaming.scheduler._
import org.apache.spark.util.{MutableURLClassLoader, Clock, ManualClock, Utils}

/**
* A trait of that can be mixed in to get methods for testing DStream operations under
* DStream checkpointing. Note that the implementations of this trait has to implement
* the `setupCheckpointOperation`
*/
trait DStreamCheckpointTester { self: SparkFunSuite =>

/**
* Tests a streaming operation under checkpointing, by restarting the operation
* from checkpoint file and verifying whether the final output is correct.
* The output is assumed to have come from a reliable queue which an replay
* data as required.
*
* NOTE: This takes into consideration that the last batch processed before
* master failure will be re-processed after restart/recovery.
*/
protected def testCheckpointedOperation[U: ClassTag, V: ClassTag](
input: Seq[Seq[U]],
operation: DStream[U] => DStream[V],
expectedOutput: Seq[Seq[V]],
numBatchesBeforeRestart: Int,
batchDuration: Duration = Seconds(1),
stopSparkContextAfterTest: Boolean = true
) {
require(numBatchesBeforeRestart < expectedOutput.size,
"Number of batches before context restart less than number of expected output " +
"(i.e. number of total batches to run)")
require(StreamingContext.getActive().isEmpty,
"Cannot run test with already active streaming context")

// Current code assumes that:
// number of inputs = number of outputs = number of batches to be run
val totalNumBatches = input.size
val nextNumBatches = totalNumBatches - numBatchesBeforeRestart
val initialNumExpectedOutputs = numBatchesBeforeRestart
val nextNumExpectedOutputs = expectedOutput.size - initialNumExpectedOutputs + 1
// because the last batch will be processed again

// Setup the stream computation
val checkpointDir = Utils.createTempDir(this.getClass.getSimpleName()).toString
logDebug(s"Using checkpoint directory $checkpointDir")
val ssc = createContextForCheckpointOperation(batchDuration)
require(ssc.conf.get("spark.streaming.clock") === classOf[ManualClock].getName,
"Cannot run test without manual clock in the conf")

val inputStream = new TestInputStream(ssc, input, numPartitions = 2)
val operatedStream = operation(inputStream)
val outputStream = new TestOutputStreamWithPartitions(operatedStream,
new ArrayBuffer[Seq[Seq[V]]] with SynchronizedBuffer[Seq[Seq[V]]])
outputStream.register()
ssc.checkpoint(checkpointDir)

// Do the computation for initial number of batches, create checkpoint file and quit
generateAndAssertOutput[V](ssc, batchDuration, checkpointDir,
expectedOutput.take(numBatchesBeforeRestart), stopSparkContextAfterTest)

// Restart and complete the computation from checkpoint file
logInfo(
"\n-------------------------------------------\n" +
" Restarting stream computation " +
"\n-------------------------------------------\n"
)
val restartedSsc = new StreamingContext(checkpointDir)
generateAndAssertOutput[V](restartedSsc, batchDuration, checkpointDir,
expectedOutput.takeRight(nextNumExpectedOutputs), stopSparkContextAfterTest)
}

protected def createContextForCheckpointOperation(batchDuration: Duration): StreamingContext = {
val conf = new SparkConf().setMaster("local").setAppName(this.getClass.getSimpleName)
conf.set("spark.streaming.clock", classOf[ManualClock].getName())
new StreamingContext(SparkContext.getOrCreate(conf), Seconds(1))
}

private def generateAndAssertOutput[V: ClassTag](
ssc: StreamingContext,
batchDuration: Duration,
checkpointDir: String,
expectedOutput: Seq[Seq[V]],
stopSparkContext: Boolean
) {
try {
ssc.start()
val numBatches = expectedOutput.size
val clock = ssc.scheduler.clock.asInstanceOf[ManualClock]
logDebug("Manual clock before advancing = " + clock.getTimeMillis())
clock.advance((batchDuration * numBatches).milliseconds)
logDebug("Manual clock after advancing = " + clock.getTimeMillis())

val outputStream = ssc.graph.getOutputStreams().filter { dstream =>
dstream.isInstanceOf[TestOutputStreamWithPartitions[V]]
}.head.asInstanceOf[TestOutputStreamWithPartitions[V]]

eventually(timeout(10 seconds)) {
ssc.awaitTerminationOrTimeout(10)
assert(outputStream.output.size === expectedOutput.size)
}

eventually(timeout(10 seconds)) {
Checkpoint.getCheckpointFiles(checkpointDir).exists {
_.toString.contains(clock.getTimeMillis.toString)
}
}

val output = outputStream.output.map(_.flatten)
assert(
output.zip(expectedOutput).forall { case (o, e) => o.toSet === e.toSet },
s"Set comparison failed\n" +
s"Expected output (${expectedOutput.size} items):\n${expectedOutput.mkString("\n")}\n" +
s"Generated output (${output.size} items): ${output.mkString("\n")}"
)
} finally {
ssc.stop(stopSparkContext = stopSparkContext)
}
}
}

/**
* This test suites tests the checkpointing functionality of DStreams -
* the checkpointing of a DStream's RDDs as well as the checkpointing of
* the whole DStream graph.
*/
class CheckpointSuite extends TestSuiteBase {
class CheckpointSuite extends TestSuiteBase with DStreamCheckpointTester {

var ssc: StreamingContext = null

Expand All @@ -56,7 +172,7 @@ class CheckpointSuite extends TestSuiteBase {

override def afterFunction() {
super.afterFunction()
if (ssc != null) ssc.stop()
StreamingContext.getActive().foreach { _.stop() }
Utils.deleteRecursively(new File(checkpointDir))
}

Expand Down Expand Up @@ -634,53 +750,6 @@ class CheckpointSuite extends TestSuiteBase {
checkpointWriter.stop()
}

/**
* Tests a streaming operation under checkpointing, by restarting the operation
* from checkpoint file and verifying whether the final output is correct.
* The output is assumed to have come from a reliable queue which an replay
* data as required.
*
* NOTE: This takes into consideration that the last batch processed before
* master failure will be re-processed after restart/recovery.
*/
def testCheckpointedOperation[U: ClassTag, V: ClassTag](
input: Seq[Seq[U]],
operation: DStream[U] => DStream[V],
expectedOutput: Seq[Seq[V]],
initialNumBatches: Int
) {

// Current code assumes that:
// number of inputs = number of outputs = number of batches to be run
val totalNumBatches = input.size
val nextNumBatches = totalNumBatches - initialNumBatches
val initialNumExpectedOutputs = initialNumBatches
val nextNumExpectedOutputs = expectedOutput.size - initialNumExpectedOutputs + 1
// because the last batch will be processed again

// Do the computation for initial number of batches, create checkpoint file and quit
ssc = setupStreams[U, V](input, operation)
ssc.start()
val output = advanceTimeWithRealDelay[V](ssc, initialNumBatches)
ssc.stop()
verifyOutput[V](output, expectedOutput.take(initialNumBatches), true)
Thread.sleep(1000)

// Restart and complete the computation from checkpoint file
logInfo(
"\n-------------------------------------------\n" +
" Restarting stream computation " +
"\n-------------------------------------------\n"
)
ssc = new StreamingContext(checkpointDir)
ssc.start()
val outputNew = advanceTimeWithRealDelay[V](ssc, nextNumBatches)
// the first element will be re-processed data of the last batch before restart
verifyOutput[V](outputNew, expectedOutput.takeRight(nextNumExpectedOutputs), true)
ssc.stop()
ssc = null
}

/**
* Advances the manual clock on the streaming scheduler by given number of batches.
* It also waits for the expected amount of time for each batch.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,31 +25,27 @@ import scala.reflect.ClassTag
import org.scalatest.PrivateMethodTester._
import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll}

import org.apache.spark.streaming.dstream.{InternalTrackStateDStream, TrackStateDStream, TrackStateDStreamImpl}
import org.apache.spark.streaming.dstream.{DStream, InternalTrackStateDStream, TrackStateDStream, TrackStateDStreamImpl}
import org.apache.spark.util.{ManualClock, Utils}
import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite}

class TrackStateByKeySuite extends SparkFunSuite with BeforeAndAfterAll with BeforeAndAfter {
class TrackStateByKeySuite extends SparkFunSuite
with DStreamCheckpointTester with BeforeAndAfterAll with BeforeAndAfter {

private var sc: SparkContext = null
private var ssc: StreamingContext = null
private var checkpointDir: File = null
private val batchDuration = Seconds(1)
protected var checkpointDir: File = null
protected val batchDuration = Seconds(1)

before {
StreamingContext.getActive().foreach {
_.stop(stopSparkContext = false)
}
StreamingContext.getActive().foreach { _.stop(stopSparkContext = false) }
checkpointDir = Utils.createTempDir("checkpoint")

ssc = new StreamingContext(sc, batchDuration)
ssc.checkpoint(checkpointDir.toString)
}

after {
StreamingContext.getActive().foreach {
_.stop(stopSparkContext = false)
if (checkpointDir != null) {
Utils.deleteRecursively(checkpointDir)
}
StreamingContext.getActive().foreach { _.stop(stopSparkContext = false) }
}

override def beforeAll(): Unit = {
Expand Down Expand Up @@ -242,7 +238,7 @@ class TrackStateByKeySuite extends SparkFunSuite with BeforeAndAfterAll with Bef
assert(dstreamImpl.stateClass === classOf[Double])
assert(dstreamImpl.emittedClass === classOf[Long])
}

val ssc = new StreamingContext(sc, batchDuration)
val inputStream = new TestInputStream[(String, Int)](ssc, Seq.empty, numPartitions = 2)

// Defining StateSpec inline with trackStateByKey and simple function implicitly gets the types
Expand Down Expand Up @@ -451,8 +447,9 @@ class TrackStateByKeySuite extends SparkFunSuite with BeforeAndAfterAll with Bef
expectedCheckpointDuration: Duration,
explicitCheckpointDuration: Option[Duration] = None
): Unit = {
val ssc = new StreamingContext(sc, batchDuration)

try {
ssc = new StreamingContext(sc, batchDuration)
val inputStream = new TestInputStream(ssc, Seq.empty[Seq[Int]], 2).map(_ -> 1)
val dummyFunc = (value: Option[Int], state: State[Int]) => 0
val trackStateStream = inputStream.trackStateByKey(StateSpec.function(dummyFunc))
Expand All @@ -462,11 +459,12 @@ class TrackStateByKeySuite extends SparkFunSuite with BeforeAndAfterAll with Bef
trackStateStream.checkpoint(d)
}
trackStateStream.register()
ssc.checkpoint(checkpointDir.toString)
ssc.start() // should initialize all the checkpoint durations
assert(trackStateStream.checkpointDuration === null)
assert(internalTrackStateStream.checkpointDuration === expectedCheckpointDuration)
} finally {
StreamingContext.getActive().foreach { _.stop(stopSparkContext = false) }
ssc.stop(stopSparkContext = false)
}
}

Expand All @@ -479,6 +477,50 @@ class TrackStateByKeySuite extends SparkFunSuite with BeforeAndAfterAll with Bef
testCheckpointDuration(Seconds(10), Seconds(20), Some(Seconds(20)))
}


test("trackStateByKey - drivery failure recovery") {
val inputData =
Seq(
Seq(),
Seq("a"),
Seq("a", "b"),
Seq("a", "b", "c"),
Seq("a", "b"),
Seq("a"),
Seq()
)

val stateData =
Seq(
Seq(),
Seq(("a", 1)),
Seq(("a", 2), ("b", 1)),
Seq(("a", 3), ("b", 2), ("c", 1)),
Seq(("a", 4), ("b", 3), ("c", 1)),
Seq(("a", 5), ("b", 3), ("c", 1)),
Seq(("a", 5), ("b", 3), ("c", 1))
)

def operation(dstream: DStream[String]): DStream[(String, Int)] = {

val checkpointDuration = batchDuration * (stateData.size / 2)

val runningCount = (value: Option[Int], state: State[Int]) => {
state.update(state.getOption().getOrElse(0) + value.getOrElse(0))
state.get()
}

val trackStateStream = dstream.map { _ -> 1 }.trackStateByKey(
StateSpec.function(runningCount))
// Set internval make sure there is one RDD checkpointing
trackStateStream.checkpoint(checkpointDuration)
trackStateStream.stateSnapshots()
}

testCheckpointedOperation(inputData, operation, stateData, inputData.size / 2,
batchDuration = batchDuration, stopSparkContextAfterTest = false)
}

private def testOperation[K: ClassTag, S: ClassTag, T: ClassTag](
input: Seq[Seq[K]],
trackStateSpec: StateSpec[K, Int, S, T],
Expand All @@ -500,6 +542,7 @@ class TrackStateByKeySuite extends SparkFunSuite with BeforeAndAfterAll with Bef
): (Seq[Seq[T]], Seq[Seq[(K, S)]]) = {

// Setup the stream computation
val ssc = new StreamingContext(sc, Seconds(1))
val inputStream = new TestInputStream(ssc, input, numPartitions = 2)
val trackeStateStream = inputStream.map(x => (x, 1)).trackStateByKey(trackStateSpec)
val collectedOutputs = new ArrayBuffer[Seq[T]] with SynchronizedBuffer[Seq[T]]
Expand All @@ -511,12 +554,14 @@ class TrackStateByKeySuite extends SparkFunSuite with BeforeAndAfterAll with Bef
stateSnapshotStream.register()

val batchCounter = new BatchCounter(ssc)
ssc.checkpoint(checkpointDir.toString)
ssc.start()

val clock = ssc.scheduler.clock.asInstanceOf[ManualClock]
clock.advance(batchDuration.milliseconds * numBatches)

batchCounter.waitUntilBatchesCompleted(numBatches, 10000)
ssc.stop(stopSparkContext = false)
(collectedOutputs, collectedStateSnapshots)
}

Expand Down

0 comments on commit 0c5fe55

Please sign in to comment.