diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingQueryListenerBus.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingQueryListenerBus.scala index 22e4c6380fcd5..a2153d27e9fef 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingQueryListenerBus.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingQueryListenerBus.scala @@ -17,6 +17,10 @@ package org.apache.spark.sql.execution.streaming +import java.util.UUID + +import scala.collection.mutable + import org.apache.spark.scheduler.{LiveListenerBus, SparkListener, SparkListenerEvent} import org.apache.spark.sql.streaming.StreamingQueryListener import org.apache.spark.util.ListenerBus @@ -25,7 +29,11 @@ import org.apache.spark.util.ListenerBus * A bus to forward events to [[StreamingQueryListener]]s. This one will send received * [[StreamingQueryListener.Event]]s to the Spark listener bus. It also registers itself with * Spark listener bus, so that it can receive [[StreamingQueryListener.Event]]s and dispatch them - * to StreamingQueryListener. + * to StreamingQueryListeners. + * + * Note that each bus and its registered listeners are associated with a single SparkSession + * and StreamingQueryManager. So this bus will dispatch events to registered listeners for only + * those queries that were started in the associated SparkSession. */ class StreamingQueryListenerBus(sparkListenerBus: LiveListenerBus) extends SparkListener with ListenerBus[StreamingQueryListener, StreamingQueryListener.Event] { @@ -35,12 +43,30 @@ class StreamingQueryListenerBus(sparkListenerBus: LiveListenerBus) sparkListenerBus.addListener(this) /** - * Post a StreamingQueryListener event to the Spark listener bus asynchronously. This event will - * be dispatched to all StreamingQueryListener in the thread of the Spark listener bus. + * RunIds of active queries whose events are supposed to be forwarded by this ListenerBus + * to registered `StreamingQueryListeners`. + * + * Note 1: We need to track runIds instead of ids because the runId is unique for every started + * query, even it its a restart. So even if a query is restarted, this bus will identify them + * separately and correctly account for the restart. + * + * Note 2: This list needs to be maintained separately from the + * `StreamingQueryManager.activeQueries` because a terminated query is cleared from + * `StreamingQueryManager.activeQueries` as soon as it is stopped, but the this ListenerBus + * must clear a query only after the termination event of that query has been posted. + */ + private val activeQueryRunIds = new mutable.HashSet[UUID] + + /** + * Post a StreamingQueryListener event to the added StreamingQueryListeners. + * Note that only the QueryStarted event is posted to the listener synchronously. Other events + * are dispatched to Spark listener bus. This method is guaranteed to be called by queries in + * the same SparkSession as this listener. */ def post(event: StreamingQueryListener.Event) { event match { case s: QueryStartedEvent => + activeQueryRunIds.synchronized { activeQueryRunIds += s.runId } sparkListenerBus.post(s) // post to local listeners to trigger callbacks postToAll(s) @@ -63,18 +89,32 @@ class StreamingQueryListenerBus(sparkListenerBus: LiveListenerBus) } } + /** + * Dispatch events to registered StreamingQueryListeners. Only the events associated queries + * started in the same SparkSession as this ListenerBus will be dispatched to the listeners. + */ override protected def doPostEvent( listener: StreamingQueryListener, event: StreamingQueryListener.Event): Unit = { + def shouldReport(runId: UUID): Boolean = { + activeQueryRunIds.synchronized { activeQueryRunIds.contains(runId) } + } + event match { case queryStarted: QueryStartedEvent => - listener.onQueryStarted(queryStarted) + if (shouldReport(queryStarted.runId)) { + listener.onQueryStarted(queryStarted) + } case queryProgress: QueryProgressEvent => - listener.onQueryProgress(queryProgress) + if (shouldReport(queryProgress.progress.runId)) { + listener.onQueryProgress(queryProgress) + } case queryTerminated: QueryTerminatedEvent => - listener.onQueryTerminated(queryTerminated) + if (shouldReport(queryTerminated.runId)) { + listener.onQueryTerminated(queryTerminated) + activeQueryRunIds.synchronized { activeQueryRunIds -= queryTerminated.runId } + } case _ => } } - } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala index b370845481ed7..b699be217e67f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala @@ -70,11 +70,11 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) def schema: StructType = encoder.schema - def toDS()(implicit sqlContext: SQLContext): Dataset[A] = { + def toDS(): Dataset[A] = { Dataset(sqlContext.sparkSession, logicalPlan) } - def toDF()(implicit sqlContext: SQLContext): DataFrame = { + def toDF(): DataFrame = { Dataset.ofRows(sqlContext.sparkSession, logicalPlan) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index 43322651296b9..10f267e115320 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -231,8 +231,8 @@ trait StreamTest extends QueryTest with SharedSQLContext with Timeouts { outputMode: OutputMode = OutputMode.Append)(actions: StreamAction*): Unit = { val stream = _stream.toDF() + val sparkSession = stream.sparkSession // use the session in DF, not the default session var pos = 0 - var currentPlan: LogicalPlan = stream.logicalPlan var currentStream: StreamExecution = null var lastStream: StreamExecution = null val awaiting = new mutable.HashMap[Int, Offset]() // source index -> offset to wait for @@ -319,7 +319,6 @@ trait StreamTest extends QueryTest with SharedSQLContext with Timeouts { """.stripMargin) } - val testThread = Thread.currentThread() val metadataRoot = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath var manualClockExpectedTime = -1L try { @@ -337,14 +336,16 @@ trait StreamTest extends QueryTest with SharedSQLContext with Timeouts { additionalConfs.foreach(pair => { val value = - if (spark.conf.contains(pair._1)) Some(spark.conf.get(pair._1)) else None + if (sparkSession.conf.contains(pair._1)) { + Some(sparkSession.conf.get(pair._1)) + } else None resetConfValues(pair._1) = value - spark.conf.set(pair._1, pair._2) + sparkSession.conf.set(pair._1, pair._2) }) lastStream = currentStream currentStream = - spark + sparkSession .streams .startQuery( None, @@ -518,8 +519,8 @@ trait StreamTest extends QueryTest with SharedSQLContext with Timeouts { // Rollback prev configuration values resetConfValues.foreach { - case (key, Some(value)) => spark.conf.set(key, value) - case (key, None) => spark.conf.unset(key) + case (key, Some(value)) => sparkSession.conf.set(key, value) + case (key, None) => sparkSession.conf.unset(key) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala index 1cd503c6de696..229eaf3189939 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.streaming import java.util.UUID import scala.collection.mutable +import scala.concurrent.duration._ import org.scalactic.TolerantNumerics import org.scalatest.concurrent.AsyncAssertions.Waiter @@ -30,6 +31,7 @@ import org.scalatest.PrivateMethodTester._ import org.apache.spark.SparkException import org.apache.spark.scheduler._ +import org.apache.spark.sql.{Encoder, SparkSession} import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.StreamingQueryListener._ @@ -45,7 +47,7 @@ class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter { after { spark.streams.active.foreach(_.stop()) assert(spark.streams.active.isEmpty) - assert(addedListeners.isEmpty) + assert(addedListeners().isEmpty) // Make sure we don't leak any events to the next test spark.sparkContext.listenerBus.waitUntilEmpty(10000) } @@ -148,7 +150,7 @@ class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter { assert(isListenerActive(listener1) === false) assert(isListenerActive(listener2) === true) } finally { - addedListeners.foreach(spark.streams.removeListener) + addedListeners().foreach(spark.streams.removeListener) } } @@ -251,6 +253,57 @@ class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter { } } + test("listener only posts events from queries started in the related sessions") { + val session1 = spark.newSession() + val session2 = spark.newSession() + val collector1 = new EventCollector + val collector2 = new EventCollector + + def runQuery(session: SparkSession): Unit = { + collector1.reset() + collector2.reset() + val mem = MemoryStream[Int](implicitly[Encoder[Int]], session.sqlContext) + testStream(mem.toDS)( + AddData(mem, 1, 2, 3), + CheckAnswer(1, 2, 3) + ) + session.sparkContext.listenerBus.waitUntilEmpty(5000) + } + + def assertEventsCollected(collector: EventCollector): Unit = { + assert(collector.startEvent !== null) + assert(collector.progressEvents.nonEmpty) + assert(collector.terminationEvent !== null) + } + + def assertEventsNotCollected(collector: EventCollector): Unit = { + assert(collector.startEvent === null) + assert(collector.progressEvents.isEmpty) + assert(collector.terminationEvent === null) + } + + assert(session1.ne(session2)) + assert(session1.streams.ne(session2.streams)) + + withListenerAdded(collector1, session1) { + assert(addedListeners(session1).nonEmpty) + + withListenerAdded(collector2, session2) { + assert(addedListeners(session2).nonEmpty) + + // query on session1 should send events only to collector1 + runQuery(session1) + assertEventsCollected(collector1) + assertEventsNotCollected(collector2) + + // query on session2 should send events only to collector2 + runQuery(session2) + assertEventsCollected(collector2) + assertEventsNotCollected(collector1) + } + } + } + testQuietly("ReplayListenerBus should ignore broken event jsons generated in 2.0.0") { // query-event-logs-version-2.0.0.txt has all types of events generated by // Structured Streaming in Spark 2.0.0. @@ -298,21 +351,23 @@ class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter { } } - private def withListenerAdded(listener: StreamingQueryListener)(body: => Unit): Unit = { + private def withListenerAdded( + listener: StreamingQueryListener, + session: SparkSession = spark)(body: => Unit): Unit = { try { failAfter(streamingTimeout) { - spark.streams.addListener(listener) + session.streams.addListener(listener) body } } finally { - spark.streams.removeListener(listener) + session.streams.removeListener(listener) } } - private def addedListeners(): Array[StreamingQueryListener] = { + private def addedListeners(session: SparkSession = spark): Array[StreamingQueryListener] = { val listenerBusMethod = PrivateMethod[StreamingQueryListenerBus]('listenerBus) - val listenerBus = spark.streams invokePrivate listenerBusMethod() + val listenerBus = session.streams invokePrivate listenerBusMethod() listenerBus.listeners.toArray.map(_.asInstanceOf[StreamingQueryListener]) }