diff --git a/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala b/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala index 9c7f698840778..85b354ff4aa0d 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala @@ -45,7 +45,7 @@ final private[streaming] class DStreamGraph extends Serializable with Logging { startTime = time outputStreams.foreach(_.initialize(zeroTime)) outputStreams.foreach(_.remember(rememberDuration)) - outputStreams.foreach(_.validate) + outputStreams.foreach(_.validateAtStart) inputStreams.par.foreach(_.start()) } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala index 7092a3d3f0b86..64de7526a6a34 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala @@ -60,6 +60,8 @@ abstract class DStream[T: ClassTag] ( @transient private[streaming] var ssc: StreamingContext ) extends Serializable with Logging { + validateAtInit() + // ======================================================================= // Methods that should be implemented by subclasses of DStream // ======================================================================= @@ -171,7 +173,22 @@ abstract class DStream[T: ClassTag] ( dependencies.foreach(_.initialize(zeroTime)) } - private[streaming] def validate() { + private def validateAtInit(): Unit = { + ssc.getState() match { + case StreamingContextState.INITIALIZED => + // good to go + case StreamingContextState.ACTIVE => + throw new SparkException( + "Adding new inputs, transformations, and output operations after " + + "starting a context is not supported") + case StreamingContextState.STOPPED => + throw new SparkException( + "Adding new inputs, transformations, and output operations after " + + "stopping a context is not supported") + } + } + + private[streaming] def validateAtStart() { assert(rememberDuration != null, "Remember duration is set to null") assert( @@ -226,7 +243,7 @@ abstract class DStream[T: ClassTag] ( math.ceil(rememberDuration.milliseconds / 1000.0).toInt + " seconds." ) - dependencies.foreach(_.validate()) + dependencies.foreach(_.validateAtStart()) logInfo("Slide time = " + slideDuration) logInfo("Storage level = " + storageLevel) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala index 5d09b234f77ce..5f93332896de1 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala @@ -651,6 +651,45 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts w testPackage.test() } + test("throw exception on using active or stopped context") { + val conf = new SparkConf() + .setMaster(master) + .setAppName(appName) + .set("spark.streaming.clock", "org.apache.spark.util.ManualClock") + ssc = new StreamingContext(conf, batchDuration) + require(ssc.getState() === StreamingContextState.INITIALIZED) + val input = addInputStream(ssc) + val transformed = input.map { x => x} + transformed.foreachRDD { rdd => rdd.count } + + def testForException(clue: String, expectedErrorMsg: String)(body: => Unit): Unit = { + withClue(clue) { + val ex = intercept[SparkException] { + body + } + assert(ex.getMessage.toLowerCase().contains(expectedErrorMsg)) + } + } + + ssc.start() + require(ssc.getState() === StreamingContextState.ACTIVE) + testForException("no error on adding input after start", "start") { + addInputStream(ssc) } + testForException("no error on adding transformation after start", "start") { + input.map { x => x * 2 } } + testForException("no error on adding output operation after start", "start") { + transformed.foreachRDD { rdd => rdd.collect() } } + + ssc.stop() + require(ssc.getState() === StreamingContextState.STOPPED) + testForException("no error on adding input after stop", "stop") { + addInputStream(ssc) } + testForException("no error on adding transformation after stop", "stop") { + input.map { x => x * 2 } } + testForException("no error on adding output operation after stop", "stop") { + transformed.foreachRDD { rdd => rdd.collect() } } + } + def addInputStream(s: StreamingContext): DStream[Int] = { val input = (1 to 100).map(i => 1 to i) val inputStream = new TestInputStream(s, input, 1)