Skip to content

Commit

Permalink
[SPARK-18758][SS] StreamingQueryListener events from a StreamingQuery…
Browse files Browse the repository at this point in the history
… should be sent only to the listeners in the same session as the query

## What changes were proposed in this pull request?

Listeners added with `sparkSession.streams.addListener(l)` are added to a SparkSession. So events only from queries in the same session as a listener should be posted to the listener. Currently, all the events gets rerouted through the Spark's main listener bus, that is,
- StreamingQuery posts event to StreamingQueryListenerBus. Only the queries associated with the same session as the bus posts events to it.
- StreamingQueryListenerBus posts event to Spark's main LiveListenerBus as a SparkEvent.
- StreamingQueryListenerBus also subscribes to LiveListenerBus events thus getting back the posted event in a different thread.
- The received is posted to the registered listeners.

The problem is that *all StreamingQueryListenerBuses in all sessions* gets the events and posts them to their listeners. This is wrong.

In this PR, I solve it by making StreamingQueryListenerBus track active queries (by their runIds) when a query posts the QueryStarted event to the bus. This allows the rerouted events to be filtered using the tracked queries.

Note that this list needs to be maintained separately
from the `StreamingQueryManager.activeQueries` because a terminated query is cleared from
`StreamingQueryManager.activeQueries` as soon as it is stopped, but the this ListenerBus must
clear a query only after the termination event of that query has been posted lazily, much after the query has been terminated.

Credit goes to zsxwing for coming up with the initial idea.

## How was this patch tested?
Updated test harness code to use the correct session, and added new unit test.

Author: Tathagata Das <[email protected]>

Closes #16186 from tdas/SPARK-18758.

(cherry picked from commit 9ab725e)
Signed-off-by: Tathagata Das <[email protected]>
  • Loading branch information
tdas committed Dec 8, 2016
1 parent 839c2eb commit 617ce3b
Show file tree
Hide file tree
Showing 4 changed files with 119 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@

package org.apache.spark.sql.execution.streaming

import java.util.UUID

import scala.collection.mutable

import org.apache.spark.scheduler.{LiveListenerBus, SparkListener, SparkListenerEvent}
import org.apache.spark.sql.streaming.StreamingQueryListener
import org.apache.spark.util.ListenerBus
Expand All @@ -25,7 +29,11 @@ import org.apache.spark.util.ListenerBus
* A bus to forward events to [[StreamingQueryListener]]s. This one will send received
* [[StreamingQueryListener.Event]]s to the Spark listener bus. It also registers itself with
* Spark listener bus, so that it can receive [[StreamingQueryListener.Event]]s and dispatch them
* to StreamingQueryListener.
* to StreamingQueryListeners.
*
* Note that each bus and its registered listeners are associated with a single SparkSession
* and StreamingQueryManager. So this bus will dispatch events to registered listeners for only
* those queries that were started in the associated SparkSession.
*/
class StreamingQueryListenerBus(sparkListenerBus: LiveListenerBus)
extends SparkListener with ListenerBus[StreamingQueryListener, StreamingQueryListener.Event] {
Expand All @@ -35,12 +43,30 @@ class StreamingQueryListenerBus(sparkListenerBus: LiveListenerBus)
sparkListenerBus.addListener(this)

/**
* Post a StreamingQueryListener event to the Spark listener bus asynchronously. This event will
* be dispatched to all StreamingQueryListener in the thread of the Spark listener bus.
* RunIds of active queries whose events are supposed to be forwarded by this ListenerBus
* to registered `StreamingQueryListeners`.
*
* Note 1: We need to track runIds instead of ids because the runId is unique for every started
* query, even it its a restart. So even if a query is restarted, this bus will identify them
* separately and correctly account for the restart.
*
* Note 2: This list needs to be maintained separately from the
* `StreamingQueryManager.activeQueries` because a terminated query is cleared from
* `StreamingQueryManager.activeQueries` as soon as it is stopped, but the this ListenerBus
* must clear a query only after the termination event of that query has been posted.
*/
private val activeQueryRunIds = new mutable.HashSet[UUID]

/**
* Post a StreamingQueryListener event to the added StreamingQueryListeners.
* Note that only the QueryStarted event is posted to the listener synchronously. Other events
* are dispatched to Spark listener bus. This method is guaranteed to be called by queries in
* the same SparkSession as this listener.
*/
def post(event: StreamingQueryListener.Event) {
event match {
case s: QueryStartedEvent =>
activeQueryRunIds.synchronized { activeQueryRunIds += s.runId }
sparkListenerBus.post(s)
// post to local listeners to trigger callbacks
postToAll(s)
Expand All @@ -63,18 +89,32 @@ class StreamingQueryListenerBus(sparkListenerBus: LiveListenerBus)
}
}

/**
* Dispatch events to registered StreamingQueryListeners. Only the events associated queries
* started in the same SparkSession as this ListenerBus will be dispatched to the listeners.
*/
override protected def doPostEvent(
listener: StreamingQueryListener,
event: StreamingQueryListener.Event): Unit = {
def shouldReport(runId: UUID): Boolean = {
activeQueryRunIds.synchronized { activeQueryRunIds.contains(runId) }
}

event match {
case queryStarted: QueryStartedEvent =>
listener.onQueryStarted(queryStarted)
if (shouldReport(queryStarted.runId)) {
listener.onQueryStarted(queryStarted)
}
case queryProgress: QueryProgressEvent =>
listener.onQueryProgress(queryProgress)
if (shouldReport(queryProgress.progress.runId)) {
listener.onQueryProgress(queryProgress)
}
case queryTerminated: QueryTerminatedEvent =>
listener.onQueryTerminated(queryTerminated)
if (shouldReport(queryTerminated.runId)) {
listener.onQueryTerminated(queryTerminated)
activeQueryRunIds.synchronized { activeQueryRunIds -= queryTerminated.runId }
}
case _ =>
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,11 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext)

def schema: StructType = encoder.schema

def toDS()(implicit sqlContext: SQLContext): Dataset[A] = {
def toDS(): Dataset[A] = {
Dataset(sqlContext.sparkSession, logicalPlan)
}

def toDF()(implicit sqlContext: SQLContext): DataFrame = {
def toDF(): DataFrame = {
Dataset.ofRows(sqlContext.sparkSession, logicalPlan)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -231,8 +231,8 @@ trait StreamTest extends QueryTest with SharedSQLContext with Timeouts {
outputMode: OutputMode = OutputMode.Append)(actions: StreamAction*): Unit = {

val stream = _stream.toDF()
val sparkSession = stream.sparkSession // use the session in DF, not the default session
var pos = 0
var currentPlan: LogicalPlan = stream.logicalPlan
var currentStream: StreamExecution = null
var lastStream: StreamExecution = null
val awaiting = new mutable.HashMap[Int, Offset]() // source index -> offset to wait for
Expand Down Expand Up @@ -319,7 +319,6 @@ trait StreamTest extends QueryTest with SharedSQLContext with Timeouts {
""".stripMargin)
}

val testThread = Thread.currentThread()
val metadataRoot = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
var manualClockExpectedTime = -1L
try {
Expand All @@ -337,14 +336,16 @@ trait StreamTest extends QueryTest with SharedSQLContext with Timeouts {

additionalConfs.foreach(pair => {
val value =
if (spark.conf.contains(pair._1)) Some(spark.conf.get(pair._1)) else None
if (sparkSession.conf.contains(pair._1)) {
Some(sparkSession.conf.get(pair._1))
} else None
resetConfValues(pair._1) = value
spark.conf.set(pair._1, pair._2)
sparkSession.conf.set(pair._1, pair._2)
})

lastStream = currentStream
currentStream =
spark
sparkSession
.streams
.startQuery(
None,
Expand Down Expand Up @@ -518,8 +519,8 @@ trait StreamTest extends QueryTest with SharedSQLContext with Timeouts {

// Rollback prev configuration values
resetConfValues.foreach {
case (key, Some(value)) => spark.conf.set(key, value)
case (key, None) => spark.conf.unset(key)
case (key, Some(value)) => sparkSession.conf.set(key, value)
case (key, None) => sparkSession.conf.unset(key)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.sql.streaming
import java.util.UUID

import scala.collection.mutable
import scala.concurrent.duration._

import org.scalactic.TolerantNumerics
import org.scalatest.concurrent.AsyncAssertions.Waiter
Expand All @@ -30,6 +31,7 @@ import org.scalatest.PrivateMethodTester._

import org.apache.spark.SparkException
import org.apache.spark.scheduler._
import org.apache.spark.sql.{Encoder, SparkSession}
import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.streaming.StreamingQueryListener._
Expand All @@ -45,7 +47,7 @@ class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter {
after {
spark.streams.active.foreach(_.stop())
assert(spark.streams.active.isEmpty)
assert(addedListeners.isEmpty)
assert(addedListeners().isEmpty)
// Make sure we don't leak any events to the next test
spark.sparkContext.listenerBus.waitUntilEmpty(10000)
}
Expand Down Expand Up @@ -148,7 +150,7 @@ class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter {
assert(isListenerActive(listener1) === false)
assert(isListenerActive(listener2) === true)
} finally {
addedListeners.foreach(spark.streams.removeListener)
addedListeners().foreach(spark.streams.removeListener)
}
}

Expand Down Expand Up @@ -251,6 +253,57 @@ class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter {
}
}

test("listener only posts events from queries started in the related sessions") {
val session1 = spark.newSession()
val session2 = spark.newSession()
val collector1 = new EventCollector
val collector2 = new EventCollector

def runQuery(session: SparkSession): Unit = {
collector1.reset()
collector2.reset()
val mem = MemoryStream[Int](implicitly[Encoder[Int]], session.sqlContext)
testStream(mem.toDS)(
AddData(mem, 1, 2, 3),
CheckAnswer(1, 2, 3)
)
session.sparkContext.listenerBus.waitUntilEmpty(5000)
}

def assertEventsCollected(collector: EventCollector): Unit = {
assert(collector.startEvent !== null)
assert(collector.progressEvents.nonEmpty)
assert(collector.terminationEvent !== null)
}

def assertEventsNotCollected(collector: EventCollector): Unit = {
assert(collector.startEvent === null)
assert(collector.progressEvents.isEmpty)
assert(collector.terminationEvent === null)
}

assert(session1.ne(session2))
assert(session1.streams.ne(session2.streams))

withListenerAdded(collector1, session1) {
assert(addedListeners(session1).nonEmpty)

withListenerAdded(collector2, session2) {
assert(addedListeners(session2).nonEmpty)

// query on session1 should send events only to collector1
runQuery(session1)
assertEventsCollected(collector1)
assertEventsNotCollected(collector2)

// query on session2 should send events only to collector2
runQuery(session2)
assertEventsCollected(collector2)
assertEventsNotCollected(collector1)
}
}
}

testQuietly("ReplayListenerBus should ignore broken event jsons generated in 2.0.0") {
// query-event-logs-version-2.0.0.txt has all types of events generated by
// Structured Streaming in Spark 2.0.0.
Expand Down Expand Up @@ -298,21 +351,23 @@ class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter {
}
}

private def withListenerAdded(listener: StreamingQueryListener)(body: => Unit): Unit = {
private def withListenerAdded(
listener: StreamingQueryListener,
session: SparkSession = spark)(body: => Unit): Unit = {
try {
failAfter(streamingTimeout) {
spark.streams.addListener(listener)
session.streams.addListener(listener)
body
}
} finally {
spark.streams.removeListener(listener)
session.streams.removeListener(listener)
}
}

private def addedListeners(): Array[StreamingQueryListener] = {
private def addedListeners(session: SparkSession = spark): Array[StreamingQueryListener] = {
val listenerBusMethod =
PrivateMethod[StreamingQueryListenerBus]('listenerBus)
val listenerBus = spark.streams invokePrivate listenerBusMethod()
val listenerBus = session.streams invokePrivate listenerBusMethod()
listenerBus.listeners.toArray.map(_.asInstanceOf[StreamingQueryListener])
}

Expand Down

0 comments on commit 617ce3b

Please sign in to comment.