diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala index 661c11f8de53c..21efe2333fc90 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala @@ -51,9 +51,9 @@ private[streaming] class BlockGenerator( private val clock = new SystemClock() private val blockInterval = conf.getLong("spark.streaming.blockInterval", 200) private val blockIntervalTimer = - new RecurringTimer(clock, blockInterval, updateCurrentBuffer, - "BlockGenerator") - private val blocksForPushing = new ArrayBlockingQueue[Block](10) + new RecurringTimer(clock, blockInterval, updateCurrentBuffer, "BlockGenerator") + private val blockQueueSize = conf.getInt("spark.streaming.blockQueueSize", 10) + private val blocksForPushing = new ArrayBlockingQueue[Block](blockQueueSize) private val blockPushingThread = new Thread() { override def run() { keepPushingBlocks() } } @volatile private var currentBuffer = new ArrayBuffer[Any] @@ -68,8 +68,10 @@ private[streaming] class BlockGenerator( /** Stop all threads. */ def stop() { - blockIntervalTimer.stop(false) + logInfo("Stopping BlockGenerator") + blockIntervalTimer.stop(interruptTimer = false) stopped = true + logInfo("Waiting for block pushing thread") blockPushingThread.join() logInfo("Stopped BlockGenerator") } @@ -90,7 +92,7 @@ private[streaming] class BlockGenerator( if (newBlockBuffer.size > 0) { val blockId = StreamBlockId(receiverId, time - blockInterval) val newBlock = new Block(blockId, newBlockBuffer) - blocksForPushing.add(newBlock) + blocksForPushing.put(newBlock) // put is blocking when queue is full logDebug("Last element in " + blockId + " is " + newBlockBuffer.last) } } catch { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/NetworkReceiverExecutor.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/NetworkReceiverExecutor.scala index 77c53112493c9..01b9283568dcf 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/NetworkReceiverExecutor.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/NetworkReceiverExecutor.scala @@ -23,6 +23,7 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.{Logging, SparkConf} import org.apache.spark.storage.StreamBlockId +import java.util.concurrent.CountDownLatch /** * Abstract class that is responsible for executing a NetworkReceiver in the worker. @@ -39,10 +40,11 @@ private[streaming] abstract class NetworkReceiverExecutor( protected val receiverId = receiver.receiverId /** Thread that starts the receiver and stays blocked while data is being received. */ - @volatile protected var receivingThread: Option[Thread] = None + @volatile protected var executionThread: Option[Thread] = None /** Has the receiver been marked for stop. */ - @volatile private var stopped = false + //@volatile private var stopped = false + val stopLatch = new CountDownLatch(1) /** Push a single data item to backend data store. */ def pushSingle(data: Any) @@ -77,17 +79,15 @@ private[streaming] abstract class NetworkReceiverExecutor( */ def run() { // Remember this thread as the receiving thread - receivingThread = Some(Thread.currentThread()) + executionThread = Some(Thread.currentThread()) try { // Call user-defined onStart() logInfo("Calling onStart") receiver.onStart() - // Wait until interrupt is called on this thread - while(true) { - Thread.sleep(100) - } + awaitStop() + logInfo("Outside latch") } catch { case ie: InterruptedException => logInfo("Receiving thread has been interrupted, receiver " + receiverId + " stopped") @@ -106,27 +106,17 @@ private[streaming] abstract class NetworkReceiverExecutor( } /** - * Stop receiving data. + * Mark the executor and the receiver as stopped */ def stop() { - // Mark has stopped - - if (receivingThread.isDefined) { - // Interrupt the thread - receivingThread.get.interrupt() - - // Wait for the receiving thread to finish on its own - receivingThread.get.join(conf.getLong("spark.streaming.receiverStopTimeout", 2000)) - - // Stop receiving by interrupting the receiving thread - receivingThread.get.interrupt() - logInfo("Interrupted receiving thread of receiver " + receiverId + " for stopping") - } - - stopped = true - logInfo("Marked as stop") + // Mark for stop + stopLatch.countDown() + logInfo("Marked for stop " + stopLatch.getCount) } - /** Check if receiver has been marked for stopping. */ - def isStopped = stopped + /** Check if receiver has been marked for stopping */ + def isStopped() = (stopLatch.getCount == 0L) + + /** Wait the thread until the executor is stopped */ + def awaitStop() = stopLatch.await() } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/NetworkReceiverExecutorImpl.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/NetworkReceiverExecutorImpl.scala index 173eb88276684..dcdd14637e3d7 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/NetworkReceiverExecutorImpl.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/NetworkReceiverExecutorImpl.scala @@ -161,9 +161,7 @@ private[streaming] class NetworkReceiverExecutorImpl( override def run() { // Starting the block generator blockGenerator.start() - super.run() - // Stopping BlockGenerator blockGenerator.stop() reportStop() diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala index e564eccba2df5..d50b270124faa 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala @@ -38,6 +38,7 @@ private[streaming] class JobGenerator(jobScheduler: JobScheduler) extends Logging { private val ssc = jobScheduler.ssc + private val conf = ssc.conf private val graph = ssc.graph val clock = { @@ -93,13 +94,18 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { if (processReceivedData) { logInfo("Stopping JobGenerator gracefully") val timeWhenStopStarted = System.currentTimeMillis() - val stopTimeout = 10 * ssc.graph.batchDuration.milliseconds + val stopTimeout = conf.getLong( + "spark.streaming.gracefulStopTimeout", + 10 * ssc.graph.batchDuration.milliseconds + ) val pollTime = 100 // To prevent graceful stop to get stuck permanently def hasTimedOut = { val timedOut = System.currentTimeMillis() - timeWhenStopStarted > stopTimeout - if (timedOut) logWarning("Timed out while stopping the job generator") + if (timedOut) { + logWarning("Timed out while stopping the job generator (timeout = " + stopTimeout + ")") + } timedOut } @@ -112,7 +118,7 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { logInfo("Waited for all received blocks to be consumed for job generation") // Stop generating jobs - val stopTime = timer.stop(false) + val stopTime = timer.stop(interruptTimer = false) graph.stop() logInfo("Stopped generation timer") diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/RecurringTimer.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/RecurringTimer.scala index e016377c94c0d..1a616a0434f2c 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/RecurringTimer.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/RecurringTimer.scala @@ -77,7 +77,9 @@ class RecurringTimer(clock: Clock, period: Long, callback: (Long) => Unit, name: def stop(interruptTimer: Boolean): Long = synchronized { if (!stopped) { stopped = true - if (interruptTimer) thread.interrupt() + if (interruptTimer) { + thread.interrupt() + } thread.join() logInfo("Stopped timer for " + name + " after time " + prevTime) } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/NetworkReceiverSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/NetworkReceiverSuite.scala index 4c3ac00cf36b0..5e0a9d7238ac9 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/NetworkReceiverSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/NetworkReceiverSuite.scala @@ -28,36 +28,39 @@ import org.scalatest.time.SpanSugar._ import org.apache.spark.SparkConf import org.apache.spark.storage.{StorageLevel, StreamBlockId} import org.apache.spark.streaming.receiver.{BlockGenerator, BlockGeneratorListener, NetworkReceiver, NetworkReceiverExecutor} - +/** Testsuite for testing the network receiver behavior */ class NetworkReceiverSuite extends FunSuite with Timeouts { - test("network receiver with fake executor") { - val receiver = new MockReceiver - val executor = new MockReceiverExecutor(receiver) + test("network receiver life cycle") { + val receiver = new FakeReceiver + val executor = new FakeReceiverExecutor(receiver) - val receivingThread = new Thread() { + // Thread that runs the executor + val executingThread = new Thread() { override def run() { println("Running receiver") executor.run() println("Finished receiver") } } - receivingThread.start() - // Verify that NetworkReceiver.run() blocks + // Start the receiver + executingThread.start() + + // Verify that the receiver intercept[Exception] { failAfter(200 millis) { - receivingThread.join() + executingThread.join() } } // Verify that onStart was called, and onStop wasn't called assert(receiver.started) + assert(receiver.otherThread.isAlive) assert(!receiver.stopped) assert(executor.isAllEmpty) - // Verify whether the data stored by the receiver was - // sent to the executor + // Verify whether the data stored by the receiver was sent to the executor val byteBuffer = ByteBuffer.allocate(100) val arrayBuffer = new ArrayBuffer[Int]() val iterator = arrayBuffer.iterator @@ -74,17 +77,17 @@ class NetworkReceiverSuite extends FunSuite with Timeouts { assert(executor.arrayBuffers.size === 1) assert(executor.arrayBuffers.head.eq(arrayBuffer)) - // Verify whether the exceptions reported by the receiver - // was sent to the executor + // Verify whether the exceptions reported by the receiver was sent to the executor val exception = new Exception receiver.reportError("Error", exception) assert(executor.errors.size === 1) assert(executor.errors.head.eq(exception)) // Verify that stopping actually stops the thread - failAfter(500 millis) { + failAfter(100 millis) { receiver.stop() - receivingThread.join() + executingThread.join() + assert(!receiver.otherThread.isAlive) } // Verify that onStop was called @@ -92,7 +95,7 @@ class NetworkReceiverSuite extends FunSuite with Timeouts { } test("block generator") { - val blockGeneratorListener = new MockBlockGeneratorListener + val blockGeneratorListener = new FakeBlockGeneratorListener val blockInterval = 200 val conf = new SparkConf().set("spark.streaming.blockInterval", blockInterval.toString) val blockGenerator = new BlockGenerator(blockGeneratorListener, 1, conf) @@ -114,25 +117,47 @@ class NetworkReceiverSuite extends FunSuite with Timeouts { val recordedData = blockGeneratorListener.arrayBuffers.flatten assert(blockGeneratorListener.arrayBuffers.size > 0) - assert(recordedData.size <= count) - //assert(generatedData.toList === recordedData.toList) + assert(recordedData.toSet === generatedData.toSet) } } -class MockReceiver extends NetworkReceiver[Int](StorageLevel.MEMORY_ONLY) { +/** + * An implementation of NetworkReceiver that is used for testing a receiver's life cycle. + */ +class FakeReceiver extends NetworkReceiver[Int](StorageLevel.MEMORY_ONLY) { var started = false var stopped = false - def onStart() { started = true } - def onStop() { stopped = true } + val otherThread = new Thread() { + override def run() { + while(!stopped) { + Thread.sleep(10) + } + } + } + + def onStart() { + otherThread.start() + started = true + } + def onStop() { + stopped = true + otherThread.join() + } } -class MockReceiverExecutor(receiver: MockReceiver) extends NetworkReceiverExecutor(receiver) { +/** + * An implementation of NetworkReceiverExecutor used for testing a NetworkReceiver. + * Instead of storing the data in the BlockManager, it stores all the data in a local buffer + * that can used for verifying that the data has been forwarded correctly. + */ +class FakeReceiverExecutor(receiver: FakeReceiver) extends NetworkReceiverExecutor(receiver) { val singles = new ArrayBuffer[Any] val byteBuffers = new ArrayBuffer[ByteBuffer] val iterators = new ArrayBuffer[Iterator[_]] val arrayBuffers = new ArrayBuffer[ArrayBuffer[_]] val errors = new ArrayBuffer[Throwable] + /** Check if all data structures are clean */ def isAllEmpty = { singles.isEmpty && byteBuffers.isEmpty && iterators.isEmpty && arrayBuffers.isEmpty && errors.isEmpty @@ -171,16 +196,21 @@ class MockReceiverExecutor(receiver: MockReceiver) extends NetworkReceiverExecut } } -class MockBlockGeneratorListener extends BlockGeneratorListener { +/** + * An implementation of BlockGeneratorListener that is used to test the BlockGenerator. + */ +class FakeBlockGeneratorListener(pushDelay: Long = 0) extends BlockGeneratorListener { + // buffer of data received as ArrayBuffers val arrayBuffers = new ArrayBuffer[ArrayBuffer[Int]] val errors = new ArrayBuffer[Throwable] def onPushBlock(blockId: StreamBlockId, arrayBuffer: ArrayBuffer[_]) { val bufferOfInts = arrayBuffer.map(_.asInstanceOf[Int]) arrayBuffers += bufferOfInts + Thread.sleep(0) } def onError(message: String, throwable: Throwable) { errors += throwable } -} +} \ No newline at end of file diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala index 1b81f2643cc51..b88f26de3869a 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala @@ -55,7 +55,7 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts w sc = null } } -/* + test("from no conf constructor") { ssc = new StreamingContext(master, appName, batchDuration) assert(ssc.sparkContext.conf.get("spark.master") === master) @@ -174,22 +174,24 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts w ssc.start() ssc.stop() } -*/ + test("stop gracefully") { val conf = new SparkConf().setMaster(master).setAppName(appName) conf.set("spark.cleaner.ttl", "3600") sc = new SparkContext(conf) for (i <- 1 to 4) { - logInfo("==================================") - ssc = new StreamingContext(sc, batchDuration) + logInfo("==================================\n\n\n") + println("Round " + i) + ssc = new StreamingContext(sc, Milliseconds(100)) var runningCount = 0 + val startTime = System.currentTimeMillis() TestReceiver.counter.set(1) val input = ssc.networkStream(new TestReceiver) input.count.foreachRDD(rdd => { val count = rdd.first() runningCount += count.toInt logInfo("Count = " + count + ", Running count = " + runningCount) - + println("Count = " + count + ", Running count = " + runningCount) }) ssc.start() ssc.awaitTermination(500) @@ -203,9 +205,10 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts w "Received records = " + TestReceiver.counter.get() + ", " + "processed records = " + runningCount ) + println("Time taken = " + (System.currentTimeMillis() - startTime) + " ms") } } -/* + test("awaitTermination") { ssc = new StreamingContext(master, appName, batchDuration) val inputStream = addInputStream(ssc) @@ -265,7 +268,7 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts w } assert(exception.getMessage.contains("transform"), "Expected exception not thrown") } -*/ + def addInputStream(s: StreamingContext): DStream[Int] = { val input = (1 to 100).map(i => (1 to i)) val inputStream = new TestInputStream(s, input, 1) @@ -277,18 +280,25 @@ class TestException(msg: String) extends Exception(msg) /** Custom receiver for testing whether all data received by a receiver gets processed or not */ class TestReceiver extends NetworkReceiver[Int](StorageLevel.MEMORY_ONLY) with Logging { + + var receivingThreadOption: Option[Thread] = None + def onStart() { - try { - while(true) { - store(TestReceiver.counter.getAndIncrement) - Thread.sleep(0) + val thread = new Thread() { + override def run() { + while (!isStopped) { + store(TestReceiver.counter.getAndIncrement) + } + logInfo("Receiving stopped at count value of " + TestReceiver.counter.get()) } - } finally { - logInfo("Receiving stopped at count value of " + TestReceiver.counter.get()) } + receivingThreadOption = Some(thread) + thread.start() } - def onStop() { } + def onStop() { + // no cleanup to be done, the receiving thread should stop on it own + } } object TestReceiver {