diff --git a/dev/run-tests.py b/dev/run-tests.py index 1f0d218514f92..90535fd3b7b03 100755 --- a/dev/run-tests.py +++ b/dev/run-tests.py @@ -294,7 +294,8 @@ def build_spark_sbt(hadoop_version): sbt_goals = ["package", "assembly/assembly", "streaming-kafka-assembly/assembly", - "streaming-flume-assembly/assembly"] + "streaming-flume-assembly/assembly", + "streaming-mqtt-assembly/assembly"] profiles_and_goals = build_profiles + sbt_goals print("[info] Building Spark (w/Hive 0.13.1) using SBT with these arguments: ", diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 993583e2f4119..45a03c7ea7447 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -170,6 +170,7 @@ def contains_file(self, filename): dependencies=[streaming], source_file_regexes=[ "external/mqtt", + "external/mqtt-assembly", ], sbt_test_goals=[ "streaming-mqtt/test", @@ -290,7 +291,7 @@ def contains_file(self, filename): pyspark_streaming = Module( name="pyspark-streaming", - dependencies=[pyspark_core, streaming, streaming_kafka, streaming_flume_assembly], + dependencies=[pyspark_core, streaming, streaming_kafka, streaming_flume_assembly, streaming_mqtt], source_file_regexes=[ "python/pyspark/streaming" ], diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md index e72d5580dae55..33d835ba1c381 100644 --- a/docs/streaming-programming-guide.md +++ b/docs/streaming-programming-guide.md @@ -683,7 +683,7 @@ for Java, and [StreamingContext](api/python/pyspark.streaming.html#pyspark.strea {:.no_toc} Python API As of Spark {{site.SPARK_VERSION_SHORT}}, -out of these sources, *only* Kafka and Flume are available in the Python API. We will add more advanced sources in the Python API in future. +out of these sources, *only* Kafka, Flume and MQTT are available in the Python API. We will add more advanced sources in the Python API in future. This category of sources require interfacing with external non-Spark libraries, some of them with complex dependencies (e.g., Kafka and Flume). Hence, to minimize issues related to version conflicts diff --git a/external/mqtt/pom.xml b/external/mqtt/pom.xml index 0e41e5781784b..a28dd3603503a 100644 --- a/external/mqtt/pom.xml +++ b/external/mqtt/pom.xml @@ -72,7 +72,6 @@ org.apache.activemq activemq-core 5.7.0 - test diff --git a/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTTestUtils.scala b/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTTestUtils.scala new file mode 100644 index 0000000000000..555c26aad1811 --- /dev/null +++ b/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTTestUtils.scala @@ -0,0 +1,124 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.mqtt + +import java.net.{ServerSocket, URI} +import java.util.concurrent.{TimeUnit, CountDownLatch} + +import scala.language.postfixOps + +import org.apache.activemq.broker.{BrokerService, TransportConnector} +import org.apache.commons.lang3.RandomUtils +import org.eclipse.paho.client.mqttv3._ +import org.eclipse.paho.client.mqttv3.persist.MqttDefaultFilePersistence + +import org.apache.spark.streaming.{StreamingContext, Milliseconds} +import org.apache.spark.streaming.scheduler.StreamingListener +import org.apache.spark.streaming.scheduler.StreamingListenerReceiverStarted +import org.apache.spark.util.Utils +import org.apache.spark.{Logging, SparkConf} + +/** + * Share codes for Scala and Python unit tests + */ +private class MQTTTestUtils extends Logging { + + private val persistenceDir = Utils.createTempDir() + private val brokerHost = "localhost" + private var brokerPort = findFreePort() + + private var broker: BrokerService = _ + private var connector: TransportConnector = _ + + def brokerUri: String = { + s"$brokerHost:$brokerPort" + } + + def setup(): Unit = { + broker = new BrokerService() + broker.setDataDirectoryFile(Utils.createTempDir()) + connector = new TransportConnector() + connector.setName("mqtt") + connector.setUri(new URI("mqtt://" + brokerUri)) + broker.addConnector(connector) + broker.start() + } + + def teardown(): Unit = { + if (broker != null) { + broker.stop() + broker = null + } + if (connector != null) { + connector.stop() + connector = null + } + } + + private def findFreePort(): Int = { + val candidatePort = RandomUtils.nextInt(1024, 65536) + Utils.startServiceOnPort(candidatePort, (trialPort: Int) => { + val socket = new ServerSocket(trialPort) + socket.close() + (null, trialPort) + }, new SparkConf())._2 + } + + def publishData(topic: String, data: String): Unit = { + var client: MqttClient = null + try { + val persistence = new MqttDefaultFilePersistence(persistenceDir.getAbsolutePath) + client = new MqttClient("tcp://" + brokerUri, MqttClient.generateClientId(), persistence) + client.connect() + if (client.isConnected) { + val msgTopic = client.getTopic(topic) + val message = new MqttMessage(data.getBytes("utf-8")) + message.setQos(1) + message.setRetained(true) + + for (i <- 0 to 10) { + try { + msgTopic.publish(message) + } catch { + case e: MqttException if e.getReasonCode == MqttException.REASON_CODE_MAX_INFLIGHT => + // wait for Spark streaming to consume something from the message queue + Thread.sleep(50) + } + } + } + } finally { + client.disconnect() + client.close() + client = null + } + } + + /** + * Block until at least one receiver has started or timeout occurs. + */ + def waitForReceiverToStart(ssc: StreamingContext) = { + val latch = new CountDownLatch(1) + ssc.addStreamingListener(new StreamingListener { + override def onReceiverStarted(receiverStarted: StreamingListenerReceiverStarted) { + latch.countDown() + } + }) + + assert(latch.await(10, TimeUnit.SECONDS), "Timeout waiting for receiver to start.") + } +} diff --git a/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala b/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala index c4bf5aa7869bb..4d352cba96a3a 100644 --- a/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala +++ b/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala @@ -17,61 +17,45 @@ package org.apache.spark.streaming.mqtt -import java.net.{URI, ServerSocket} -import java.util.concurrent.CountDownLatch -import java.util.concurrent.TimeUnit - import scala.concurrent.duration._ import scala.language.postfixOps -import org.apache.activemq.broker.{TransportConnector, BrokerService} -import org.apache.commons.lang3.RandomUtils -import org.eclipse.paho.client.mqttv3._ -import org.eclipse.paho.client.mqttv3.persist.MqttDefaultFilePersistence - -import org.scalatest.BeforeAndAfter +import org.scalatest.BeforeAndAfterAll import org.scalatest.concurrent.Eventually -import org.apache.spark.streaming.{Milliseconds, StreamingContext} -import org.apache.spark.storage.StorageLevel -import org.apache.spark.streaming.dstream.ReceiverInputDStream -import org.apache.spark.streaming.scheduler.StreamingListener -import org.apache.spark.streaming.scheduler.StreamingListenerReceiverStarted import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.util.Utils - -class MQTTStreamSuite extends SparkFunSuite with Eventually with BeforeAndAfter { +import org.apache.spark.storage.StorageLevel +import org.apache.spark.streaming.{Milliseconds, StreamingContext} - private val batchDuration = Milliseconds(500) - private val master = "local[2]" - private val framework = this.getClass.getSimpleName - private val freePort = findFreePort() - private val brokerUri = "//localhost:" + freePort - private val topic = "def" - private val persistenceDir = Utils.createTempDir() +class MQTTStreamSuite extends SparkFunSuite with Eventually with BeforeAndAfterAll { + private val topic = "topic" private var ssc: StreamingContext = _ - private var broker: BrokerService = _ - private var connector: TransportConnector = _ + private var MQTTTestUtils: MQTTTestUtils = _ - before { - ssc = new StreamingContext(master, framework, batchDuration) - setupMQTT() + override def beforeAll(): Unit = { + MQTTTestUtils = new MQTTTestUtils + MQTTTestUtils.setup() } - after { + override def afterAll(): Unit = { if (ssc != null) { ssc.stop() ssc = null } - Utils.deleteRecursively(persistenceDir) - tearDownMQTT() + + if (MQTTTestUtils != null) { + MQTTTestUtils.teardown() + MQTTTestUtils = null + } } test("mqtt input stream") { + val sparkConf = new SparkConf().setMaster("local[4]").setAppName(this.getClass.getSimpleName) + ssc = new StreamingContext(sparkConf, Milliseconds(500)) val sendMessage = "MQTT demo for spark streaming" val receiveStream = - MQTTUtils.createStream(ssc, "tcp:" + brokerUri, topic, StorageLevel.MEMORY_ONLY) + MQTTUtils.createStream(ssc, "tcp://" + MQTTTestUtils.brokerUri, topic, StorageLevel.MEMORY_ONLY) @volatile var receiveMessage: List[String] = List() receiveStream.foreachRDD { rdd => if (rdd.collect.length > 0) { @@ -83,85 +67,13 @@ class MQTTStreamSuite extends SparkFunSuite with Eventually with BeforeAndAfter // wait for the receiver to start before publishing data, or we risk failing // the test nondeterministically. See SPARK-4631 - waitForReceiverToStart() + MQTTTestUtils.waitForReceiverToStart(ssc) + + MQTTTestUtils.publishData(topic, sendMessage) - publishData(sendMessage) eventually(timeout(10000 milliseconds), interval(100 milliseconds)) { assert(sendMessage.equals(receiveMessage(0))) } ssc.stop() } - - private def setupMQTT() { - broker = new BrokerService() - broker.setDataDirectoryFile(Utils.createTempDir()) - connector = new TransportConnector() - connector.setName("mqtt") - connector.setUri(new URI("mqtt:" + brokerUri)) - broker.addConnector(connector) - broker.start() - } - - private def tearDownMQTT() { - if (broker != null) { - broker.stop() - broker = null - } - if (connector != null) { - connector.stop() - connector = null - } - } - - private def findFreePort(): Int = { - val candidatePort = RandomUtils.nextInt(1024, 65536) - Utils.startServiceOnPort(candidatePort, (trialPort: Int) => { - val socket = new ServerSocket(trialPort) - socket.close() - (null, trialPort) - }, new SparkConf())._2 - } - - def publishData(data: String): Unit = { - var client: MqttClient = null - try { - val persistence = new MqttDefaultFilePersistence(persistenceDir.getAbsolutePath) - client = new MqttClient("tcp:" + brokerUri, MqttClient.generateClientId(), persistence) - client.connect() - if (client.isConnected) { - val msgTopic = client.getTopic(topic) - val message = new MqttMessage(data.getBytes("utf-8")) - message.setQos(1) - message.setRetained(true) - - for (i <- 0 to 10) { - try { - msgTopic.publish(message) - } catch { - case e: MqttException if e.getReasonCode == MqttException.REASON_CODE_MAX_INFLIGHT => - // wait for Spark streaming to consume something from the message queue - Thread.sleep(50) - } - } - } - } finally { - client.disconnect() - client.close() - client = null - } - } - - /** - * Block until at least one receiver has started or timeout occurs. - */ - private def waitForReceiverToStart() = { - val latch = new CountDownLatch(1) - ssc.addStreamingListener(new StreamingListener { - override def onReceiverStarted(receiverStarted: StreamingListenerReceiverStarted) { - latch.countDown() - } - }) - - assert(latch.await(10, TimeUnit.SECONDS), "Timeout waiting for receiver to start.") - } } diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index 188c8ff12067e..3012cd1e1a1b7 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -39,6 +39,7 @@ from pyspark.streaming.context import StreamingContext from pyspark.streaming.kafka import Broker, KafkaUtils, OffsetRange, TopicAndPartition from pyspark.streaming.flume import FlumeUtils +from pyspark.streaming.mqtt import MQTTUtils class PySparkStreamingTestCase(unittest.TestCase): @@ -826,6 +827,52 @@ def test_flume_polling(self): def test_flume_polling_multiple_hosts(self): self._testMultipleTimes(self._testFlumePollingMultipleHosts) +class MQTTStreamTests(PySparkStreamingTestCase): + timeout = 20 # seconds + duration = 1 + + def setUp(self): + super(MQTTStreamTests, self).setUp() + + utilsClz = self.ssc._jvm.java.lang.Thread.currentThread().getContextClassLoader() \ + .loadClass("org.apache.spark.streaming.mqtt.MQTTTestUtils") + self._utils = utilsClz.newInstance() + self._MQTTTestUtils.setup() + + def tearDown(self): + if self._MQTTTestUtils is not None: + self._MQTTTestUtils.teardown() + self._MQTTTestUtils = None + + super(MQTTStreamTests, self).tearDown() + + def _randomTopic(self): + return "topic-%d" % random.randint(0, 10000) + + def _validateStreamResult(self, sendData, stream): + result = [] + + def get_output(_, rdd): + for data in rdd.collect(): + result.append(data) + + dstream.foreachRDD(get_output) + receiveData = ' '.join(result[0]) + self.assertEqual(sendData, receiveData) + + def test_mqtt_stream(self): + """Test the Python Kafka stream API.""" + topic = self._randomTopic() + sendData = "MQTT demo for spark streaming" + ssc = self.ssc + + self._MQTTTestUtils.createTopic(topic) + self._MQTTTestUtils.waitForReceiverToStart(ssc) + self._MQTTTestUtils.publishData(topic, sendData) + + stream = MQTTUtils.createStream(ssc, "tcp://" + MQTTTestUtils.brokerUri, topic) + self._validateStreamResult(sendData, stream) + def search_kafka_assembly_jar(): SPARK_HOME = os.environ["SPARK_HOME"] @@ -862,10 +909,28 @@ def search_flume_assembly_jar(): else: return jars[0] +def search_mqtt_assembly_jar(): + SPARK_HOME = os.environ["SPARK_HOME"] + mqtt_assembly_dir = os.path.join(SPARK_HOME, "external/mqtt-assembly") + jars = glob.glob( + os.path.join(mqtt_assembly_dir, "target/scala-*/spark-streaming-mqtt-assembly-*.jar")) + if not jars: + raise Exception( + ("Failed to find Spark Streaming MQTT assembly jar in %s. " % mqtt_assembly_dir) + + "You need to build Spark with " + "'build/sbt assembly/assembly streaming-mqtt-assembly/assembly' or " + "'build/mvn package' before running this test") + elif len(jars) > 1: + raise Exception(("Found multiple Spark Streaming MQTT assembly JARs in %s; please " + "remove all but one") % mqtt_assembly_dir) + else: + return jars[0] + if __name__ == "__main__": kafka_assembly_jar = search_kafka_assembly_jar() flume_assembly_jar = search_flume_assembly_jar() - jars = "%s,%s" % (kafka_assembly_jar, flume_assembly_jar) + mqtt_assembly_jar = search_mqtt_assembly_jar() + jars = "%s,%s,%s" % (kafka_assembly_jar, flume_assembly_jar, mqtt_assembly_jar) os.environ["PYSPARK_SUBMIT_ARGS"] = "--jars %s pyspark-shell" % jars unittest.main()