diff --git a/dev/run-tests.py b/dev/run-tests.py
index 1f0d218514f92..90535fd3b7b03 100755
--- a/dev/run-tests.py
+++ b/dev/run-tests.py
@@ -294,7 +294,8 @@ def build_spark_sbt(hadoop_version):
sbt_goals = ["package",
"assembly/assembly",
"streaming-kafka-assembly/assembly",
- "streaming-flume-assembly/assembly"]
+ "streaming-flume-assembly/assembly",
+ "streaming-mqtt-assembly/assembly"]
profiles_and_goals = build_profiles + sbt_goals
print("[info] Building Spark (w/Hive 0.13.1) using SBT with these arguments: ",
diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py
index 993583e2f4119..45a03c7ea7447 100644
--- a/dev/sparktestsupport/modules.py
+++ b/dev/sparktestsupport/modules.py
@@ -170,6 +170,7 @@ def contains_file(self, filename):
dependencies=[streaming],
source_file_regexes=[
"external/mqtt",
+ "external/mqtt-assembly",
],
sbt_test_goals=[
"streaming-mqtt/test",
@@ -290,7 +291,7 @@ def contains_file(self, filename):
pyspark_streaming = Module(
name="pyspark-streaming",
- dependencies=[pyspark_core, streaming, streaming_kafka, streaming_flume_assembly],
+ dependencies=[pyspark_core, streaming, streaming_kafka, streaming_flume_assembly, streaming_mqtt],
source_file_regexes=[
"python/pyspark/streaming"
],
diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md
index e72d5580dae55..33d835ba1c381 100644
--- a/docs/streaming-programming-guide.md
+++ b/docs/streaming-programming-guide.md
@@ -683,7 +683,7 @@ for Java, and [StreamingContext](api/python/pyspark.streaming.html#pyspark.strea
{:.no_toc}
Python API As of Spark {{site.SPARK_VERSION_SHORT}},
-out of these sources, *only* Kafka and Flume are available in the Python API. We will add more advanced sources in the Python API in future.
+out of these sources, *only* Kafka, Flume and MQTT are available in the Python API. We will add more advanced sources in the Python API in future.
This category of sources require interfacing with external non-Spark libraries, some of them with
complex dependencies (e.g., Kafka and Flume). Hence, to minimize issues related to version conflicts
diff --git a/external/mqtt/pom.xml b/external/mqtt/pom.xml
index 0e41e5781784b..a28dd3603503a 100644
--- a/external/mqtt/pom.xml
+++ b/external/mqtt/pom.xml
@@ -72,7 +72,6 @@
org.apache.activemq
activemq-core
5.7.0
- test
diff --git a/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTTestUtils.scala b/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTTestUtils.scala
new file mode 100644
index 0000000000000..555c26aad1811
--- /dev/null
+++ b/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTTestUtils.scala
@@ -0,0 +1,124 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming.mqtt
+
+import java.net.{ServerSocket, URI}
+import java.util.concurrent.{TimeUnit, CountDownLatch}
+
+import scala.language.postfixOps
+
+import org.apache.activemq.broker.{BrokerService, TransportConnector}
+import org.apache.commons.lang3.RandomUtils
+import org.eclipse.paho.client.mqttv3._
+import org.eclipse.paho.client.mqttv3.persist.MqttDefaultFilePersistence
+
+import org.apache.spark.streaming.{StreamingContext, Milliseconds}
+import org.apache.spark.streaming.scheduler.StreamingListener
+import org.apache.spark.streaming.scheduler.StreamingListenerReceiverStarted
+import org.apache.spark.util.Utils
+import org.apache.spark.{Logging, SparkConf}
+
+/**
+ * Share codes for Scala and Python unit tests
+ */
+private class MQTTTestUtils extends Logging {
+
+ private val persistenceDir = Utils.createTempDir()
+ private val brokerHost = "localhost"
+ private var brokerPort = findFreePort()
+
+ private var broker: BrokerService = _
+ private var connector: TransportConnector = _
+
+ def brokerUri: String = {
+ s"$brokerHost:$brokerPort"
+ }
+
+ def setup(): Unit = {
+ broker = new BrokerService()
+ broker.setDataDirectoryFile(Utils.createTempDir())
+ connector = new TransportConnector()
+ connector.setName("mqtt")
+ connector.setUri(new URI("mqtt://" + brokerUri))
+ broker.addConnector(connector)
+ broker.start()
+ }
+
+ def teardown(): Unit = {
+ if (broker != null) {
+ broker.stop()
+ broker = null
+ }
+ if (connector != null) {
+ connector.stop()
+ connector = null
+ }
+ }
+
+ private def findFreePort(): Int = {
+ val candidatePort = RandomUtils.nextInt(1024, 65536)
+ Utils.startServiceOnPort(candidatePort, (trialPort: Int) => {
+ val socket = new ServerSocket(trialPort)
+ socket.close()
+ (null, trialPort)
+ }, new SparkConf())._2
+ }
+
+ def publishData(topic: String, data: String): Unit = {
+ var client: MqttClient = null
+ try {
+ val persistence = new MqttDefaultFilePersistence(persistenceDir.getAbsolutePath)
+ client = new MqttClient("tcp://" + brokerUri, MqttClient.generateClientId(), persistence)
+ client.connect()
+ if (client.isConnected) {
+ val msgTopic = client.getTopic(topic)
+ val message = new MqttMessage(data.getBytes("utf-8"))
+ message.setQos(1)
+ message.setRetained(true)
+
+ for (i <- 0 to 10) {
+ try {
+ msgTopic.publish(message)
+ } catch {
+ case e: MqttException if e.getReasonCode == MqttException.REASON_CODE_MAX_INFLIGHT =>
+ // wait for Spark streaming to consume something from the message queue
+ Thread.sleep(50)
+ }
+ }
+ }
+ } finally {
+ client.disconnect()
+ client.close()
+ client = null
+ }
+ }
+
+ /**
+ * Block until at least one receiver has started or timeout occurs.
+ */
+ def waitForReceiverToStart(ssc: StreamingContext) = {
+ val latch = new CountDownLatch(1)
+ ssc.addStreamingListener(new StreamingListener {
+ override def onReceiverStarted(receiverStarted: StreamingListenerReceiverStarted) {
+ latch.countDown()
+ }
+ })
+
+ assert(latch.await(10, TimeUnit.SECONDS), "Timeout waiting for receiver to start.")
+ }
+}
diff --git a/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala b/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala
index c4bf5aa7869bb..4d352cba96a3a 100644
--- a/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala
+++ b/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala
@@ -17,61 +17,45 @@
package org.apache.spark.streaming.mqtt
-import java.net.{URI, ServerSocket}
-import java.util.concurrent.CountDownLatch
-import java.util.concurrent.TimeUnit
-
import scala.concurrent.duration._
import scala.language.postfixOps
-import org.apache.activemq.broker.{TransportConnector, BrokerService}
-import org.apache.commons.lang3.RandomUtils
-import org.eclipse.paho.client.mqttv3._
-import org.eclipse.paho.client.mqttv3.persist.MqttDefaultFilePersistence
-
-import org.scalatest.BeforeAndAfter
+import org.scalatest.BeforeAndAfterAll
import org.scalatest.concurrent.Eventually
-import org.apache.spark.streaming.{Milliseconds, StreamingContext}
-import org.apache.spark.storage.StorageLevel
-import org.apache.spark.streaming.dstream.ReceiverInputDStream
-import org.apache.spark.streaming.scheduler.StreamingListener
-import org.apache.spark.streaming.scheduler.StreamingListenerReceiverStarted
import org.apache.spark.{SparkConf, SparkFunSuite}
-import org.apache.spark.util.Utils
-
-class MQTTStreamSuite extends SparkFunSuite with Eventually with BeforeAndAfter {
+import org.apache.spark.storage.StorageLevel
+import org.apache.spark.streaming.{Milliseconds, StreamingContext}
- private val batchDuration = Milliseconds(500)
- private val master = "local[2]"
- private val framework = this.getClass.getSimpleName
- private val freePort = findFreePort()
- private val brokerUri = "//localhost:" + freePort
- private val topic = "def"
- private val persistenceDir = Utils.createTempDir()
+class MQTTStreamSuite extends SparkFunSuite with Eventually with BeforeAndAfterAll {
+ private val topic = "topic"
private var ssc: StreamingContext = _
- private var broker: BrokerService = _
- private var connector: TransportConnector = _
+ private var MQTTTestUtils: MQTTTestUtils = _
- before {
- ssc = new StreamingContext(master, framework, batchDuration)
- setupMQTT()
+ override def beforeAll(): Unit = {
+ MQTTTestUtils = new MQTTTestUtils
+ MQTTTestUtils.setup()
}
- after {
+ override def afterAll(): Unit = {
if (ssc != null) {
ssc.stop()
ssc = null
}
- Utils.deleteRecursively(persistenceDir)
- tearDownMQTT()
+
+ if (MQTTTestUtils != null) {
+ MQTTTestUtils.teardown()
+ MQTTTestUtils = null
+ }
}
test("mqtt input stream") {
+ val sparkConf = new SparkConf().setMaster("local[4]").setAppName(this.getClass.getSimpleName)
+ ssc = new StreamingContext(sparkConf, Milliseconds(500))
val sendMessage = "MQTT demo for spark streaming"
val receiveStream =
- MQTTUtils.createStream(ssc, "tcp:" + brokerUri, topic, StorageLevel.MEMORY_ONLY)
+ MQTTUtils.createStream(ssc, "tcp://" + MQTTTestUtils.brokerUri, topic, StorageLevel.MEMORY_ONLY)
@volatile var receiveMessage: List[String] = List()
receiveStream.foreachRDD { rdd =>
if (rdd.collect.length > 0) {
@@ -83,85 +67,13 @@ class MQTTStreamSuite extends SparkFunSuite with Eventually with BeforeAndAfter
// wait for the receiver to start before publishing data, or we risk failing
// the test nondeterministically. See SPARK-4631
- waitForReceiverToStart()
+ MQTTTestUtils.waitForReceiverToStart(ssc)
+
+ MQTTTestUtils.publishData(topic, sendMessage)
- publishData(sendMessage)
eventually(timeout(10000 milliseconds), interval(100 milliseconds)) {
assert(sendMessage.equals(receiveMessage(0)))
}
ssc.stop()
}
-
- private def setupMQTT() {
- broker = new BrokerService()
- broker.setDataDirectoryFile(Utils.createTempDir())
- connector = new TransportConnector()
- connector.setName("mqtt")
- connector.setUri(new URI("mqtt:" + brokerUri))
- broker.addConnector(connector)
- broker.start()
- }
-
- private def tearDownMQTT() {
- if (broker != null) {
- broker.stop()
- broker = null
- }
- if (connector != null) {
- connector.stop()
- connector = null
- }
- }
-
- private def findFreePort(): Int = {
- val candidatePort = RandomUtils.nextInt(1024, 65536)
- Utils.startServiceOnPort(candidatePort, (trialPort: Int) => {
- val socket = new ServerSocket(trialPort)
- socket.close()
- (null, trialPort)
- }, new SparkConf())._2
- }
-
- def publishData(data: String): Unit = {
- var client: MqttClient = null
- try {
- val persistence = new MqttDefaultFilePersistence(persistenceDir.getAbsolutePath)
- client = new MqttClient("tcp:" + brokerUri, MqttClient.generateClientId(), persistence)
- client.connect()
- if (client.isConnected) {
- val msgTopic = client.getTopic(topic)
- val message = new MqttMessage(data.getBytes("utf-8"))
- message.setQos(1)
- message.setRetained(true)
-
- for (i <- 0 to 10) {
- try {
- msgTopic.publish(message)
- } catch {
- case e: MqttException if e.getReasonCode == MqttException.REASON_CODE_MAX_INFLIGHT =>
- // wait for Spark streaming to consume something from the message queue
- Thread.sleep(50)
- }
- }
- }
- } finally {
- client.disconnect()
- client.close()
- client = null
- }
- }
-
- /**
- * Block until at least one receiver has started or timeout occurs.
- */
- private def waitForReceiverToStart() = {
- val latch = new CountDownLatch(1)
- ssc.addStreamingListener(new StreamingListener {
- override def onReceiverStarted(receiverStarted: StreamingListenerReceiverStarted) {
- latch.countDown()
- }
- })
-
- assert(latch.await(10, TimeUnit.SECONDS), "Timeout waiting for receiver to start.")
- }
}
diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py
index 188c8ff12067e..3012cd1e1a1b7 100644
--- a/python/pyspark/streaming/tests.py
+++ b/python/pyspark/streaming/tests.py
@@ -39,6 +39,7 @@
from pyspark.streaming.context import StreamingContext
from pyspark.streaming.kafka import Broker, KafkaUtils, OffsetRange, TopicAndPartition
from pyspark.streaming.flume import FlumeUtils
+from pyspark.streaming.mqtt import MQTTUtils
class PySparkStreamingTestCase(unittest.TestCase):
@@ -826,6 +827,52 @@ def test_flume_polling(self):
def test_flume_polling_multiple_hosts(self):
self._testMultipleTimes(self._testFlumePollingMultipleHosts)
+class MQTTStreamTests(PySparkStreamingTestCase):
+ timeout = 20 # seconds
+ duration = 1
+
+ def setUp(self):
+ super(MQTTStreamTests, self).setUp()
+
+ utilsClz = self.ssc._jvm.java.lang.Thread.currentThread().getContextClassLoader() \
+ .loadClass("org.apache.spark.streaming.mqtt.MQTTTestUtils")
+ self._utils = utilsClz.newInstance()
+ self._MQTTTestUtils.setup()
+
+ def tearDown(self):
+ if self._MQTTTestUtils is not None:
+ self._MQTTTestUtils.teardown()
+ self._MQTTTestUtils = None
+
+ super(MQTTStreamTests, self).tearDown()
+
+ def _randomTopic(self):
+ return "topic-%d" % random.randint(0, 10000)
+
+ def _validateStreamResult(self, sendData, stream):
+ result = []
+
+ def get_output(_, rdd):
+ for data in rdd.collect():
+ result.append(data)
+
+ dstream.foreachRDD(get_output)
+ receiveData = ' '.join(result[0])
+ self.assertEqual(sendData, receiveData)
+
+ def test_mqtt_stream(self):
+ """Test the Python Kafka stream API."""
+ topic = self._randomTopic()
+ sendData = "MQTT demo for spark streaming"
+ ssc = self.ssc
+
+ self._MQTTTestUtils.createTopic(topic)
+ self._MQTTTestUtils.waitForReceiverToStart(ssc)
+ self._MQTTTestUtils.publishData(topic, sendData)
+
+ stream = MQTTUtils.createStream(ssc, "tcp://" + MQTTTestUtils.brokerUri, topic)
+ self._validateStreamResult(sendData, stream)
+
def search_kafka_assembly_jar():
SPARK_HOME = os.environ["SPARK_HOME"]
@@ -862,10 +909,28 @@ def search_flume_assembly_jar():
else:
return jars[0]
+def search_mqtt_assembly_jar():
+ SPARK_HOME = os.environ["SPARK_HOME"]
+ mqtt_assembly_dir = os.path.join(SPARK_HOME, "external/mqtt-assembly")
+ jars = glob.glob(
+ os.path.join(mqtt_assembly_dir, "target/scala-*/spark-streaming-mqtt-assembly-*.jar"))
+ if not jars:
+ raise Exception(
+ ("Failed to find Spark Streaming MQTT assembly jar in %s. " % mqtt_assembly_dir) +
+ "You need to build Spark with "
+ "'build/sbt assembly/assembly streaming-mqtt-assembly/assembly' or "
+ "'build/mvn package' before running this test")
+ elif len(jars) > 1:
+ raise Exception(("Found multiple Spark Streaming MQTT assembly JARs in %s; please "
+ "remove all but one") % mqtt_assembly_dir)
+ else:
+ return jars[0]
+
if __name__ == "__main__":
kafka_assembly_jar = search_kafka_assembly_jar()
flume_assembly_jar = search_flume_assembly_jar()
- jars = "%s,%s" % (kafka_assembly_jar, flume_assembly_jar)
+ mqtt_assembly_jar = search_mqtt_assembly_jar()
+ jars = "%s,%s,%s" % (kafka_assembly_jar, flume_assembly_jar, mqtt_assembly_jar)
os.environ["PYSPARK_SUBMIT_ARGS"] = "--jars %s pyspark-shell" % jars
unittest.main()