diff --git a/project/MimaBuild.scala b/project/MimaBuild.scala index e7c9c47c960fa..5ea4817bfde18 100644 --- a/project/MimaBuild.scala +++ b/project/MimaBuild.scala @@ -58,17 +58,19 @@ object MimaBuild { SparkBuild.SPARK_VERSION match { case v if v.startsWith("1.0") => Seq( - excludePackage("org.apache.spark.api.java"), - excludePackage("org.apache.spark.streaming.api.java"), - excludePackage("org.apache.spark.mllib") - ) ++ - excludeSparkClass("rdd.ClassTags") ++ - excludeSparkClass("util.XORShiftRandom") ++ - excludeSparkClass("mllib.recommendation.MFDataGenerator") ++ - excludeSparkClass("mllib.optimization.SquaredGradient") ++ - excludeSparkClass("mllib.regression.RidgeRegressionWithSGD") ++ - excludeSparkClass("mllib.regression.LassoWithSGD") ++ - excludeSparkClass("mllib.regression.LinearRegressionWithSGD") + excludePackage("org.apache.spark.api.java"), + excludePackage("org.apache.spark.streaming.api.java"), + excludePackage("org.apache.spark.mllib") + ) ++ + excludeSparkClass("rdd.ClassTags") ++ + excludeSparkClass("util.XORShiftRandom") ++ + excludeSparkClass("mllib.recommendation.MFDataGenerator") ++ + excludeSparkClass("mllib.optimization.SquaredGradient") ++ + excludeSparkClass("mllib.regression.RidgeRegressionWithSGD") ++ + excludeSparkClass("mllib.regression.LassoWithSGD") ++ + excludeSparkClass("mllib.regression.LinearRegressionWithSGD") ++ + excludeSparkClass("streaming.dstream.NetworkReceiver") ++ + excludeSparkClass("streaming.dstream.NetworkReceiver#NetworkReceiverActor") case _ => Seq() } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala index baf80fe2a91b7..93023e8dced57 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala @@ -194,19 +194,19 @@ class CheckpointWriter( } } - def stop() { - synchronized { - if (stopped) { - return - } - stopped = true - } + def stop(): Unit = synchronized { + if (stopped) return + executor.shutdown() val startTime = System.currentTimeMillis() val terminated = executor.awaitTermination(10, java.util.concurrent.TimeUnit.SECONDS) + if (!terminated) { + executor.shutdownNow() + } val endTime = System.currentTimeMillis() logInfo("CheckpointWriter executor terminated ? " + terminated + ", waited for " + (endTime - startTime) + " ms.") + stopped = true } private def fs = synchronized { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala index e198c69470c1f..a4e236c65ff86 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala @@ -158,6 +158,15 @@ class StreamingContext private[streaming] ( private[streaming] val waiter = new ContextWaiter + /** Enumeration to identify current state of the StreamingContext */ + private[streaming] object StreamingContextState extends Enumeration { + type CheckpointState = Value + val Initialized, Started, Stopped = Value + } + + import StreamingContextState._ + private[streaming] var state = Initialized + /** * Return the associated Spark context */ @@ -405,9 +414,18 @@ class StreamingContext private[streaming] ( /** * Start the execution of the streams. */ - def start() = synchronized { + def start(): Unit = synchronized { + // Throw exception if the context has already been started once + // or if a stopped context is being started again + if (state == Started) { + throw new SparkException("StreamingContext has already been started") + } + if (state == Stopped) { + throw new SparkException("StreamingContext has already been stopped") + } validate() scheduler.start() + state = Started } /** @@ -428,14 +446,38 @@ class StreamingContext private[streaming] ( } /** - * Stop the execution of the streams. + * Stop the execution of the streams immediately (does not wait for all received data + * to be processed). * @param stopSparkContext Stop the associated SparkContext or not + * */ def stop(stopSparkContext: Boolean = true): Unit = synchronized { - scheduler.stop() + stop(stopSparkContext, false) + } + + /** + * Stop the execution of the streams, with option of ensuring all received data + * has been processed. + * @param stopSparkContext Stop the associated SparkContext or not + * @param stopGracefully Stop gracefully by waiting for the processing of all + * received data to be completed + */ + def stop(stopSparkContext: Boolean, stopGracefully: Boolean): Unit = synchronized { + // Warn (but not fail) if context is stopped twice, + // or context is stopped before starting + if (state == Initialized) { + logWarning("StreamingContext has not been started yet") + return + } + if (state == Stopped) { + logWarning("StreamingContext has already been stopped") + return + } // no need to throw an exception as its okay to stop twice + scheduler.stop(stopGracefully) logInfo("StreamingContext stopped successfully") waiter.notifyStop() if (stopSparkContext) sc.stop() + state = Stopped } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala index b705d2ec9a58e..c800602d0959b 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala @@ -509,8 +509,16 @@ class JavaStreamingContext(val ssc: StreamingContext) { * Stop the execution of the streams. * @param stopSparkContext Stop the associated SparkContext or not */ - def stop(stopSparkContext: Boolean): Unit = { - ssc.stop(stopSparkContext) + def stop(stopSparkContext: Boolean) = ssc.stop(stopSparkContext) + + /** + * Stop the execution of the streams. + * @param stopSparkContext Stop the associated SparkContext or not + * @param stopGracefully Stop gracefully by waiting for the processing of all + * received data to be completed + */ + def stop(stopSparkContext: Boolean, stopGracefully: Boolean) = { + ssc.stop(stopSparkContext, stopGracefully) } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/NetworkInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/NetworkInputDStream.scala index 72ad0bae75bfb..d19a635fe8eca 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/NetworkInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/NetworkInputDStream.scala @@ -17,7 +17,7 @@ package org.apache.spark.streaming.dstream -import java.util.concurrent.ArrayBlockingQueue +import java.util.concurrent.{TimeUnit, ArrayBlockingQueue} import java.nio.ByteBuffer import scala.collection.mutable.ArrayBuffer @@ -34,6 +34,7 @@ import org.apache.spark.{Logging, SparkEnv} import org.apache.spark.rdd.{RDD, BlockRDD} import org.apache.spark.storage.{BlockId, StorageLevel, StreamBlockId} import org.apache.spark.streaming.scheduler.{DeregisterReceiver, AddBlocks, RegisterReceiver} +import org.apache.spark.util.AkkaUtils /** * Abstract class for defining any [[org.apache.spark.streaming.dstream.InputDStream]] @@ -69,7 +70,7 @@ abstract class NetworkInputDStream[T: ClassTag](@transient ssc_ : StreamingConte // then this returns an empty RDD. This may happen when recovering from a // master failure if (validTime >= graph.startTime) { - val blockIds = ssc.scheduler.networkInputTracker.getBlockIds(id, validTime) + val blockIds = ssc.scheduler.networkInputTracker.getBlocks(id, validTime) Some(new BlockRDD[T](ssc.sc, blockIds)) } else { Some(new BlockRDD[T](ssc.sc, Array[BlockId]())) @@ -79,7 +80,7 @@ abstract class NetworkInputDStream[T: ClassTag](@transient ssc_ : StreamingConte private[streaming] sealed trait NetworkReceiverMessage -private[streaming] case class StopReceiver(msg: String) extends NetworkReceiverMessage +private[streaming] case class StopReceiver() extends NetworkReceiverMessage private[streaming] case class ReportBlock(blockId: BlockId, metadata: Any) extends NetworkReceiverMessage private[streaming] case class ReportError(msg: String) extends NetworkReceiverMessage @@ -90,13 +91,31 @@ private[streaming] case class ReportError(msg: String) extends NetworkReceiverMe */ abstract class NetworkReceiver[T: ClassTag]() extends Serializable with Logging { + /** Local SparkEnv */ lazy protected val env = SparkEnv.get + /** Remote Akka actor for the NetworkInputTracker */ + lazy protected val trackerActor = { + val ip = env.conf.get("spark.driver.host", "localhost") + val port = env.conf.getInt("spark.driver.port", 7077) + val url = "akka.tcp://spark@%s:%s/user/NetworkInputTracker".format(ip, port) + env.actorSystem.actorSelection(url) + } + + /** Akka actor for receiving messages from the NetworkInputTracker in the driver */ lazy protected val actor = env.actorSystem.actorOf( Props(new NetworkReceiverActor()), "NetworkReceiver-" + streamId) + /** Timeout for Akka actor messages */ + lazy protected val askTimeout = AkkaUtils.askTimeout(env.conf) + + /** Thread that starts the receiver and stays blocked while data is being received */ lazy protected val receivingThread = Thread.currentThread() + /** Exceptions that occurs while receiving data */ + protected lazy val exceptions = new ArrayBuffer[Exception] + + /** Identifier of the stream this receiver is associated with */ protected var streamId: Int = -1 /** @@ -112,7 +131,7 @@ abstract class NetworkReceiver[T: ClassTag]() extends Serializable with Logging def getLocationPreference() : Option[String] = None /** - * Starts the receiver. First is accesses all the lazy members to + * Start the receiver. First is accesses all the lazy members to * materialize them. Then it calls the user-defined onStart() method to start * other threads, etc required to receiver the data. */ @@ -124,83 +143,107 @@ abstract class NetworkReceiver[T: ClassTag]() extends Serializable with Logging receivingThread // Call user-defined onStart() + logInfo("Starting receiver") onStart() + + // Wait until interrupt is called on this thread + while(true) Thread.sleep(100000) } catch { case ie: InterruptedException => - logInfo("Receiving thread interrupted") + logInfo("Receiving thread has been interrupted, receiver " + streamId + " stopped") case e: Exception => - stopOnError(e) + logError("Error receiving data in receiver " + streamId, e) + exceptions += e + } + + // Call user-defined onStop() + logInfo("Stopping receiver") + try { + onStop() + } catch { + case e: Exception => + logError("Error stopping receiver " + streamId, e) + exceptions += e + } + + val message = if (exceptions.isEmpty) { + null + } else if (exceptions.size == 1) { + val e = exceptions.head + "Exception in receiver " + streamId + ": " + e.getMessage + "\n" + e.getStackTraceString + } else { + "Multiple exceptions in receiver " + streamId + "(" + exceptions.size + "):\n" + exceptions.zipWithIndex.map { + case (e, i) => "Exception " + i + ": " + e.getMessage + "\n" + e.getStackTraceString + }.mkString("\n") } + logInfo("Deregistering receiver " + streamId) + val future = trackerActor.ask(DeregisterReceiver(streamId, message))(askTimeout) + Await.result(future, askTimeout) + logInfo("Deregistered receiver " + streamId) + env.actorSystem.stop(actor) + logInfo("Stopped receiver " + streamId) } /** - * Stops the receiver. First it interrupts the main receiving thread, - * that is, the thread that called receiver.start(). Then it calls the user-defined - * onStop() method to stop other threads and/or do cleanup. + * Stop the receiver. First it interrupts the main receiving thread, + * that is, the thread that called receiver.start(). */ def stop() { + // Stop receiving by interrupting the receiving thread receivingThread.interrupt() - onStop() - // TODO: terminate the actor + logInfo("Interrupted receiving thread " + receivingThread + " for stopping") } /** - * Stops the receiver and reports exception to the tracker. + * Stop the receiver and reports exception to the tracker. * This should be called whenever an exception is to be handled on any thread * of the receiver. */ protected def stopOnError(e: Exception) { logError("Error receiving data", e) + exceptions += e stop() - actor ! ReportError(e.toString) } - /** - * Pushes a block (as an ArrayBuffer filled with data) into the block manager. + * Push a block (as an ArrayBuffer filled with data) into the block manager. */ def pushBlock(blockId: BlockId, arrayBuffer: ArrayBuffer[T], metadata: Any, level: StorageLevel) { env.blockManager.put(blockId, arrayBuffer.asInstanceOf[ArrayBuffer[Any]], level) - actor ! ReportBlock(blockId, metadata) + trackerActor ! AddBlocks(streamId, Array(blockId), metadata) + logDebug("Pushed block " + blockId) } /** - * Pushes a block (as bytes) into the block manager. + * Push a block (as bytes) into the block manager. */ def pushBlock(blockId: BlockId, bytes: ByteBuffer, metadata: Any, level: StorageLevel) { env.blockManager.putBytes(blockId, bytes, level) - actor ! ReportBlock(blockId, metadata) + trackerActor ! AddBlocks(streamId, Array(blockId), metadata) + } + + /** Set the ID of the DStream that this receiver is associated with */ + protected[streaming] def setStreamId(id: Int) { + streamId = id } /** A helper actor that communicates with the NetworkInputTracker */ private class NetworkReceiverActor extends Actor { - logInfo("Attempting to register with tracker") - val ip = env.conf.get("spark.driver.host", "localhost") - val port = env.conf.getInt("spark.driver.port", 7077) - val url = "akka.tcp://spark@%s:%s/user/NetworkInputTracker".format(ip, port) - val tracker = env.actorSystem.actorSelection(url) - val timeout = 5.seconds override def preStart() { - val future = tracker.ask(RegisterReceiver(streamId, self))(timeout) - Await.result(future, timeout) + logInfo("Registered receiver " + streamId) + val future = trackerActor.ask(RegisterReceiver(streamId, self))(askTimeout) + Await.result(future, askTimeout) } override def receive() = { - case ReportBlock(blockId, metadata) => - tracker ! AddBlocks(streamId, Array(blockId), metadata) - case ReportError(msg) => - tracker ! DeregisterReceiver(streamId, msg) - case StopReceiver(msg) => + case StopReceiver => + logInfo("Received stop signal") stop() - tracker ! DeregisterReceiver(streamId, msg) } } - protected[streaming] def setStreamId(id: Int) { - streamId = id - } - /** * Batches objects created by a [[org.apache.spark.streaming.dstream.NetworkReceiver]] and puts * them into appropriately named blocks at regular intervals. This class starts two threads, @@ -214,23 +257,26 @@ abstract class NetworkReceiver[T: ClassTag]() extends Serializable with Logging val clock = new SystemClock() val blockInterval = env.conf.getLong("spark.streaming.blockInterval", 200) - val blockIntervalTimer = new RecurringTimer(clock, blockInterval, updateCurrentBuffer) + val blockIntervalTimer = new RecurringTimer(clock, blockInterval, updateCurrentBuffer, + "BlockGenerator") val blockStorageLevel = storageLevel val blocksForPushing = new ArrayBlockingQueue[Block](1000) val blockPushingThread = new Thread() { override def run() { keepPushingBlocks() } } var currentBuffer = new ArrayBuffer[T] + var stopped = false def start() { blockIntervalTimer.start() blockPushingThread.start() - logInfo("Data handler started") + logInfo("Started BlockGenerator") } def stop() { - blockIntervalTimer.stop() - blockPushingThread.interrupt() - logInfo("Data handler stopped") + blockIntervalTimer.stop(false) + stopped = true + blockPushingThread.join() + logInfo("Stopped BlockGenerator") } def += (obj: T): Unit = synchronized { @@ -248,24 +294,35 @@ abstract class NetworkReceiver[T: ClassTag]() extends Serializable with Logging } } catch { case ie: InterruptedException => - logInfo("Block interval timer thread interrupted") + logInfo("Block updating timer thread was interrupted") case e: Exception => - NetworkReceiver.this.stop() + NetworkReceiver.this.stopOnError(e) } } private def keepPushingBlocks() { - logInfo("Block pushing thread started") + logInfo("Started block pushing thread") try { - while(true) { + while(!stopped) { + Option(blocksForPushing.poll(100, TimeUnit.MILLISECONDS)) match { + case Some(block) => + NetworkReceiver.this.pushBlock(block.id, block.buffer, block.metadata, storageLevel) + case None => + } + } + // Push out the blocks that are still left + logInfo("Pushing out the last " + blocksForPushing.size() + " blocks") + while (!blocksForPushing.isEmpty) { val block = blocksForPushing.take() NetworkReceiver.this.pushBlock(block.id, block.buffer, block.metadata, storageLevel) + logInfo("Blocks left to push " + blocksForPushing.size()) } + logInfo("Stopped blocks pushing thread") } catch { case ie: InterruptedException => - logInfo("Block pushing thread interrupted") + logInfo("Block pushing thread was interrupted") case e: Exception => - NetworkReceiver.this.stop() + NetworkReceiver.this.stopOnError(e) } } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/SocketInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/SocketInputDStream.scala index 2cdd13f205313..63d94d1cc670a 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/SocketInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/SocketInputDStream.scala @@ -67,7 +67,6 @@ class SocketReceiver[T: ClassTag]( protected def onStop() { blockGenerator.stop() } - } private[streaming] diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receivers/ActorReceiver.scala b/streaming/src/main/scala/org/apache/spark/streaming/receivers/ActorReceiver.scala index bd78bae8a5c51..44eb2750c6c7a 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receivers/ActorReceiver.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receivers/ActorReceiver.scala @@ -174,10 +174,10 @@ private[streaming] class ActorReceiver[T: ClassTag]( blocksGenerator.start() supervisor logInfo("Supervision tree for receivers initialized at:" + supervisor.path) + } protected def onStop() = { supervisor ! PoisonPill } - } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala index c7306248b1950..92d885c4bc5a5 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala @@ -39,16 +39,22 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { private val ssc = jobScheduler.ssc private val graph = ssc.graph + val clock = { val clockClass = ssc.sc.conf.get( "spark.streaming.clock", "org.apache.spark.streaming.util.SystemClock") Class.forName(clockClass).newInstance().asInstanceOf[Clock] } + private val timer = new RecurringTimer(clock, ssc.graph.batchDuration.milliseconds, - longTime => eventActor ! GenerateJobs(new Time(longTime))) - private lazy val checkpointWriter = - if (ssc.checkpointDuration != null && ssc.checkpointDir != null) { - new CheckpointWriter(this, ssc.conf, ssc.checkpointDir, ssc.sparkContext.hadoopConfiguration) + longTime => eventActor ! GenerateJobs(new Time(longTime)), "JobGenerator") + + // This is marked lazy so that this is initialized after checkpoint duration has been set + // in the context and the generator has been started. + private lazy val shouldCheckpoint = ssc.checkpointDuration != null && ssc.checkpointDir != null + + private lazy val checkpointWriter = if (shouldCheckpoint) { + new CheckpointWriter(this, ssc.conf, ssc.checkpointDir, ssc.sparkContext.hadoopConfiguration) } else { null } @@ -57,17 +63,16 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { // This not being null means the scheduler has been started and not stopped private var eventActor: ActorRef = null + // last batch whose completion,checkpointing and metadata cleanup has been completed + private var lastProcessedBatch: Time = null + /** Start generation of jobs */ - def start() = synchronized { - if (eventActor != null) { - throw new SparkException("JobGenerator already started") - } + def start(): Unit = synchronized { + if (eventActor != null) return // generator has already been started eventActor = ssc.env.actorSystem.actorOf(Props(new Actor { def receive = { - case event: JobGeneratorEvent => - logDebug("Got event of type " + event.getClass.getName) - processEvent(event) + case event: JobGeneratorEvent => processEvent(event) } }), "JobGenerator") if (ssc.isCheckpointPresent) { @@ -77,30 +82,79 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { } } - /** Stop generation of jobs */ - def stop() = synchronized { - if (eventActor != null) { - timer.stop() - ssc.env.actorSystem.stop(eventActor) - if (checkpointWriter != null) checkpointWriter.stop() - ssc.graph.stop() - logInfo("JobGenerator stopped") + /** + * Stop generation of jobs. processReceivedData = true makes this wait until jobs + * of current ongoing time interval has been generated, processed and corresponding + * checkpoints written. + */ + def stop(processReceivedData: Boolean): Unit = synchronized { + if (eventActor == null) return // generator has already been stopped + + if (processReceivedData) { + logInfo("Stopping JobGenerator gracefully") + val timeWhenStopStarted = System.currentTimeMillis() + val stopTimeout = 10 * ssc.graph.batchDuration.milliseconds + val pollTime = 100 + + // To prevent graceful stop to get stuck permanently + def hasTimedOut = { + val timedOut = System.currentTimeMillis() - timeWhenStopStarted > stopTimeout + if (timedOut) logWarning("Timed out while stopping the job generator") + timedOut + } + + // Wait until all the received blocks in the network input tracker has + // been consumed by network input DStreams, and jobs have been generated with them + logInfo("Waiting for all received blocks to be consumed for job generation") + while(!hasTimedOut && jobScheduler.networkInputTracker.hasMoreReceivedBlockIds) { + Thread.sleep(pollTime) + } + logInfo("Waited for all received blocks to be consumed for job generation") + + // Stop generating jobs + val stopTime = timer.stop(false) + graph.stop() + logInfo("Stopped generation timer") + + // Wait for the jobs to complete and checkpoints to be written + def haveAllBatchesBeenProcessed = { + lastProcessedBatch != null && lastProcessedBatch.milliseconds == stopTime + } + logInfo("Waiting for jobs to be processed and checkpoints to be written") + while (!hasTimedOut && !haveAllBatchesBeenProcessed) { + Thread.sleep(pollTime) + } + logInfo("Waited for jobs to be processed and checkpoints to be written") + } else { + logInfo("Stopping JobGenerator immediately") + // Stop timer and graph immediately, ignore unprocessed data and pending jobs + timer.stop(true) + graph.stop() } + + // Stop the actor and checkpoint writer + if (shouldCheckpoint) checkpointWriter.stop() + ssc.env.actorSystem.stop(eventActor) + logInfo("Stopped JobGenerator") } /** - * On batch completion, clear old metadata and checkpoint computation. + * Callback called when a batch has been completely processed. */ def onBatchCompletion(time: Time) { eventActor ! ClearMetadata(time) } - + + /** + * Callback called when the checkpoint of a batch has been written. + */ def onCheckpointCompletion(time: Time) { eventActor ! ClearCheckpointData(time) } /** Processes all events */ private def processEvent(event: JobGeneratorEvent) { + logDebug("Got event " + event) event match { case GenerateJobs(time) => generateJobs(time) case ClearMetadata(time) => clearMetadata(time) @@ -114,7 +168,7 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { val startTime = new Time(timer.getStartTime()) graph.start(startTime - graph.batchDuration) timer.start(startTime.milliseconds) - logInfo("JobGenerator started at " + startTime) + logInfo("Started JobGenerator at " + startTime) } /** Restarts the generator based on the information in checkpoint */ @@ -152,15 +206,17 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { // Restart the timer timer.start(restartTime.milliseconds) - logInfo("JobGenerator restarted at " + restartTime) + logInfo("Restarted JobGenerator at " + restartTime) } /** Generate jobs and perform checkpoint for the given `time`. */ private def generateJobs(time: Time) { SparkEnv.set(ssc.env) Try(graph.generateJobs(time)) match { - case Success(jobs) => jobScheduler.runJobs(time, jobs) - case Failure(e) => jobScheduler.reportError("Error generating jobs for time " + time, e) + case Success(jobs) => + jobScheduler.runJobs(time, jobs) + case Failure(e) => + jobScheduler.reportError("Error generating jobs for time " + time, e) } eventActor ! DoCheckpoint(time) } @@ -168,20 +224,32 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { /** Clear DStream metadata for the given `time`. */ private def clearMetadata(time: Time) { ssc.graph.clearMetadata(time) - eventActor ! DoCheckpoint(time) + + // If checkpointing is enabled, then checkpoint, + // else mark batch to be fully processed + if (shouldCheckpoint) { + eventActor ! DoCheckpoint(time) + } else { + markBatchFullyProcessed(time) + } } /** Clear DStream checkpoint data for the given `time`. */ private def clearCheckpointData(time: Time) { ssc.graph.clearCheckpointData(time) + markBatchFullyProcessed(time) } /** Perform checkpoint for the give `time`. */ - private def doCheckpoint(time: Time) = synchronized { - if (checkpointWriter != null && (time - graph.zeroTime).isMultipleOf(ssc.checkpointDuration)) { + private def doCheckpoint(time: Time) { + if (shouldCheckpoint && (time - graph.zeroTime).isMultipleOf(ssc.checkpointDuration)) { logInfo("Checkpointing graph for time " + time) ssc.graph.updateCheckpointData(time) checkpointWriter.write(new Checkpoint(ssc, time)) } } + + private def markBatchFullyProcessed(time: Time) { + lastProcessedBatch = time + } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala index de675d3c7fb94..04e0a6a283cfb 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala @@ -39,7 +39,7 @@ class JobScheduler(val ssc: StreamingContext) extends Logging { private val jobSets = new ConcurrentHashMap[Time, JobSet] private val numConcurrentJobs = ssc.conf.getInt("spark.streaming.concurrentJobs", 1) - private val executor = Executors.newFixedThreadPool(numConcurrentJobs) + private val jobExecutor = Executors.newFixedThreadPool(numConcurrentJobs) private val jobGenerator = new JobGenerator(this) val clock = jobGenerator.clock val listenerBus = new StreamingListenerBus() @@ -50,36 +50,54 @@ class JobScheduler(val ssc: StreamingContext) extends Logging { private var eventActor: ActorRef = null - def start() = synchronized { - if (eventActor != null) { - throw new SparkException("JobScheduler already started") - } + def start(): Unit = synchronized { + if (eventActor != null) return // scheduler has already been started + logDebug("Starting JobScheduler") eventActor = ssc.env.actorSystem.actorOf(Props(new Actor { def receive = { case event: JobSchedulerEvent => processEvent(event) } }), "JobScheduler") + listenerBus.start() networkInputTracker = new NetworkInputTracker(ssc) networkInputTracker.start() - Thread.sleep(1000) jobGenerator.start() - logInfo("JobScheduler started") + logInfo("Started JobScheduler") } - def stop() = synchronized { - if (eventActor != null) { - jobGenerator.stop() - networkInputTracker.stop() - executor.shutdown() - if (!executor.awaitTermination(2, TimeUnit.SECONDS)) { - executor.shutdownNow() - } - listenerBus.stop() - ssc.env.actorSystem.stop(eventActor) - logInfo("JobScheduler stopped") + def stop(processAllReceivedData: Boolean): Unit = synchronized { + if (eventActor == null) return // scheduler has already been stopped + logDebug("Stopping JobScheduler") + + // First, stop receiving + networkInputTracker.stop() + + // Second, stop generating jobs. If it has to process all received data, + // then this will wait for all the processing through JobScheduler to be over. + jobGenerator.stop(processAllReceivedData) + + // Stop the executor for receiving new jobs + logDebug("Stopping job executor") + jobExecutor.shutdown() + + // Wait for the queued jobs to complete if indicated + val terminated = if (processAllReceivedData) { + jobExecutor.awaitTermination(1, TimeUnit.HOURS) // just a very large period of time + } else { + jobExecutor.awaitTermination(2, TimeUnit.SECONDS) } + if (!terminated) { + jobExecutor.shutdownNow() + } + logDebug("Stopped job executor") + + // Stop everything else + listenerBus.stop() + ssc.env.actorSystem.stop(eventActor) + eventActor = null + logInfo("Stopped JobScheduler") } def runJobs(time: Time, jobs: Seq[Job]) { @@ -88,7 +106,7 @@ class JobScheduler(val ssc: StreamingContext) extends Logging { } else { val jobSet = new JobSet(time, jobs) jobSets.put(time, jobSet) - jobSet.jobs.foreach(job => executor.execute(new JobHandler(job))) + jobSet.jobs.foreach(job => jobExecutor.execute(new JobHandler(job))) logInfo("Added jobs for time " + time) } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/NetworkInputTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/NetworkInputTracker.scala index cad68e248ab29..067e804202236 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/NetworkInputTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/NetworkInputTracker.scala @@ -17,20 +17,14 @@ package org.apache.spark.streaming.scheduler -import org.apache.spark.streaming.dstream.{NetworkInputDStream, NetworkReceiver} -import org.apache.spark.streaming.dstream.{StopReceiver, ReportBlock, ReportError} -import org.apache.spark.{SparkException, Logging, SparkEnv} -import org.apache.spark.SparkContext._ - -import scala.collection.mutable.HashMap -import scala.collection.mutable.Queue -import scala.concurrent.duration._ +import scala.collection.mutable.{HashMap, Queue, SynchronizedMap} import akka.actor._ -import akka.pattern.ask -import akka.dispatch._ +import org.apache.spark.{Logging, SparkEnv, SparkException} +import org.apache.spark.SparkContext._ import org.apache.spark.storage.BlockId -import org.apache.spark.streaming.{Time, StreamingContext} +import org.apache.spark.streaming.{StreamingContext, Time} +import org.apache.spark.streaming.dstream.{NetworkReceiver, StopReceiver} import org.apache.spark.util.AkkaUtils private[streaming] sealed trait NetworkInputTrackerMessage @@ -52,8 +46,8 @@ class NetworkInputTracker(ssc: StreamingContext) extends Logging { val networkInputStreams = ssc.graph.getNetworkInputStreams() val networkInputStreamMap = Map(networkInputStreams.map(x => (x.id, x)): _*) val receiverExecutor = new ReceiverExecutor() - val receiverInfo = new HashMap[Int, ActorRef] - val receivedBlockIds = new HashMap[Int, Queue[BlockId]] + val receiverInfo = new HashMap[Int, ActorRef] with SynchronizedMap[Int, ActorRef] + val receivedBlockIds = new HashMap[Int, Queue[BlockId]] with SynchronizedMap[Int, Queue[BlockId]] val timeout = AkkaUtils.askTimeout(ssc.conf) @@ -63,7 +57,7 @@ class NetworkInputTracker(ssc: StreamingContext) extends Logging { var currentTime: Time = null /** Start the actor and receiver execution thread. */ - def start() { + def start() = synchronized { if (actor != null) { throw new SparkException("NetworkInputTracker already started") } @@ -77,72 +71,99 @@ class NetworkInputTracker(ssc: StreamingContext) extends Logging { } /** Stop the receiver execution thread. */ - def stop() { + def stop() = synchronized { if (!networkInputStreams.isEmpty && actor != null) { - receiverExecutor.interrupt() - receiverExecutor.stopReceivers() + // First, stop the receivers + receiverExecutor.stop() + + // Finally, stop the actor ssc.env.actorSystem.stop(actor) + actor = null logInfo("NetworkInputTracker stopped") } } - /** Return all the blocks received from a receiver. */ - def getBlockIds(receiverId: Int, time: Time): Array[BlockId] = synchronized { - val queue = receivedBlockIds.synchronized { - receivedBlockIds.getOrElse(receiverId, new Queue[BlockId]()) + /** Register a receiver */ + def registerReceiver(streamId: Int, receiverActor: ActorRef, sender: ActorRef) { + if (!networkInputStreamMap.contains(streamId)) { + throw new Exception("Register received for unexpected id " + streamId) } - val result = queue.synchronized { - queue.dequeueAll(x => true) - } - logInfo("Stream " + receiverId + " received " + result.size + " blocks") - result.toArray + receiverInfo += ((streamId, receiverActor)) + logInfo("Registered receiver for network stream " + streamId + " from " + sender.path.address) + } + + /** Deregister a receiver */ + def deregisterReceiver(streamId: Int, message: String) { + receiverInfo -= streamId + logError("Deregistered receiver for network stream " + streamId + " with message:\n" + message) + } + + /** Get all the received blocks for the given stream. */ + def getBlocks(streamId: Int, time: Time): Array[BlockId] = { + val queue = receivedBlockIds.getOrElseUpdate(streamId, new Queue[BlockId]()) + val result = queue.dequeueAll(x => true).toArray + logInfo("Stream " + streamId + " received " + result.size + " blocks") + result + } + + /** Add new blocks for the given stream */ + def addBlocks(streamId: Int, blockIds: Seq[BlockId], metadata: Any) = { + val queue = receivedBlockIds.getOrElseUpdate(streamId, new Queue[BlockId]) + queue ++= blockIds + networkInputStreamMap(streamId).addMetadata(metadata) + logDebug("Stream " + streamId + " received new blocks: " + blockIds.mkString("[", ", ", "]")) + } + + /** Check if any blocks are left to be processed */ + def hasMoreReceivedBlockIds: Boolean = { + !receivedBlockIds.forall(_._2.isEmpty) } /** Actor to receive messages from the receivers. */ private class NetworkInputTrackerActor extends Actor { def receive = { - case RegisterReceiver(streamId, receiverActor) => { - if (!networkInputStreamMap.contains(streamId)) { - throw new Exception("Register received for unexpected id " + streamId) - } - receiverInfo += ((streamId, receiverActor)) - logInfo("Registered receiver for network stream " + streamId + " from " - + sender.path.address) + case RegisterReceiver(streamId, receiverActor) => + registerReceiver(streamId, receiverActor, sender) + sender ! true + case AddBlocks(streamId, blockIds, metadata) => + addBlocks(streamId, blockIds, metadata) + case DeregisterReceiver(streamId, message) => + deregisterReceiver(streamId, message) sender ! true - } - case AddBlocks(streamId, blockIds, metadata) => { - val tmp = receivedBlockIds.synchronized { - if (!receivedBlockIds.contains(streamId)) { - receivedBlockIds += ((streamId, new Queue[BlockId])) - } - receivedBlockIds(streamId) - } - tmp.synchronized { - tmp ++= blockIds - } - networkInputStreamMap(streamId).addMetadata(metadata) - } - case DeregisterReceiver(streamId, msg) => { - receiverInfo -= streamId - logError("De-registered receiver for network stream " + streamId - + " with message " + msg) - // TODO: Do something about the corresponding NetworkInputDStream - } } } /** This thread class runs all the receivers on the cluster. */ - class ReceiverExecutor extends Thread { - val env = ssc.env - - override def run() { - try { - SparkEnv.set(env) - startReceivers() - } catch { - case ie: InterruptedException => logInfo("ReceiverExecutor interrupted") - } finally { - stopReceivers() + class ReceiverExecutor { + @transient val env = ssc.env + @transient val thread = new Thread() { + override def run() { + try { + SparkEnv.set(env) + startReceivers() + } catch { + case ie: InterruptedException => logInfo("ReceiverExecutor interrupted") + } + } + } + + def start() { + thread.start() + } + + def stop() { + // Send the stop signal to all the receivers + stopReceivers() + + // Wait for the Spark job that runs the receivers to be over + // That is, for the receivers to quit gracefully. + thread.join(10000) + + // Check if all the receivers have been deregistered or not + if (!receiverInfo.isEmpty) { + logWarning("All of the receivers have not deregistered, " + receiverInfo) + } else { + logInfo("All of the receivers have deregistered successfully") } } @@ -150,7 +171,7 @@ class NetworkInputTracker(ssc: StreamingContext) extends Logging { * Get the receivers from the NetworkInputDStreams, distributes them to the * worker nodes as a parallel collection, and runs them. */ - def startReceivers() { + private def startReceivers() { val receivers = networkInputStreams.map(nis => { val rcvr = nis.getReceiver() rcvr.setStreamId(nis.id) @@ -186,13 +207,16 @@ class NetworkInputTracker(ssc: StreamingContext) extends Logging { } // Distribute the receivers and start them + logInfo("Starting " + receivers.length + " receivers") ssc.sparkContext.runJob(tempRDD, startReceiver) + logInfo("All of the receivers have been terminated") } /** Stops the receivers. */ - def stopReceivers() { + private def stopReceivers() { // Signal the receivers to stop receiverInfo.values.foreach(_ ! StopReceiver) + logInfo("Sent stop signal to all " + receiverInfo.size + " receivers") } } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/Clock.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/Clock.scala index c3a849d2769a7..c5ef2cc8c390d 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/Clock.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/Clock.scala @@ -48,14 +48,11 @@ class SystemClock() extends Clock { minPollTime } } - - + while (true) { currentTime = System.currentTimeMillis() waitTime = targetTime - currentTime - if (waitTime <= 0) { - return currentTime } val sleepTime = diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/RecurringTimer.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/RecurringTimer.scala index 559c2473851b3..f71938ac55ccb 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/RecurringTimer.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/RecurringTimer.scala @@ -17,44 +17,84 @@ package org.apache.spark.streaming.util +import org.apache.spark.Logging + private[streaming] -class RecurringTimer(val clock: Clock, val period: Long, val callback: (Long) => Unit) { +class RecurringTimer(clock: Clock, period: Long, callback: (Long) => Unit, name: String) + extends Logging { - private val thread = new Thread("RecurringTimer") { + private val thread = new Thread("RecurringTimer - " + name) { + setDaemon(true) override def run() { loop } } - - private var nextTime = 0L + @volatile private var prevTime = -1L + @volatile private var nextTime = -1L + @volatile private var stopped = false + + /** + * Get the time when this timer will fire if it is started right now. + * The time will be a multiple of this timer's period and more than + * current system time. + */ def getStartTime(): Long = { (math.floor(clock.currentTime.toDouble / period) + 1).toLong * period } + /** + * Get the time when the timer will fire if it is restarted right now. + * This time depends on when the timer was started the first time, and was stopped + * for whatever reason. The time must be a multiple of this timer's period and + * more than current time. + */ def getRestartTime(originalStartTime: Long): Long = { val gap = clock.currentTime - originalStartTime (math.floor(gap.toDouble / period).toLong + 1) * period + originalStartTime } - def start(startTime: Long): Long = { + /** + * Start at the given start time. + */ + def start(startTime: Long): Long = synchronized { nextTime = startTime thread.start() + logInfo("Started timer for " + name + " at time " + nextTime) nextTime } + /** + * Start at the earliest time it can start based on the period. + */ def start(): Long = { start(getStartTime()) } - def stop() { - thread.interrupt() + /** + * Stop the timer, and return the last time the callback was made. + * interruptTimer = true will interrupt the callback + * if it is in progress (not guaranteed to give correct time in this case). + */ + def stop(interruptTimer: Boolean): Long = synchronized { + if (!stopped) { + stopped = true + if (interruptTimer) thread.interrupt() + thread.join() + logInfo("Stopped timer for " + name + " after time " + prevTime) + } + prevTime } - + + /** + * Repeatedly call the callback every interval. + */ private def loop() { try { - while (true) { + while (!stopped) { clock.waitTillTime(nextTime) callback(nextTime) + prevTime = nextTime nextTime += period + logDebug("Callback for " + name + " called at time " + prevTime) } } catch { case e: InterruptedException => @@ -74,10 +114,10 @@ object RecurringTimer { println("" + currentTime + ": " + (currentTime - lastRecurTime)) lastRecurTime = currentTime } - val timer = new RecurringTimer(new SystemClock(), period, onRecur) + val timer = new RecurringTimer(new SystemClock(), period, onRecur, "Test") timer.start() Thread.sleep(30 * 1000) - timer.stop() + timer.stop(true) } } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala index bcb0c28bf07a0..bb73dbf29b649 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala @@ -324,7 +324,7 @@ class BasicOperationsSuite extends TestSuiteBase { val updateStateOperation = (s: DStream[String]) => { val updateFunc = (values: Seq[Int], state: Option[Int]) => { - Some(values.foldLeft(0)(_ + _) + state.getOrElse(0)) + Some(values.sum + state.getOrElse(0)) } s.map(x => (x, 1)).updateStateByKey[Int](updateFunc) } @@ -359,7 +359,7 @@ class BasicOperationsSuite extends TestSuiteBase { // updateFunc clears a state when a StateObject is seen without new values twice in a row val updateFunc = (values: Seq[Int], state: Option[StateObject]) => { val stateObj = state.getOrElse(new StateObject) - values.foldLeft(0)(_ + _) match { + values.sum match { case 0 => stateObj.expireCounter += 1 // no new values case n => { // has new values, increment and reset expireCounter stateObj.counter += n diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala index 717da8e00462b..9cc27ef7f03b5 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala @@ -17,19 +17,22 @@ package org.apache.spark.streaming -import org.scalatest.{FunSuite, BeforeAndAfter} -import org.scalatest.exceptions.TestFailedDueToTimeoutException +import java.util.concurrent.atomic.AtomicInteger + +import org.apache.spark.{Logging, SparkConf, SparkContext, SparkException} +import org.apache.spark.storage.StorageLevel +import org.apache.spark.streaming.dstream.{DStream, NetworkReceiver} +import org.apache.spark.util.{MetadataCleaner, Utils} +import org.scalatest.{BeforeAndAfter, FunSuite} import org.scalatest.concurrent.Timeouts +import org.scalatest.exceptions.TestFailedDueToTimeoutException import org.scalatest.time.SpanSugar._ -import org.apache.spark.{SparkException, SparkConf, SparkContext} -import org.apache.spark.util.{Utils, MetadataCleaner} -import org.apache.spark.streaming.dstream.DStream -class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts { +class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts with Logging { val master = "local[2]" val appName = this.getClass.getSimpleName - val batchDuration = Seconds(1) + val batchDuration = Milliseconds(500) val sparkHome = "someDir" val envPair = "key" -> "value" val ttl = StreamingContext.DEFAULT_CLEANER_TTL + 100 @@ -108,19 +111,31 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts { val myConf = SparkContext.updatedConf(new SparkConf(false), master, appName) myConf.set("spark.cleaner.ttl", ttl.toString) val ssc1 = new StreamingContext(myConf, batchDuration) + addInputStream(ssc1).register + ssc1.start() val cp = new Checkpoint(ssc1, Time(1000)) assert(MetadataCleaner.getDelaySeconds(cp.sparkConf) === ttl) ssc1.stop() val newCp = Utils.deserialize[Checkpoint](Utils.serialize(cp)) assert(MetadataCleaner.getDelaySeconds(newCp.sparkConf) === ttl) - ssc = new StreamingContext(null, cp, null) + ssc = new StreamingContext(null, newCp, null) assert(MetadataCleaner.getDelaySeconds(ssc.conf) === ttl) } - test("start multiple times") { + test("start and stop state check") { ssc = new StreamingContext(master, appName, batchDuration) addInputStream(ssc).register + assert(ssc.state === ssc.StreamingContextState.Initialized) + ssc.start() + assert(ssc.state === ssc.StreamingContextState.Started) + ssc.stop() + assert(ssc.state === ssc.StreamingContextState.Stopped) + } + + test("start multiple times") { + ssc = new StreamingContext(master, appName, batchDuration) + addInputStream(ssc).register ssc.start() intercept[SparkException] { ssc.start() @@ -133,18 +148,61 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts { ssc.start() ssc.stop() ssc.stop() - ssc = null } + test("stop before start and start after stop") { + ssc = new StreamingContext(master, appName, batchDuration) + addInputStream(ssc).register + ssc.stop() // stop before start should not throw exception + ssc.start() + ssc.stop() + intercept[SparkException] { + ssc.start() // start after stop should throw exception + } + } + + test("stop only streaming context") { ssc = new StreamingContext(master, appName, batchDuration) sc = ssc.sparkContext addInputStream(ssc).register ssc.start() ssc.stop(false) - ssc = null assert(sc.makeRDD(1 to 100).collect().size === 100) ssc = new StreamingContext(sc, batchDuration) + addInputStream(ssc).register + ssc.start() + ssc.stop() + } + + test("stop gracefully") { + val conf = new SparkConf().setMaster(master).setAppName(appName) + conf.set("spark.cleaner.ttl", "3600") + sc = new SparkContext(conf) + for (i <- 1 to 4) { + logInfo("==================================") + ssc = new StreamingContext(sc, batchDuration) + var runningCount = 0 + TestReceiver.counter.set(1) + val input = ssc.networkStream(new TestReceiver) + input.count.foreachRDD(rdd => { + val count = rdd.first() + logInfo("Count = " + count) + runningCount += count.toInt + }) + ssc.start() + ssc.awaitTermination(500) + ssc.stop(stopSparkContext = false, stopGracefully = true) + logInfo("Running count = " + runningCount) + logInfo("TestReceiver.counter = " + TestReceiver.counter.get()) + assert(runningCount > 0) + assert( + (TestReceiver.counter.get() == runningCount + 1) || + (TestReceiver.counter.get() == runningCount + 2), + "Received records = " + TestReceiver.counter.get() + ", " + + "processed records = " + runningCount + ) + } } test("awaitTermination") { @@ -199,7 +257,6 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts { test("awaitTermination with error in job generation") { ssc = new StreamingContext(master, appName, batchDuration) val inputStream = addInputStream(ssc) - inputStream.transform(rdd => { throw new TestException("error in transform"); rdd }).register val exception = intercept[TestException] { ssc.start() @@ -215,4 +272,29 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts { } } -class TestException(msg: String) extends Exception(msg) \ No newline at end of file +class TestException(msg: String) extends Exception(msg) + +/** Custom receiver for testing whether all data received by a receiver gets processed or not */ +class TestReceiver extends NetworkReceiver[Int] { + protected lazy val blockGenerator = new BlockGenerator(StorageLevel.MEMORY_ONLY) + protected def onStart() { + blockGenerator.start() + logInfo("BlockGenerator started on thread " + receivingThread) + try { + while(true) { + blockGenerator += TestReceiver.counter.getAndIncrement + Thread.sleep(0) + } + } finally { + logInfo("Receiving stopped at count value of " + TestReceiver.counter.get()) + } + } + + protected def onStop() { + blockGenerator.stop() + } +} + +object TestReceiver { + val counter = new AtomicInteger(1) +} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala index 201630672ab4c..aa2d5c2fc2454 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala @@ -277,7 +277,7 @@ trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging { assert(timeTaken < maxWaitTimeMillis, "Operation timed out after " + timeTaken + " ms") assert(output.size === numExpectedOutput, "Unexpected number of outputs generated") - Thread.sleep(500) // Give some time for the forgetting old RDDs to complete + Thread.sleep(100) // Give some time for the forgetting old RDDs to complete } catch { case e: Exception => {e.printStackTrace(); throw e} } finally {