Skip to content

Commit

Permalink
Fixed graceful shutdown by removing interrupts on receiving thread.
Browse files Browse the repository at this point in the history
  • Loading branch information
tdas committed Apr 15, 2014
1 parent 9e37a0b commit 2c94579
Show file tree
Hide file tree
Showing 7 changed files with 112 additions and 74 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,9 @@ private[streaming] class BlockGenerator(
private val clock = new SystemClock()
private val blockInterval = conf.getLong("spark.streaming.blockInterval", 200)
private val blockIntervalTimer =
new RecurringTimer(clock, blockInterval, updateCurrentBuffer,
"BlockGenerator")
private val blocksForPushing = new ArrayBlockingQueue[Block](10)
new RecurringTimer(clock, blockInterval, updateCurrentBuffer, "BlockGenerator")
private val blockQueueSize = conf.getInt("spark.streaming.blockQueueSize", 10)
private val blocksForPushing = new ArrayBlockingQueue[Block](blockQueueSize)
private val blockPushingThread = new Thread() { override def run() { keepPushingBlocks() } }

@volatile private var currentBuffer = new ArrayBuffer[Any]
Expand All @@ -68,8 +68,10 @@ private[streaming] class BlockGenerator(

/** Stop all threads. */
def stop() {
blockIntervalTimer.stop(false)
logInfo("Stopping BlockGenerator")
blockIntervalTimer.stop(interruptTimer = false)
stopped = true
logInfo("Waiting for block pushing thread")
blockPushingThread.join()
logInfo("Stopped BlockGenerator")
}
Expand All @@ -90,7 +92,7 @@ private[streaming] class BlockGenerator(
if (newBlockBuffer.size > 0) {
val blockId = StreamBlockId(receiverId, time - blockInterval)
val newBlock = new Block(blockId, newBlockBuffer)
blocksForPushing.add(newBlock)
blocksForPushing.put(newBlock) // put is blocking when queue is full
logDebug("Last element in " + blockId + " is " + newBlockBuffer.last)
}
} catch {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import scala.collection.mutable.ArrayBuffer

import org.apache.spark.{Logging, SparkConf}
import org.apache.spark.storage.StreamBlockId
import java.util.concurrent.CountDownLatch

/**
* Abstract class that is responsible for executing a NetworkReceiver in the worker.
Expand All @@ -39,10 +40,11 @@ private[streaming] abstract class NetworkReceiverExecutor(
protected val receiverId = receiver.receiverId

/** Thread that starts the receiver and stays blocked while data is being received. */
@volatile protected var receivingThread: Option[Thread] = None
@volatile protected var executionThread: Option[Thread] = None

/** Has the receiver been marked for stop. */
@volatile private var stopped = false
//@volatile private var stopped = false
val stopLatch = new CountDownLatch(1)

/** Push a single data item to backend data store. */
def pushSingle(data: Any)
Expand Down Expand Up @@ -77,17 +79,15 @@ private[streaming] abstract class NetworkReceiverExecutor(
*/
def run() {
// Remember this thread as the receiving thread
receivingThread = Some(Thread.currentThread())
executionThread = Some(Thread.currentThread())

try {
// Call user-defined onStart()
logInfo("Calling onStart")
receiver.onStart()

// Wait until interrupt is called on this thread
while(true) {
Thread.sleep(100)
}
awaitStop()
logInfo("Outside latch")
} catch {
case ie: InterruptedException =>
logInfo("Receiving thread has been interrupted, receiver " + receiverId + " stopped")
Expand All @@ -106,27 +106,17 @@ private[streaming] abstract class NetworkReceiverExecutor(
}

/**
* Stop receiving data.
* Mark the executor and the receiver as stopped
*/
def stop() {
// Mark has stopped

if (receivingThread.isDefined) {
// Interrupt the thread
receivingThread.get.interrupt()

// Wait for the receiving thread to finish on its own
receivingThread.get.join(conf.getLong("spark.streaming.receiverStopTimeout", 2000))

// Stop receiving by interrupting the receiving thread
receivingThread.get.interrupt()
logInfo("Interrupted receiving thread of receiver " + receiverId + " for stopping")
}

stopped = true
logInfo("Marked as stop")
// Mark for stop
stopLatch.countDown()
logInfo("Marked for stop " + stopLatch.getCount)
}

/** Check if receiver has been marked for stopping. */
def isStopped = stopped
/** Check if receiver has been marked for stopping */
def isStopped() = (stopLatch.getCount == 0L)

/** Wait the thread until the executor is stopped */
def awaitStop() = stopLatch.await()
}
Original file line number Diff line number Diff line change
Expand Up @@ -161,9 +161,7 @@ private[streaming] class NetworkReceiverExecutorImpl(
override def run() {
// Starting the block generator
blockGenerator.start()

super.run()

// Stopping BlockGenerator
blockGenerator.stop()
reportStop()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ private[streaming]
class JobGenerator(jobScheduler: JobScheduler) extends Logging {

private val ssc = jobScheduler.ssc
private val conf = ssc.conf
private val graph = ssc.graph

val clock = {
Expand Down Expand Up @@ -93,13 +94,18 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging {
if (processReceivedData) {
logInfo("Stopping JobGenerator gracefully")
val timeWhenStopStarted = System.currentTimeMillis()
val stopTimeout = 10 * ssc.graph.batchDuration.milliseconds
val stopTimeout = conf.getLong(
"spark.streaming.gracefulStopTimeout",
10 * ssc.graph.batchDuration.milliseconds
)
val pollTime = 100

// To prevent graceful stop to get stuck permanently
def hasTimedOut = {
val timedOut = System.currentTimeMillis() - timeWhenStopStarted > stopTimeout
if (timedOut) logWarning("Timed out while stopping the job generator")
if (timedOut) {
logWarning("Timed out while stopping the job generator (timeout = " + stopTimeout + ")")
}
timedOut
}

Expand All @@ -112,7 +118,7 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging {
logInfo("Waited for all received blocks to be consumed for job generation")

// Stop generating jobs
val stopTime = timer.stop(false)
val stopTime = timer.stop(interruptTimer = false)
graph.stop()
logInfo("Stopped generation timer")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,9 @@ class RecurringTimer(clock: Clock, period: Long, callback: (Long) => Unit, name:
def stop(interruptTimer: Boolean): Long = synchronized {
if (!stopped) {
stopped = true
if (interruptTimer) thread.interrupt()
if (interruptTimer) {
thread.interrupt()
}
thread.join()
logInfo("Stopped timer for " + name + " after time " + prevTime)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,36 +28,39 @@ import org.scalatest.time.SpanSugar._
import org.apache.spark.SparkConf
import org.apache.spark.storage.{StorageLevel, StreamBlockId}
import org.apache.spark.streaming.receiver.{BlockGenerator, BlockGeneratorListener, NetworkReceiver, NetworkReceiverExecutor}

/** Testsuite for testing the network receiver behavior */
class NetworkReceiverSuite extends FunSuite with Timeouts {

test("network receiver with fake executor") {
val receiver = new MockReceiver
val executor = new MockReceiverExecutor(receiver)
test("network receiver life cycle") {
val receiver = new FakeReceiver
val executor = new FakeReceiverExecutor(receiver)

val receivingThread = new Thread() {
// Thread that runs the executor
val executingThread = new Thread() {
override def run() {
println("Running receiver")
executor.run()
println("Finished receiver")
}
}
receivingThread.start()

// Verify that NetworkReceiver.run() blocks
// Start the receiver
executingThread.start()

// Verify that the receiver
intercept[Exception] {
failAfter(200 millis) {
receivingThread.join()
executingThread.join()
}
}

// Verify that onStart was called, and onStop wasn't called
assert(receiver.started)
assert(receiver.otherThread.isAlive)
assert(!receiver.stopped)
assert(executor.isAllEmpty)

// Verify whether the data stored by the receiver was
// sent to the executor
// Verify whether the data stored by the receiver was sent to the executor
val byteBuffer = ByteBuffer.allocate(100)
val arrayBuffer = new ArrayBuffer[Int]()
val iterator = arrayBuffer.iterator
Expand All @@ -74,25 +77,25 @@ class NetworkReceiverSuite extends FunSuite with Timeouts {
assert(executor.arrayBuffers.size === 1)
assert(executor.arrayBuffers.head.eq(arrayBuffer))

// Verify whether the exceptions reported by the receiver
// was sent to the executor
// Verify whether the exceptions reported by the receiver was sent to the executor
val exception = new Exception
receiver.reportError("Error", exception)
assert(executor.errors.size === 1)
assert(executor.errors.head.eq(exception))

// Verify that stopping actually stops the thread
failAfter(500 millis) {
failAfter(100 millis) {
receiver.stop()
receivingThread.join()
executingThread.join()
assert(!receiver.otherThread.isAlive)
}

// Verify that onStop was called
assert(receiver.stopped)
}

test("block generator") {
val blockGeneratorListener = new MockBlockGeneratorListener
val blockGeneratorListener = new FakeBlockGeneratorListener
val blockInterval = 200
val conf = new SparkConf().set("spark.streaming.blockInterval", blockInterval.toString)
val blockGenerator = new BlockGenerator(blockGeneratorListener, 1, conf)
Expand All @@ -114,25 +117,47 @@ class NetworkReceiverSuite extends FunSuite with Timeouts {

val recordedData = blockGeneratorListener.arrayBuffers.flatten
assert(blockGeneratorListener.arrayBuffers.size > 0)
assert(recordedData.size <= count)
//assert(generatedData.toList === recordedData.toList)
assert(recordedData.toSet === generatedData.toSet)
}
}

class MockReceiver extends NetworkReceiver[Int](StorageLevel.MEMORY_ONLY) {
/**
* An implementation of NetworkReceiver that is used for testing a receiver's life cycle.
*/
class FakeReceiver extends NetworkReceiver[Int](StorageLevel.MEMORY_ONLY) {
var started = false
var stopped = false
def onStart() { started = true }
def onStop() { stopped = true }
val otherThread = new Thread() {
override def run() {
while(!stopped) {
Thread.sleep(10)
}
}
}

def onStart() {
otherThread.start()
started = true
}
def onStop() {
stopped = true
otherThread.join()
}
}

class MockReceiverExecutor(receiver: MockReceiver) extends NetworkReceiverExecutor(receiver) {
/**
* An implementation of NetworkReceiverExecutor used for testing a NetworkReceiver.
* Instead of storing the data in the BlockManager, it stores all the data in a local buffer
* that can used for verifying that the data has been forwarded correctly.
*/
class FakeReceiverExecutor(receiver: FakeReceiver) extends NetworkReceiverExecutor(receiver) {
val singles = new ArrayBuffer[Any]
val byteBuffers = new ArrayBuffer[ByteBuffer]
val iterators = new ArrayBuffer[Iterator[_]]
val arrayBuffers = new ArrayBuffer[ArrayBuffer[_]]
val errors = new ArrayBuffer[Throwable]

/** Check if all data structures are clean */
def isAllEmpty = {
singles.isEmpty && byteBuffers.isEmpty && iterators.isEmpty &&
arrayBuffers.isEmpty && errors.isEmpty
Expand Down Expand Up @@ -171,16 +196,21 @@ class MockReceiverExecutor(receiver: MockReceiver) extends NetworkReceiverExecut
}
}

class MockBlockGeneratorListener extends BlockGeneratorListener {
/**
* An implementation of BlockGeneratorListener that is used to test the BlockGenerator.
*/
class FakeBlockGeneratorListener(pushDelay: Long = 0) extends BlockGeneratorListener {
// buffer of data received as ArrayBuffers
val arrayBuffers = new ArrayBuffer[ArrayBuffer[Int]]
val errors = new ArrayBuffer[Throwable]

def onPushBlock(blockId: StreamBlockId, arrayBuffer: ArrayBuffer[_]) {
val bufferOfInts = arrayBuffer.map(_.asInstanceOf[Int])
arrayBuffers += bufferOfInts
Thread.sleep(0)
}

def onError(message: String, throwable: Throwable) {
errors += throwable
}
}
}
Loading

0 comments on commit 2c94579

Please sign in to comment.