From a36cc48fb949a72edffdd9769bae99eaac08ec2d Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 1 Apr 2014 16:26:35 -0700 Subject: [PATCH] Refactored the NetworkReceiver API for future stability. --- .../streaming/flume/FlumeInputDStream.scala | 23 +- .../streaming/kafka/KafkaInputDStream.scala | 12 +- .../streaming/mqtt/MQTTInputDStream.scala | 37 +- .../spark/streaming/mqtt/MQTTUtils.scala | 2 +- .../twitter/TwitterInputDStream.scala | 12 +- .../dstream/NetworkInputDStream.scala | 462 +++++++++++++----- .../streaming/dstream/RawInputDStream.scala | 9 +- .../dstream/SocketInputDStream.scala | 21 +- .../streaming/receivers/ActorReceiver.scala | 23 +- .../scheduler/NetworkInputTracker.scala | 10 +- .../spark/streaming/InputStreamsSuite.scala | 10 +- .../streaming/StreamingContextSuite.scala | 5 +- 12 files changed, 399 insertions(+), 227 deletions(-) diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala index 34012b846e21e..4b2373473c7cc 100644 --- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala +++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala @@ -34,6 +34,7 @@ import org.apache.spark.util.Utils import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.StreamingContext import org.apache.spark.streaming.dstream._ +import org.apache.spark.Logging private[streaming] class FlumeInputDStream[T: ClassTag]( @@ -115,13 +116,13 @@ private[streaming] object SparkFlumeEvent { private[streaming] class FlumeEventServer(receiver : FlumeReceiver) extends AvroSourceProtocol { override def append(event : AvroFlumeEvent) : Status = { - receiver.blockGenerator += SparkFlumeEvent.fromAvroFlumeEvent(event) + receiver.store(SparkFlumeEvent.fromAvroFlumeEvent(event)) Status.OK } override def appendBatch(events : java.util.List[AvroFlumeEvent]) : Status = { events.foreach (event => - receiver.blockGenerator += SparkFlumeEvent.fromAvroFlumeEvent(event)) + receiver.store(SparkFlumeEvent.fromAvroFlumeEvent(event))) Status.OK } } @@ -133,23 +134,21 @@ class FlumeReceiver( host: String, port: Int, storageLevel: StorageLevel - ) extends NetworkReceiver[SparkFlumeEvent] { + ) extends NetworkReceiver[SparkFlumeEvent](storageLevel) with Logging { - lazy val blockGenerator = new BlockGenerator(storageLevel) + lazy val responder = new SpecificResponder( + classOf[AvroSourceProtocol], new FlumeEventServer(this)) + lazy val server = new NettyServer(responder, new InetSocketAddress(host, port)) - protected override def onStart() { - val responder = new SpecificResponder( - classOf[AvroSourceProtocol], new FlumeEventServer(this)) - val server = new NettyServer(responder, new InetSocketAddress(host, port)) - blockGenerator.start() + def onStart() { server.start() logInfo("Flume receiver started") } - protected override def onStop() { - blockGenerator.stop() + def onStop() { + server.close() logInfo("Flume receiver stopped") } - override def getLocationPreference = Some(host) + override def preferredLocation = Some(host) } diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaInputDStream.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaInputDStream.scala index c2d9dcbfaac7a..7c10c4a0d6a16 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaInputDStream.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaInputDStream.scala @@ -70,21 +70,15 @@ class KafkaReceiver[ kafkaParams: Map[String, String], topics: Map[String, Int], storageLevel: StorageLevel - ) extends NetworkReceiver[Any] { + ) extends NetworkReceiver[Any](storageLevel) with Logging { - // Handles pushing data into the BlockManager - lazy protected val blockGenerator = new BlockGenerator(storageLevel) // Connection to Kafka var consumerConnector : ConsumerConnector = null - def onStop() { - blockGenerator.stop() - } + def onStop() { } def onStart() { - blockGenerator.start() - // In case we are using multiple Threads to handle Kafka Messages val executorPool = Executors.newFixedThreadPool(topics.values.reduce(_ + _)) @@ -130,7 +124,7 @@ class KafkaReceiver[ def run() { logInfo("Starting MessageHandler.") for (msgAndMetadata <- stream) { - blockGenerator += (msgAndMetadata.key, msgAndMetadata.message) + store((msgAndMetadata.key, msgAndMetadata.message)) } } } diff --git a/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTInputDStream.scala b/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTInputDStream.scala index 41e813d48c7b8..5f8d1463dc46b 100644 --- a/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTInputDStream.scala +++ b/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTInputDStream.scala @@ -49,38 +49,34 @@ import org.apache.spark.streaming.dstream._ */ private[streaming] -class MQTTInputDStream[T: ClassTag]( +class MQTTInputDStream( @transient ssc_ : StreamingContext, brokerUrl: String, topic: String, storageLevel: StorageLevel - ) extends NetworkInputDStream[T](ssc_) with Logging { + ) extends NetworkInputDStream[String](ssc_) with Logging { - def getReceiver(): NetworkReceiver[T] = { - new MQTTReceiver(brokerUrl, topic, storageLevel).asInstanceOf[NetworkReceiver[T]] + def getReceiver(): NetworkReceiver[String] = { + new MQTTReceiver(brokerUrl, topic, storageLevel) } } private[streaming] -class MQTTReceiver(brokerUrl: String, - topic: String, - storageLevel: StorageLevel - ) extends NetworkReceiver[Any] { - lazy protected val blockGenerator = new BlockGenerator(storageLevel) - - def onStop() { - blockGenerator.stop() - } +class MQTTReceiver( + brokerUrl: String, + topic: String, + storageLevel: StorageLevel + ) extends NetworkReceiver[String](storageLevel) { + + def onStop() { } def onStart() { - blockGenerator.start() - // Set up persistence for messages - var peristance: MqttClientPersistence = new MemoryPersistence() + val peristance: MqttClientPersistence = new MemoryPersistence() // Initializing Mqtt Client specifying brokerUrl, clientID and MqttClientPersistance - var client: MqttClient = new MqttClient(brokerUrl, MqttClient.generateClientId(), peristance) + val client: MqttClient = new MqttClient(brokerUrl, MqttClient.generateClientId(), peristance) // Connect to MqttBroker client.connect() @@ -89,18 +85,19 @@ class MQTTReceiver(brokerUrl: String, client.subscribe(topic) // Callback automatically triggers as and when new message arrives on specified topic - var callback: MqttCallback = new MqttCallback() { + val callback: MqttCallback = new MqttCallback() { // Handles Mqtt message override def messageArrived(arg0: String, arg1: MqttMessage) { - blockGenerator += new String(arg1.getPayload()) + store(new String(arg1.getPayload())) } override def deliveryComplete(arg0: IMqttDeliveryToken) { } override def connectionLost(arg0: Throwable) { - logInfo("Connection lost " + arg0) + store("Connection lost " + arg0) + stopOnError(new Exception(arg0)) } } diff --git a/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTUtils.scala b/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTUtils.scala index 1b09ee5dc8f65..2f97b3bc6d919 100644 --- a/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTUtils.scala +++ b/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTUtils.scala @@ -37,7 +37,7 @@ object MQTTUtils { topic: String, storageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK_SER_2 ): DStream[String] = { - new MQTTInputDStream[String](ssc, brokerUrl, topic, storageLevel) + new MQTTInputDStream(ssc, brokerUrl, topic, storageLevel) } /** diff --git a/external/twitter/src/main/scala/org/apache/spark/streaming/twitter/TwitterInputDStream.scala b/external/twitter/src/main/scala/org/apache/spark/streaming/twitter/TwitterInputDStream.scala index 3316b6dc39d6b..30cf3bd1a8efe 100644 --- a/external/twitter/src/main/scala/org/apache/spark/streaming/twitter/TwitterInputDStream.scala +++ b/external/twitter/src/main/scala/org/apache/spark/streaming/twitter/TwitterInputDStream.scala @@ -25,6 +25,7 @@ import twitter4j.auth.OAuthAuthorization import org.apache.spark.streaming._ import org.apache.spark.streaming.dstream._ import org.apache.spark.storage.StorageLevel +import org.apache.spark.Logging /* A stream of Twitter statuses, potentially filtered by one or more keywords. * @@ -59,17 +60,15 @@ class TwitterReceiver( twitterAuth: Authorization, filters: Seq[String], storageLevel: StorageLevel - ) extends NetworkReceiver[Status] { + ) extends NetworkReceiver[Status](storageLevel) with Logging { var twitterStream: TwitterStream = _ - lazy val blockGenerator = new BlockGenerator(storageLevel) - protected override def onStart() { - blockGenerator.start() + def onStart() { twitterStream = new TwitterStreamFactory().getInstance(twitterAuth) twitterStream.addListener(new StatusListener { def onStatus(status: Status) = { - blockGenerator += status + store(status) } // Unimplemented def onDeletionNotice(statusDeletionNotice: StatusDeletionNotice) {} @@ -89,8 +88,7 @@ class TwitterReceiver( logInfo("Twitter receiver started") } - protected override def onStop() { - blockGenerator.stop() + def onStop() { twitterStream.shutdown() logInfo("Twitter receiver stopped") } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/NetworkInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/NetworkInputDStream.scala index 72ad0bae75bfb..77cf5ee4cc075 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/NetworkInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/NetworkInputDStream.scala @@ -17,14 +17,14 @@ package org.apache.spark.streaming.dstream -import java.util.concurrent.ArrayBlockingQueue -import java.nio.ByteBuffer - -import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer} import scala.concurrent.Await -import scala.concurrent.duration._ import scala.reflect.ClassTag +import java.nio.ByteBuffer +import java.util.concurrent.atomic.AtomicLong +import java.util.concurrent.{TimeUnit, ArrayBlockingQueue} + import akka.actor.{Props, Actor} import akka.pattern.ask @@ -34,6 +34,7 @@ import org.apache.spark.{Logging, SparkEnv} import org.apache.spark.rdd.{RDD, BlockRDD} import org.apache.spark.storage.{BlockId, StorageLevel, StreamBlockId} import org.apache.spark.streaming.scheduler.{DeregisterReceiver, AddBlocks, RegisterReceiver} +import org.apache.spark.util.AkkaUtils /** * Abstract class for defining any [[org.apache.spark.streaming.dstream.InputDStream]] @@ -85,188 +86,383 @@ private[streaming] case class ReportBlock(blockId: BlockId, metadata: Any) private[streaming] case class ReportError(msg: String) extends NetworkReceiverMessage /** - * Abstract class of a receiver that can be run on worker nodes to receive external data. See - * [[org.apache.spark.streaming.dstream.NetworkInputDStream]] for an explanation. + * Abstract class of a receiver that can be run on worker nodes to receive external data. A + * custom receiver can be defined by defining the functions onStart() and onStop(). onStart() + * should define the setup steps necessary to start receiving data, + * and onStop() should define the cleanup steps necessary to stop receiving data. A custom + * receiver would look something like this. + * + * class MyReceiver(storageLevel) extends NetworkReceiver[String](storageLevel) { + * def onStart() { + * // Setup stuff (start threads, open sockets, etc.) to start receiving data. + * // Call store(...) to store received data into Spark's memory. + * // Optionally, wait for other threads to complete or watch for exceptions. + * // Call stopOnError(...) if there is an error that you cannot ignore and need + * // the receiver to be terminated. + * } + * + * def onStop() { + * // Cleanup stuff (stop threads, close sockets, etc.) to stop receiving data. + * } + * } */ -abstract class NetworkReceiver[T: ClassTag]() extends Serializable with Logging { +abstract class NetworkReceiver[T: ClassTag](val storageLevel: StorageLevel) + extends Serializable { - lazy protected val env = SparkEnv.get + /** + * This method is called by the system when the receiver is started to start receiving data. + * All threads and resources set up in this method must be cleaned up in onStop(). + * If there are exceptions on other threads such that the receiver must be terminated, + * then you must call stopOnError(exception). However, the thread that called onStart() must + * never catch and ignore InterruptedException (it can catch and rethrow). + */ + def onStart() - lazy protected val actor = env.actorSystem.actorOf( - Props(new NetworkReceiverActor()), "NetworkReceiver-" + streamId) + /** + * This method is called by the system when the receiver is stopped to stop receiving data. + * All threads and resources setup in onStart() must be cleaned up in this method. + */ + def onStop() - lazy protected val receivingThread = Thread.currentThread() + /** Override this to specify a preferred location (hostname). */ + def preferredLocation : Option[String] = None - protected var streamId: Int = -1 + /** Store a single item of received data to Spark's memory/ */ + def store(dataItem: T) { + handler.pushSingle(dataItem) + } - /** - * This method will be called to start receiving data. All your receiver - * starting code should be implemented by defining this function. - */ - protected def onStart() + /** Store a sequence of received data block into Spark's memory. */ + def store(dataBuffer: ArrayBuffer[T]) { + handler.pushArrayBuffer(dataBuffer) + } + + /** Store a sequence of received data block into Spark's memory. */ + def store(dataIterator: Iterator[T]) { + handler.pushIterator(dataIterator) + } + + /** Store the bytes of received data block into Spark's memory. */ + def store(bytes: ByteBuffer) { + handler.pushBytes(bytes) + } + + /** Stop the receiver. */ + def stop() { + handler.stop() + } + + /** Stop the receiver when an error occurred. */ + def stopOnError(e: Exception) { + handler.stop(e) + } + + /** Check if receiver has been marked for stopping */ + def isStopped: Boolean = { + handler.isStopped + } + + /** Get unique identifier of this receiver. */ + def receiverId = id - /** This method will be called to stop receiving data. */ - protected def onStop() + /** Identifier of the stream this receiver is associated with. */ + private var id: Int = -1 - /** Conveys a placement preference (hostname) for this receiver. */ - def getLocationPreference() : Option[String] = None + /** Handler object that runs the receiver. This is instantiated lazily in the worker. */ + private[streaming] lazy val handler = new NetworkReceiverHandler(this) + + /** Set the ID of the DStream that this receiver is associated with */ + private[streaming] def setReceiverId(id_ : Int) { + id = id_ + } +} + + +private[streaming] class NetworkReceiverHandler(receiver: NetworkReceiver[_]) extends Logging { + + val env = SparkEnv.get + val receiverId = receiver.receiverId + val storageLevel = receiver.storageLevel + + /** Remote Akka actor for the NetworkInputTracker */ + private val trackerActor = { + val ip = env.conf.get("spark.driver.host", "localhost") + val port = env.conf.getInt("spark.driver.port", 7077) + val url = "akka.tcp://spark@%s:%s/user/NetworkInputTracker".format(ip, port) + env.actorSystem.actorSelection(url) + } + + /** Timeout for Akka actor messages */ + private val askTimeout = AkkaUtils.askTimeout(env.conf) + + /** Akka actor for receiving messages from the NetworkInputTracker in the driver */ + private val actor = env.actorSystem.actorOf( + Props(new Actor { + override def preStart() { + logInfo("Registered receiver " + receiverId) + val future = trackerActor.ask(RegisterReceiver(receiverId, self))(askTimeout) + Await.result(future, askTimeout) + } + + override def receive() = { + case StopReceiver => + logInfo("Received stop signal") + stop() + } + }), "NetworkReceiver-" + receiverId) + + /** Divides received data records into data blocks for pushing in BlockManager */ + private val blockGenerator = new BlockGenerator(this) + + /** Exceptions that occurs while receiving data */ + private val exceptions = new ArrayBuffer[Exception] with SynchronizedBuffer[Exception] + + /** Unique block ids if one wants to add blocks directly */ + private val newBlockId = new AtomicLong(System.currentTimeMillis()) + + /** Thread that starts the receiver and stays blocked while data is being received */ + private var receivingThread: Option[Thread] = None + + /** Has the receiver been marked for stop */ + private var stopped = false /** * Starts the receiver. First is accesses all the lazy members to * materialize them. Then it calls the user-defined onStart() method to start - * other threads, etc required to receiver the data. + * other threads, etc. required to receive the data. */ - def start() { - try { - // Access the lazy vals to materialize them - env - actor - receivingThread + def run() { + // Remember this thread as the receiving thread + receivingThread = Some(Thread.currentThread()) + + // Starting the block generator + blockGenerator.start() + try { // Call user-defined onStart() - onStart() + logInfo("Calling onStart") + receiver.onStart() + + // Wait until interrupt is called on this thread + while(true) Thread.sleep(100000) } catch { case ie: InterruptedException => - logInfo("Receiving thread interrupted") + logInfo("Receiving thread has been interrupted, receiver " + receiverId + " stopped") case e: Exception => - stopOnError(e) + logError("Error receiving data in receiver " + receiverId, e) + exceptions += e } + + // Call user-defined onStop() + try { + logInfo("Calling onStop") + receiver.onStop() + } catch { + case e: Exception => + logError("Error stopping receiver " + receiverId, e) + exceptions += e + } + + // Stopping BlockGenerator + blockGenerator.stop() + + val message = if (exceptions.isEmpty) { + null + } else if (exceptions.size == 1) { + val e = exceptions.head + "Exception in receiver " + receiverId + ": " + e.getMessage + "\n" + e.getStackTraceString + } else { + "Multiple exceptions in receiver " + receiverId + "(" + exceptions.size + "):\n" + exceptions.zipWithIndex.map { + case (e, i) => "Exception " + i + ": " + e.getMessage + "\n" + e.getStackTraceString + }.mkString("\n") + } + logInfo("Deregistering receiver " + receiverId) + val future = trackerActor.ask(DeregisterReceiver(receiverId, message))(askTimeout) + Await.result(future, askTimeout) + logInfo("Deregistered receiver " + receiverId) + env.actorSystem.stop(actor) + logInfo("Stopped receiver " + receiverId) } - /** - * Stops the receiver. First it interrupts the main receiving thread, - * that is, the thread that called receiver.start(). Then it calls the user-defined - * onStop() method to stop other threads and/or do cleanup. - */ - def stop() { - receivingThread.interrupt() - onStop() - // TODO: terminate the actor + + /** Push a single record of received data into block generator. */ + def pushSingle(data: Any) { + blockGenerator += data + } + + /** Push a block of received data into block manager. */ + def pushArrayBuffer( + arrayBuffer: ArrayBuffer[_], + blockId: StreamBlockId = nextBlockId, + metadata: Any = null + ) { + logDebug("Pushing block " + blockId) + val time = System.currentTimeMillis + env.blockManager.put(blockId, arrayBuffer.asInstanceOf[ArrayBuffer[Any]], storageLevel, true) + logDebug("Pushed block " + blockId + " in " + (System.currentTimeMillis - time) + " ms") + trackerActor ! AddBlocks(receiverId, Array(blockId), null) + logDebug("Reported block " + blockId) } /** - * Stops the receiver and reports exception to the tracker. - * This should be called whenever an exception is to be handled on any thread - * of the receiver. + * Push a received data into Spark as . Call this method from the data receiving + * thread to submit + * a block of data. */ - protected def stopOnError(e: Exception) { - logError("Error receiving data", e) - stop() - actor ! ReportError(e.toString) + def pushIterator( + iterator: Iterator[_], + blockId: StreamBlockId = nextBlockId, + metadata: Any = null + ) { + env.blockManager.put(blockId, iterator, storageLevel, true) + trackerActor ! AddBlocks(receiverId, Array(blockId), null) + logInfo("Pushed block " + blockId) } /** - * Pushes a block (as an ArrayBuffer filled with data) into the block manager. + * Push a block (as bytes) into the block manager. */ - def pushBlock(blockId: BlockId, arrayBuffer: ArrayBuffer[T], metadata: Any, level: StorageLevel) { - env.blockManager.put(blockId, arrayBuffer.asInstanceOf[ArrayBuffer[Any]], level) - actor ! ReportBlock(blockId, metadata) + def pushBytes( + bytes: ByteBuffer, + blockId: StreamBlockId = nextBlockId, + metadata: Any = null + ) { + env.blockManager.putBytes(blockId, bytes, storageLevel, true) + trackerActor ! AddBlocks(receiverId, Array(blockId), null) + logInfo("Pushed block " + blockId) } /** - * Pushes a block (as bytes) into the block manager. + * Stop receiving data. */ - def pushBlock(blockId: BlockId, bytes: ByteBuffer, metadata: Any, level: StorageLevel) { - env.blockManager.putBytes(blockId, bytes, level) - actor ! ReportBlock(blockId, metadata) - } - - /** A helper actor that communicates with the NetworkInputTracker */ - private class NetworkReceiverActor extends Actor { - logInfo("Attempting to register with tracker") - val ip = env.conf.get("spark.driver.host", "localhost") - val port = env.conf.getInt("spark.driver.port", 7077) - val url = "akka.tcp://spark@%s:%s/user/NetworkInputTracker".format(ip, port) - val tracker = env.actorSystem.actorSelection(url) - val timeout = 5.seconds - - override def preStart() { - val future = tracker.ask(RegisterReceiver(streamId, self))(timeout) - Await.result(future, timeout) + def stop(e: Exception = null) { + // Mark has stopped + stopped = true + logInfo("Marked as stop") + + // Store the exception if any + if (e != null) { + logError("Error receiving data", e) + exceptions += e } - override def receive() = { - case ReportBlock(blockId, metadata) => - tracker ! AddBlocks(streamId, Array(blockId), metadata) - case ReportError(msg) => - tracker ! DeregisterReceiver(streamId, msg) - case StopReceiver(msg) => - stop() - tracker ! DeregisterReceiver(streamId, msg) + if (receivingThread.isDefined) { + // Wait for the receiving thread to finish on its own + receivingThread.get.join(env.conf.getLong("spark.streaming.receiverStopTimeout", 2000)) + + // Stop receiving by interrupting the receiving thread + receivingThread.get.interrupt() + logInfo("Interrupted receiving thread of receiver " + receiverId + " for stopping") } } - protected[streaming] def setStreamId(id: Int) { - streamId = id + /** Check if receiver has been marked for stopping. */ + def isStopped = stopped + + private def nextBlockId = StreamBlockId(receiverId, newBlockId.getAndIncrement) +} + +/** + * Batches objects created by a [[org.apache.spark.streaming.dstream.NetworkReceiver]] and puts them into + * appropriately named blocks at regular intervals. This class starts two threads, + * one to periodically start a new batch and prepare the previous batch of as a block, + * the other to push the blocks into the block manager. + */ +private[streaming] class BlockGenerator(handler: NetworkReceiverHandler) extends Logging { + + private case class Block(id: StreamBlockId, buffer: ArrayBuffer[Any], metadata: Any = null) + + private val env = handler.env + private val blockInterval = env.conf.getLong("spark.streaming.blockInterval", 200) + private val blockIntervalTimer = + new RecurringTimer(new SystemClock(), blockInterval, updateCurrentBuffer) + private val blocksForPushing = new ArrayBlockingQueue[Block](10) + private val blockPushingThread = new Thread() { override def run() { keepPushingBlocks() } } + + private var currentBuffer = new ArrayBuffer[Any] + private var stopped = false + + def start() { + blockIntervalTimer.start() + blockPushingThread.start() + logInfo("Started BlockGenerator") } - /** - * Batches objects created by a [[org.apache.spark.streaming.dstream.NetworkReceiver]] and puts - * them into appropriately named blocks at regular intervals. This class starts two threads, - * one to periodically start a new batch and prepare the previous batch of as a block, - * the other to push the blocks into the block manager. - */ - class BlockGenerator(storageLevel: StorageLevel) - extends Serializable with Logging { + def stop() { + // Stop generating blocks + blockIntervalTimer.stop() - case class Block(id: BlockId, buffer: ArrayBuffer[T], metadata: Any = null) + // Mark as stopped + synchronized { stopped = true } - val clock = new SystemClock() - val blockInterval = env.conf.getLong("spark.streaming.blockInterval", 200) - val blockIntervalTimer = new RecurringTimer(clock, blockInterval, updateCurrentBuffer) - val blockStorageLevel = storageLevel - val blocksForPushing = new ArrayBlockingQueue[Block](1000) - val blockPushingThread = new Thread() { override def run() { keepPushingBlocks() } } + // Wait for all blocks to be pushed + logDebug("Waiting for block pushing thread to terminate") + blockPushingThread.join() + logInfo("Stopped BlockGenerator") + } - var currentBuffer = new ArrayBuffer[T] + def += (obj: Any): Unit = synchronized { + currentBuffer += obj + } - def start() { - blockIntervalTimer.start() - blockPushingThread.start() - logInfo("Data handler started") - } + private def isStopped = synchronized { stopped } - def stop() { - blockIntervalTimer.stop() - blockPushingThread.interrupt() - logInfo("Data handler stopped") + private def updateCurrentBuffer(time: Long): Unit = synchronized { + try { + val newBlockBuffer = currentBuffer + currentBuffer = new ArrayBuffer[Any] + if (newBlockBuffer.size > 0) { + val blockId = StreamBlockId(handler.receiverId, time - blockInterval) + val newBlock = new Block(blockId, newBlockBuffer) + blocksForPushing.add(newBlock) + logDebug("Last element in " + blockId + " is " + newBlockBuffer.last) + } + } catch { + case ie: InterruptedException => + logInfo("Block updating timer thread was interrupted") + case e: Exception => + handler.stop(e) } + } - def += (obj: T): Unit = synchronized { - currentBuffer += obj - } + private def keepPushingBlocks() { + logInfo("Started block pushing thread") - private def updateCurrentBuffer(time: Long): Unit = synchronized { - try { - val newBlockBuffer = currentBuffer - currentBuffer = new ArrayBuffer[T] - if (newBlockBuffer.size > 0) { - val blockId = StreamBlockId(NetworkReceiver.this.streamId, time - blockInterval) - val newBlock = new Block(blockId, newBlockBuffer) - blocksForPushing.add(newBlock) - } - } catch { - case ie: InterruptedException => - logInfo("Block interval timer thread interrupted") - case e: Exception => - NetworkReceiver.this.stop() + def pushNextBlock() { + Option(blocksForPushing.poll(100, TimeUnit.MILLISECONDS)) match { + case Some(block) => + handler.pushArrayBuffer(block.buffer, block.id, block.metadata) + logInfo("Pushed block "+ block.id) + case None => } } - private def keepPushingBlocks() { - logInfo("Block pushing thread started") - try { - while(true) { - val block = blocksForPushing.take() - NetworkReceiver.this.pushBlock(block.id, block.buffer, block.metadata, storageLevel) + try { + while(!isStopped) { + Option(blocksForPushing.poll(100, TimeUnit.MILLISECONDS)) match { + case Some(block) => + handler.pushArrayBuffer(block.buffer, block.id, block.metadata) + logInfo("Pushed block "+ block.id) + case None => } - } catch { - case ie: InterruptedException => - logInfo("Block pushing thread interrupted") - case e: Exception => - NetworkReceiver.this.stop() } + // Push out the blocks that are still left + logInfo("Pushing out the last " + blocksForPushing.size() + " blocks") + while (!blocksForPushing.isEmpty) { + logDebug("Getting block ") + val block = blocksForPushing.take() + logDebug("Got block") + handler.pushArrayBuffer(block.buffer, block.id, block.metadata) + logInfo("Blocks left to push " + blocksForPushing.size()) + } + logInfo("Stopped block pushing thread") + } catch { + case ie: InterruptedException => + logInfo("Block pushing thread was interrupted") + case e: Exception => + handler.stop(e) } } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/RawInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/RawInputDStream.scala index dea0f26f908fb..b920dae60cd66 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/RawInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/RawInputDStream.scala @@ -17,7 +17,7 @@ package org.apache.spark.streaming.dstream -import org.apache.spark.Logging +import org.apache.spark.{Logging, SparkEnv} import org.apache.spark.storage.{StorageLevel, StreamBlockId} import org.apache.spark.streaming.StreamingContext @@ -51,12 +51,10 @@ class RawInputDStream[T: ClassTag]( private[streaming] class RawNetworkReceiver(host: String, port: Int, storageLevel: StorageLevel) - extends NetworkReceiver[Any] { + extends NetworkReceiver[Any](storageLevel) with Logging { var blockPushingThread: Thread = null - override def getLocationPreference = None - def onStart() { // Open a socket to the target address and keep reading from it logInfo("Connecting to " + host + ":" + port) @@ -73,9 +71,8 @@ class RawNetworkReceiver(host: String, port: Int, storageLevel: StorageLevel) var nextBlockNumber = 0 while (true) { val buffer = queue.take() - val blockId = StreamBlockId(streamId, nextBlockNumber) nextBlockNumber += 1 - pushBlock(blockId, buffer, null, storageLevel) + store(buffer) } } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/SocketInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/SocketInputDStream.scala index 2cdd13f205313..53ead3d22f736 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/SocketInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/SocketInputDStream.scala @@ -25,6 +25,7 @@ import scala.reflect.ClassTag import java.io._ import java.net.Socket +import org.apache.spark.Logging private[streaming] class SocketInputDStream[T: ClassTag]( @@ -46,26 +47,22 @@ class SocketReceiver[T: ClassTag]( port: Int, bytesToObjects: InputStream => Iterator[T], storageLevel: StorageLevel - ) extends NetworkReceiver[T] { + ) extends NetworkReceiver[T](storageLevel) with Logging { - lazy protected val blockGenerator = new BlockGenerator(storageLevel) + var socket: Socket = null - override def getLocationPreference = None - - protected def onStart() { + def onStart() { logInfo("Connecting to " + host + ":" + port) - val socket = new Socket(host, port) + socket = new Socket(host, port) logInfo("Connected to " + host + ":" + port) - blockGenerator.start() val iterator = bytesToObjects(socket.getInputStream()) - while(iterator.hasNext) { - val obj = iterator.next - blockGenerator += obj + while(!isStopped && iterator.hasNext) { + store(iterator.next) } } - protected def onStop() { - blockGenerator.stop() + def onStop() { + if (socket != null) socket.close() } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receivers/ActorReceiver.scala b/streaming/src/main/scala/org/apache/spark/streaming/receivers/ActorReceiver.scala index bd78bae8a5c51..da07878cc3070 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receivers/ActorReceiver.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receivers/ActorReceiver.scala @@ -31,6 +31,7 @@ import org.apache.spark.streaming.dstream.NetworkReceiver import java.util.concurrent.atomic.AtomicInteger import scala.collection.mutable.ArrayBuffer +import org.apache.spark.{SparkEnv, Logging} /** A helper with set of defaults for supervisor strategy */ object ReceiverSupervisorStrategy { @@ -120,13 +121,10 @@ private[streaming] class ActorReceiver[T: ClassTag]( name: String, storageLevel: StorageLevel, receiverSupervisorStrategy: SupervisorStrategy) - extends NetworkReceiver[T] { + extends NetworkReceiver[T](storageLevel) with Logging { - protected lazy val blocksGenerator: BlockGenerator = - new BlockGenerator(storageLevel) - - protected lazy val supervisor = env.actorSystem.actorOf(Props(new Supervisor), - "Supervisor" + streamId) + protected lazy val supervisor = SparkEnv.get.actorSystem.actorOf(Props(new Supervisor), + "Supervisor" + receiverId) class Supervisor extends Actor { @@ -141,8 +139,8 @@ private[streaming] class ActorReceiver[T: ClassTag]( case Data(iter: Iterator[_]) => pushBlock(iter.asInstanceOf[Iterator[T]]) - case Data(msg) => - blocksGenerator += msg.asInstanceOf[T] + case Data(msg) ⇒ + store(msg.asInstanceOf[T]) n.incrementAndGet case props: Props => @@ -165,18 +163,15 @@ private[streaming] class ActorReceiver[T: ClassTag]( } protected def pushBlock(iter: Iterator[T]) { - val buffer = new ArrayBuffer[T] - buffer ++= iter - pushBlock(StreamBlockId(streamId, System.nanoTime()), buffer, null, storageLevel) + store(iter) } - protected def onStart() = { - blocksGenerator.start() + def onStart() = { supervisor logInfo("Supervision tree for receivers initialized at:" + supervisor.path) } - protected def onStop() = { + def onStop() = { supervisor ! PoisonPill } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/NetworkInputTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/NetworkInputTracker.scala index cad68e248ab29..6ac54cf7be29e 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/NetworkInputTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/NetworkInputTracker.scala @@ -153,19 +153,17 @@ class NetworkInputTracker(ssc: StreamingContext) extends Logging { def startReceivers() { val receivers = networkInputStreams.map(nis => { val rcvr = nis.getReceiver() - rcvr.setStreamId(nis.id) + rcvr.setReceiverId(nis.id) rcvr }) // Right now, we only honor preferences if all receivers have them - val hasLocationPreferences = receivers.map(_.getLocationPreference().isDefined) - .reduce(_ && _) + val hasLocationPreferences = receivers.map(_.preferredLocation.isDefined).reduce(_ && _) // Create the parallel collection of receivers to distributed them on the worker nodes val tempRDD = if (hasLocationPreferences) { - val receiversWithPreferences = - receivers.map(r => (r, Seq(r.getLocationPreference().toString))) + val receiversWithPreferences = receivers.map(r => (r, Seq(r.preferredLocation.get))) ssc.sc.makeRDD[NetworkReceiver[_]](receiversWithPreferences) } else { @@ -177,7 +175,7 @@ class NetworkInputTracker(ssc: StreamingContext) extends Logging { if (!iterator.hasNext) { throw new Exception("Could not start receiver as details not found.") } - iterator.next().start() + iterator.next().handler.run() } // Run the dummy Spark job to ensure that all slaves have registered. // This avoids all the receivers to be scheduled on the same node. diff --git a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala index 7df206241beb6..e29685bc91fb6 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala @@ -315,18 +315,16 @@ class TestActor(port: Int) extends Actor with Receiver { /** This is a receiver to test multiple threads inserting data using block generator */ class MultiThreadTestReceiver(numThreads: Int, numRecordsPerThread: Int) - extends NetworkReceiver[Int] { + extends NetworkReceiver[Int](StorageLevel.MEMORY_ONLY_SER) with Logging { lazy val executorPool = Executors.newFixedThreadPool(numThreads) - lazy val blockGenerator = new BlockGenerator(StorageLevel.MEMORY_ONLY) lazy val finishCount = new AtomicInteger(0) - protected def onStart() { - blockGenerator.start() + def onStart() { (1 to numThreads).map(threadId => { val runnable = new Runnable { def run() { (1 to numRecordsPerThread).foreach(i => - blockGenerator += (threadId * numRecordsPerThread + i) ) + store(threadId * numRecordsPerThread + i) ) if (finishCount.incrementAndGet == numThreads) { MultiThreadTestReceiver.haveAllThreadsFinished = true } @@ -337,7 +335,7 @@ class MultiThreadTestReceiver(numThreads: Int, numRecordsPerThread: Int) }) } - protected def onStop() { + def onStop() { executorPool.shutdown() } } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala index 717da8e00462b..4d8c82d78ba40 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala @@ -145,6 +145,9 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts { ssc = null assert(sc.makeRDD(1 to 100).collect().size === 100) ssc = new StreamingContext(sc, batchDuration) + addInputStream(ssc).register + ssc.start() + ssc.stop() } test("awaitTermination") { @@ -215,4 +218,4 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts { } } -class TestException(msg: String) extends Exception(msg) \ No newline at end of file +class TestException(msg: String) extends Exception(msg)