From 9709a626b262892c23657b43f785ce30ea07514b Mon Sep 17 00:00:00 2001 From: Lukasz Antoniak Date: Wed, 27 Mar 2019 02:49:27 -0700 Subject: [PATCH] [BAHIR-183] HDFS based MQTT client persistence --- README.md | 2 +- sql-streaming-mqtt/README.md | 50 +-- ...pache.spark.sql.sources.DataSourceRegister | 3 +- .../sql/streaming/mqtt/CachedMQTTClient.scala | 2 +- .../sql/streaming/mqtt/MQTTStreamSink.scala | 2 +- .../sql/streaming/mqtt/MQTTStreamSource.scala | 6 +- .../bahir/sql/streaming/mqtt/MQTTUtils.scala | 19 +- .../sql/streaming/mqtt/MessageStore.scala | 127 +++++- .../sql/mqtt/HDFSMQTTSourceProvider.scala | 64 --- .../sql/mqtt/HdfsBasedMQTTStreamSource.scala | 398 ------------------ .../mqtt/HDFSBasedMQTTStreamSourceSuite.scala | 198 --------- .../mqtt/HDFSMessageStoreSuite.scala | 80 ++++ .../sql/streaming/mqtt/HDFSTestUtils.scala | 40 ++ .../mqtt/LocalMessageStoreSuite.scala | 4 +- .../mqtt/MQTTStreamSourceSuite.scala | 81 +++- 15 files changed, 351 insertions(+), 725 deletions(-) delete mode 100644 sql-streaming-mqtt/src/main/scala/org/apache/spark/sql/mqtt/HDFSMQTTSourceProvider.scala delete mode 100644 sql-streaming-mqtt/src/main/scala/org/apache/spark/sql/mqtt/HdfsBasedMQTTStreamSource.scala delete mode 100644 sql-streaming-mqtt/src/test/scala/org/apache/bahir/sql/streaming/mqtt/HDFSBasedMQTTStreamSourceSuite.scala create mode 100644 sql-streaming-mqtt/src/test/scala/org/apache/bahir/sql/streaming/mqtt/HDFSMessageStoreSuite.scala create mode 100644 sql-streaming-mqtt/src/test/scala/org/apache/bahir/sql/streaming/mqtt/HDFSTestUtils.scala diff --git a/README.md b/README.md index 93a50ea7..155f1dbb 100644 --- a/README.md +++ b/README.md @@ -61,7 +61,7 @@ Furthermore, to generate scaladocs for each module: `$ mvn package` -Scaladocs is generated in, `MODULE_NAME/target/site/scaladocs/index.html`. __ Where `MODULE_NAME` is one of, `sql-streaming-mqtt`, `streaming-akka`, `streaming-mqtt`, `streaming-zeromq`, `streaming-twitter`. __ +Scaladocs is generated in `${MODULE_NAME}/target/site/scaladocs/index.html`, where `MODULE_NAME` is one of `sql-streaming-mqtt`, `streaming-akka`, `streaming-mqtt`, `streaming-zeromq` or `streaming-twitter`. ## A note about Apache Spark integration diff --git a/sql-streaming-mqtt/README.md b/sql-streaming-mqtt/README.md index 0fbf63e0..faa6b850 100644 --- a/sql-streaming-mqtt/README.md +++ b/sql-streaming-mqtt/README.md @@ -57,31 +57,31 @@ Setting values for option `localStorage` and `clientId` helps in recovering in c This connector uses [Eclipse Paho Java Client](https://eclipse.org/paho/clients/java/). Client API documentation is located [here](http://www.eclipse.org/paho/files/javadoc/index.html). -| Parameter name | Description | Eclipse Paho reference | -|----------------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|--------------------------------------------------------------------------| -| `brokerUrl` | URL MQTT client connects to. Specify this parameter or _path_. Example: _tcp://localhost:1883_, _ssl://localhost:1883_. | | -| `persistence` | Defines how incoming messages are stored. If _memory_ is provided as value for this option, recovery on restart is not supported. Otherwise messages are stored on disk and parameter _localStorage_ may define target directory. | | -| `topic` | Topic which client subscribes to. | | -| `clientId` | Uniquely identifies client instance. Provide the same value to recover a stopped source client. MQTT sink ignores client identifier, because Spark batch can be distributed across multiple workers whereas MQTT broker does not allow simultaneous connections with same ID from multiple hosts. | | -| `QoS` | The maximum quality of service to subscribe each topic at. Messages published at a lower quality of service will be received at the published QoS. Messages published at a higher quality of service will be received using the QoS specified on the subscribe. | | -| `username` | User name used to authenticate with MQTT server. Do not set it, if server does not require authentication. Leaving empty may lead to errors. | `MqttConnectOptions.setUserName` | -| `password` | User password. | `MqttConnectOptions.setPassword` | -| `cleanSession` | Setting to _true_ starts a clean session, removes all check-pointed messages persisted during previous run. Defaults to `false`. | `MqttConnectOptions.setCleanSession` | -| `connectionTimeout` | Sets the connection timeout, a value of _0_ is interpreted as wait until client connects. | `MqttConnectOptions.setConnectionTimeout` | -| `keepAlive` | Sets the "keep alive" interval in seconds. | `MqttConnectOptions.setKeepAliveInterval` | -| `mqttVersion` | Specify MQTT protocol version. | `MqttConnectOptions.setMqttVersion` | -| `maxInflight` | Sets the maximum inflight requests. Useful for high volume traffic. | `MqttConnectOptions.setMaxInflight` | -| `autoReconnect` | Sets whether the client will automatically attempt to reconnect to the server upon connectivity disruption. | `MqttConnectOptions.setAutomaticReconnect` | -| `ssl.protocol` | SSL protocol. Example: _SSLv3_, _TLS_, _TLSv1_, _TLSv1.2_. | `MqttConnectOptions.setSSLProperties`, `com.ibm.ssl.protocol` | -| `ssl.key.store` | Absolute path to key store file. | `MqttConnectOptions.setSSLProperties`, `com.ibm.ssl.keyStore` | -| `ssl.key.store.password` | Key store password. | `MqttConnectOptions.setSSLProperties`, `com.ibm.ssl.keyStorePassword` | -| `ssl.key.store.type` | Key store type. Example: _JKS_, _JCEKS_, _PKCS12_. | `MqttConnectOptions.setSSLProperties`, `com.ibm.ssl.keyStoreType` | -| `ssl.key.store.provider` | Key store provider. Example: _IBMJCE_. | `MqttConnectOptions.setSSLProperties`, `com.ibm.ssl.keyStoreProvider` | -| `ssl.trust.store` | Absolute path to trust store file. | `MqttConnectOptions.setSSLProperties`, `com.ibm.ssl.trustStore` | -| `ssl.trust.store.password` | Trust store password. | `MqttConnectOptions.setSSLProperties`, `com.ibm.ssl.trustStorePassword` | -| `ssl.trust.store.type` | Trust store type. Example: _JKS_, _JCEKS_, _PKCS12_. | `MqttConnectOptions.setSSLProperties`, `com.ibm.ssl.trustStoreType` | -| `ssl.trust.store.provider` | Trust store provider. Example: _IBMJCEFIPS_. | `MqttConnectOptions.setSSLProperties`, `com.ibm.ssl.trustStoreProvider` | -| `ssl.ciphers` | List of enabled cipher suites. Example: _SSL_RSA_WITH_AES_128_CBC_SHA_. | `MqttConnectOptions.setSSLProperties`, `com.ibm.ssl.enabledCipherSuites` | +| Parameter name | Description | Eclipse Paho reference | +|----------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|--------------------------------------------------------------------------| +| `brokerUrl` | URL MQTT client connects to. Specify this parameter or _path_. Example: _tcp://localhost:1883_, _ssl://localhost:1883_. | | +| `persistence` | Defines how incoming messages are stored. If _memory_ is provided as value for this option, recovery on restart is not supported. Other options include _file_, which stores in-flight messages on local file system (parameter _localStorage_ may define target directory), and _hdfs_ which leverages HDFS storage (prefix HDFS configuration parameters with "hdfs." so that they are picked up by the connector). | | +| `topic` | Topic which client subscribes to. | | +| `clientId` | Uniquely identifies client instance. Provide the same value to recover a stopped source client. MQTT sink ignores client identifier, because Spark batch can be distributed across multiple workers whereas MQTT broker does not allow simultaneous connections with same ID from multiple hosts. | | +| `QoS` | The maximum quality of service to subscribe each topic at. Messages published at a lower quality of service will be received at the published QoS. Messages published at a higher quality of service will be received using the QoS specified on the subscribe. | | +| `username` | User name used to authenticate with MQTT server. Do not set it, if server does not require authentication. Leaving empty may lead to errors. | `MqttConnectOptions.setUserName` | +| `password` | User password. | `MqttConnectOptions.setPassword` | +| `cleanSession` | Setting to _true_ starts a clean session, removes all check-pointed messages persisted during previous run. Defaults to `false`. | `MqttConnectOptions.setCleanSession` | +| `connectionTimeout` | Sets the connection timeout, a value of _0_ is interpreted as wait until client connects. | `MqttConnectOptions.setConnectionTimeout` | +| `keepAlive` | Sets the "keep alive" interval in seconds. | `MqttConnectOptions.setKeepAliveInterval` | +| `mqttVersion` | Specify MQTT protocol version. | `MqttConnectOptions.setMqttVersion` | +| `maxInflight` | Sets the maximum inflight requests. Useful for high volume traffic. | `MqttConnectOptions.setMaxInflight` | +| `autoReconnect` | Sets whether the client will automatically attempt to reconnect to the server upon connectivity disruption. | `MqttConnectOptions.setAutomaticReconnect` | +| `ssl.protocol` | SSL protocol. Example: _SSLv3_, _TLS_, _TLSv1_, _TLSv1.2_. | `MqttConnectOptions.setSSLProperties`, `com.ibm.ssl.protocol` | +| `ssl.key.store` | Absolute path to key store file. | `MqttConnectOptions.setSSLProperties`, `com.ibm.ssl.keyStore` | +| `ssl.key.store.password` | Key store password. | `MqttConnectOptions.setSSLProperties`, `com.ibm.ssl.keyStorePassword` | +| `ssl.key.store.type` | Key store type. Example: _JKS_, _JCEKS_, _PKCS12_. | `MqttConnectOptions.setSSLProperties`, `com.ibm.ssl.keyStoreType` | +| `ssl.key.store.provider` | Key store provider. Example: _IBMJCE_. | `MqttConnectOptions.setSSLProperties`, `com.ibm.ssl.keyStoreProvider` | +| `ssl.trust.store` | Absolute path to trust store file. | `MqttConnectOptions.setSSLProperties`, `com.ibm.ssl.trustStore` | +| `ssl.trust.store.password` | Trust store password. | `MqttConnectOptions.setSSLProperties`, `com.ibm.ssl.trustStorePassword` | +| `ssl.trust.store.type` | Trust store type. Example: _JKS_, _JCEKS_, _PKCS12_. | `MqttConnectOptions.setSSLProperties`, `com.ibm.ssl.trustStoreType` | +| `ssl.trust.store.provider` | Trust store provider. Example: _IBMJCEFIPS_. | `MqttConnectOptions.setSSLProperties`, `com.ibm.ssl.trustStoreProvider` | +| `ssl.ciphers` | List of enabled cipher suites. Example: _SSL_RSA_WITH_AES_128_CBC_SHA_. | `MqttConnectOptions.setSSLProperties`, `com.ibm.ssl.enabledCipherSuites` | ## Environment variables diff --git a/sql-streaming-mqtt/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/sql-streaming-mqtt/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister index 1920a6bf..d3899e6a 100644 --- a/sql-streaming-mqtt/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister +++ b/sql-streaming-mqtt/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -16,5 +16,4 @@ # org.apache.bahir.sql.streaming.mqtt.MQTTStreamSinkProvider -org.apache.bahir.sql.streaming.mqtt.MQTTStreamSourceProvider -org.apache.spark.sql.mqtt.HDFSMQTTSourceProvider \ No newline at end of file +org.apache.bahir.sql.streaming.mqtt.MQTTStreamSourceProvider \ No newline at end of file diff --git a/sql-streaming-mqtt/src/main/scala/org/apache/bahir/sql/streaming/mqtt/CachedMQTTClient.scala b/sql-streaming-mqtt/src/main/scala/org/apache/bahir/sql/streaming/mqtt/CachedMQTTClient.scala index 78eae527..80c307bf 100644 --- a/sql-streaming-mqtt/src/main/scala/org/apache/bahir/sql/streaming/mqtt/CachedMQTTClient.scala +++ b/sql-streaming-mqtt/src/main/scala/org/apache/bahir/sql/streaming/mqtt/CachedMQTTClient.scala @@ -66,7 +66,7 @@ private[mqtt] object CachedMQTTClient extends Logging { private def createMqttClient(config: Map[String, String]): (MqttClient, MqttClientPersistence) = { - val (brokerUrl, clientId, _, persistence, mqttConnectOptions, _, _, _, _) = + val (brokerUrl, clientId, _, persistence, mqttConnectOptions, _) = MQTTUtils.parseConfigParams(config) val client = new MqttClient(brokerUrl, clientId, persistence) val callback = new MqttCallbackExtended() { diff --git a/sql-streaming-mqtt/src/main/scala/org/apache/bahir/sql/streaming/mqtt/MQTTStreamSink.scala b/sql-streaming-mqtt/src/main/scala/org/apache/bahir/sql/streaming/mqtt/MQTTStreamSink.scala index 23385f4c..1518a9c4 100644 --- a/sql-streaming-mqtt/src/main/scala/org/apache/bahir/sql/streaming/mqtt/MQTTStreamSink.scala +++ b/sql-streaming-mqtt/src/main/scala/org/apache/bahir/sql/streaming/mqtt/MQTTStreamSink.scala @@ -68,7 +68,7 @@ class MQTTDataWriter(config: mutable.Map[String, String]) extends DataWriter[Int private lazy val publishBackoff: Long = SparkEnv.get.conf.getTimeAsMs("spark.mqtt.client.publish.backoff", "5s") - private lazy val (_, _, topic, _, _, qos, _, _, _) = MQTTUtils.parseConfigParams(config.toMap) + private lazy val (_, _, topic, _, _, qos) = MQTTUtils.parseConfigParams(config.toMap) override def write(record: InternalRow): Unit = { val client = CachedMQTTClient.getOrCreate(config.toMap) diff --git a/sql-streaming-mqtt/src/main/scala/org/apache/bahir/sql/streaming/mqtt/MQTTStreamSource.scala b/sql-streaming-mqtt/src/main/scala/org/apache/bahir/sql/streaming/mqtt/MQTTStreamSource.scala index e1314ae2..c9df5957 100644 --- a/sql-streaming-mqtt/src/main/scala/org/apache/bahir/sql/streaming/mqtt/MQTTStreamSource.scala +++ b/sql-streaming-mqtt/src/main/scala/org/apache/bahir/sql/streaming/mqtt/MQTTStreamSource.scala @@ -100,9 +100,6 @@ class MQTTStreamSource(options: DataSourceOptions, brokerUrl: String, persistenc private var startOffset: OffsetV2 = _ private var endOffset: OffsetV2 = _ - /* Older than last N messages, will not be checked for redelivery. */ - val backLog = options.getInt("autopruning.backlog", 500) - private[mqtt] val store = new LocalMessageStore(persistence) private[mqtt] val messages = new TrieMap[Long, MQTTMessage] @@ -231,7 +228,6 @@ class MQTTStreamSource(options: DataSourceOptions, brokerUrl: String, persistenc /** Stop this source. */ override def stop(): Unit = synchronized { client.disconnect() - persistence.close() client.close() } @@ -250,7 +246,7 @@ class MQTTStreamSourceProvider extends DataSourceV2 } import scala.collection.JavaConverters._ - val (brokerUrl, clientId, topic, persistence, mqttConnectOptions, qos, _, _, _) = + val (brokerUrl, clientId, topic, persistence, mqttConnectOptions, qos) = MQTTUtils.parseConfigParams(collection.immutable.HashMap() ++ parameters.asMap().asScala) new MQTTStreamSource(parameters, brokerUrl, persistence, topic, clientId, diff --git a/sql-streaming-mqtt/src/main/scala/org/apache/bahir/sql/streaming/mqtt/MQTTUtils.scala b/sql-streaming-mqtt/src/main/scala/org/apache/bahir/sql/streaming/mqtt/MQTTUtils.scala index 9df46bcc..ee6734b3 100644 --- a/sql-streaming-mqtt/src/main/scala/org/apache/bahir/sql/streaming/mqtt/MQTTUtils.scala +++ b/sql-streaming-mqtt/src/main/scala/org/apache/bahir/sql/streaming/mqtt/MQTTUtils.scala @@ -19,6 +19,7 @@ package org.apache.bahir.sql.streaming.mqtt import java.util.Properties +import org.apache.hadoop.conf.Configuration import org.eclipse.paho.client.mqttv3.{MqttClient, MqttClientPersistence, MqttConnectOptions} import org.eclipse.paho.client.mqttv3.persist.{MemoryPersistence, MqttDefaultFilePersistence} @@ -45,7 +46,7 @@ object MQTTUtils extends Logging { ) def parseConfigParams(config: Map[String, String]): - (String, String, String, MqttClientPersistence, MqttConnectOptions, Int, Long, Long, Int) = { + (String, String, String, MqttClientPersistence, MqttConnectOptions, Int) = { def e(s: String) = new IllegalArgumentException(s) val parameters = CaseInsensitiveMap(config) @@ -54,6 +55,14 @@ object MQTTUtils extends Logging { val persistence: MqttClientPersistence = parameters.get("persistence") match { case Some("memory") => new MemoryPersistence() + case Some("hdfs") => + val hadoopConfig = new Configuration + for (parameter <- parameters) { + if (parameter._1.startsWith("hdfs.")) { + hadoopConfig.set(parameter._1.replaceFirst("hdfs.", ""), parameter._2) + } + } + new HdfsMqttClientPersistence( hadoopConfig ) case _ => val localStorage: Option[String] = parameters.get("localStorage") localStorage match { case Some(x) => new MqttDefaultFilePersistence(x) @@ -83,11 +92,6 @@ object MQTTUtils extends Logging { val autoReconnect: Boolean = parameters.getOrElse("autoReconnect", "false").toBoolean val maxInflight: Int = parameters.getOrElse("maxInflight", "60").toInt - val maxBatchMessageNum = parameters.getOrElse("maxBatchMessageNum", s"${Long.MaxValue}").toLong - val maxBatchMessageSize = parameters.getOrElse("maxBatchMessageSize", - s"${Long.MaxValue}").toLong - val maxRetryNumber = parameters.getOrElse("maxRetryNum", "3").toInt - val mqttConnectOptions: MqttConnectOptions = new MqttConnectOptions() mqttConnectOptions.setAutomaticReconnect(autoReconnect) mqttConnectOptions.setCleanSession(cleanSession) @@ -109,7 +113,6 @@ object MQTTUtils extends Logging { }) mqttConnectOptions.setSSLProperties(sslProperties) - (brokerUrl, clientId, topic, persistence, mqttConnectOptions, qos, - maxBatchMessageNum, maxBatchMessageSize, maxRetryNumber) + (brokerUrl, clientId, topic, persistence, mqttConnectOptions, qos) } } diff --git a/sql-streaming-mqtt/src/main/scala/org/apache/bahir/sql/streaming/mqtt/MessageStore.scala b/sql-streaming-mqtt/src/main/scala/org/apache/bahir/sql/streaming/mqtt/MessageStore.scala index 30ec7a60..71e1146a 100644 --- a/sql-streaming-mqtt/src/main/scala/org/apache/bahir/sql/streaming/mqtt/MessageStore.scala +++ b/sql-streaming-mqtt/src/main/scala/org/apache/bahir/sql/streaming/mqtt/MessageStore.scala @@ -21,7 +21,11 @@ package org.apache.bahir.sql.streaming.mqtt import java.io._ import java.util +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.FileSystem +import org.apache.hadoop.fs.Path import org.eclipse.paho.client.mqttv3.{MqttClientPersistence, MqttPersistable, MqttPersistenceException} +import org.eclipse.paho.client.mqttv3.internal.MqttPersistentData import scala.util.Try import org.apache.bahir.utils.Logging @@ -45,24 +49,16 @@ trait MessageStore { } private[mqtt] class MqttPersistableData(bytes: Array[Byte]) extends MqttPersistable { - override def getHeaderLength: Int = bytes.length - override def getHeaderOffset: Int = 0 - override def getPayloadOffset: Int = 0 - override def getPayloadBytes: Array[Byte] = null - override def getHeaderBytes: Array[Byte] = bytes - override def getPayloadLength: Int = 0 } trait Serializer { - def deserialize[T](x: Array[Byte]): T - def serialize[T](x: T): Array[Byte] } @@ -94,17 +90,14 @@ class JavaSerializer extends Serializer with Logging { null } } + } object JavaSerializer { - private lazy val instance = new JavaSerializer() - def getInstance(): JavaSerializer = instance - } - /** * A message store to persist messages received. This is not intended to be thread safe. * It uses `MqttDefaultFilePersistence` for storing messages on disk locally on the client. @@ -148,3 +141,113 @@ private[mqtt] class LocalMessageStore(val persistentStore: MqttClientPersistence } } + +private[mqtt] class HdfsMqttClientPersistence(config: Configuration) + extends MqttClientPersistence { + + var rootPath: Path = _ + var fileSystem: FileSystem = _ + + override def open(clientId: String, serverURI: String): Unit = { + try { + rootPath = new Path("mqtt/" + clientId + "/" + serverURI.replaceAll("[^a-zA-Z0-9]", "_")) + fileSystem = FileSystem.get(config) + if (!fileSystem.exists(rootPath)) { + fileSystem.mkdirs(rootPath) + } + } + catch { + case e: Exception => throw new MqttPersistenceException(e) + } + } + + override def close(): Unit = { + try { + fileSystem.close() + } + catch { + case e: Exception => throw new MqttPersistenceException(e) + } + } + + override def put(key: String, persistable: MqttPersistable): Unit = { + try { + val path = getPath(key) + val output = fileSystem.create(path) + output.writeInt(persistable.getHeaderLength) + if (persistable.getHeaderLength > 0) { + output.write(persistable.getHeaderBytes) + } + output.writeInt(persistable.getPayloadLength) + if (persistable.getPayloadLength > 0) { + output.write(persistable.getPayloadBytes) + } + output.close() + } + catch { + case e: Exception => throw new MqttPersistenceException(e) + } + } + + override def get(key: String): MqttPersistable = { + try { + val input = fileSystem.open(getPath(key)) + val headerLength = input.readInt() + val headerBytes: Array[Byte] = new Array[Byte](headerLength) + input.read(headerBytes) + val payloadLength = input.readInt() + val payloadBytes: Array[Byte] = new Array[Byte](payloadLength) + input.read(payloadBytes) + input.close() + new MqttPersistentData( + key, headerBytes, 0, headerBytes.length, payloadBytes, 0, payloadBytes.length + ) + } + catch { + case e: Exception => throw new MqttPersistenceException(e) + } + } + + override def remove(key: String): Unit = { + try { + fileSystem.delete(getPath(key), false) + } + catch { + case e: Exception => throw new MqttPersistenceException(e) + } + } + + override def keys(): util.Enumeration[String] = { + try { + val iterator = fileSystem.listFiles(rootPath, false) + new util.Enumeration[String]() { + override def hasMoreElements: Boolean = iterator.hasNext + override def nextElement(): String = iterator.next().getPath.getName + } + } + catch { + case e: Exception => throw new MqttPersistenceException(e) + } + } + + override def clear(): Unit = { + try { + fileSystem.delete(rootPath, true) + } + catch { + case e: Exception => throw new MqttPersistenceException(e) + } + } + + override def containsKey(key: String): Boolean = { + try { + fileSystem.isFile(getPath(key)) + } + catch { + case e: Exception => throw new MqttPersistenceException(e) + } + } + + private def getPath(key: String): Path = new Path(rootPath + "/" + key) + +} diff --git a/sql-streaming-mqtt/src/main/scala/org/apache/spark/sql/mqtt/HDFSMQTTSourceProvider.scala b/sql-streaming-mqtt/src/main/scala/org/apache/spark/sql/mqtt/HDFSMQTTSourceProvider.scala deleted file mode 100644 index f38d8426..00000000 --- a/sql-streaming-mqtt/src/main/scala/org/apache/spark/sql/mqtt/HDFSMQTTSourceProvider.scala +++ /dev/null @@ -1,64 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.mqtt - -import org.apache.spark.internal.Logging -import org.apache.spark.sql.SQLContext -import org.apache.spark.sql.execution.streaming.Source -import org.apache.spark.sql.sources.{DataSourceRegister, StreamSourceProvider} -import org.apache.spark.sql.types.StructType - -import org.apache.bahir.sql.streaming.mqtt.{MQTTStreamConstants, MQTTUtils} - -/** - * The provider class for creating MQTT source. - * This provider throw IllegalArgumentException if 'brokerUrl' or 'topic' parameter - * is not set in options. - */ -class HDFSMQTTSourceProvider extends StreamSourceProvider with DataSourceRegister with Logging { - - override def sourceSchema(sqlContext: SQLContext, schema: Option[StructType], - providerName: String, parameters: Map[String, String]): (String, StructType) = { - ("hdfs-mqtt", MQTTStreamConstants.SCHEMA_DEFAULT) - } - - override def createSource(sqlContext: SQLContext, metadataPath: String, - schema: Option[StructType], providerName: String, parameters: Map[String, String]): Source = { - - val parsedResult = MQTTUtils.parseConfigParams(parameters) - - new HdfsBasedMQTTStreamSource( - sqlContext, - metadataPath, - parsedResult._1, // brokerUrl - parsedResult._2, // clientId - parsedResult._3, // topic - parsedResult._5, // mqttConnectionOptions - parsedResult._6, // qos - parsedResult._7, // maxBatchMessageNum - parsedResult._8, // maxBatchMessageSize - parsedResult._9 // maxRetryNum - ) - } - - override def shortName(): String = "hdfs-mqtt" -} - -object HDFSMQTTSourceProvider { - val SEP = "##" -} diff --git a/sql-streaming-mqtt/src/main/scala/org/apache/spark/sql/mqtt/HdfsBasedMQTTStreamSource.scala b/sql-streaming-mqtt/src/main/scala/org/apache/spark/sql/mqtt/HdfsBasedMQTTStreamSource.scala deleted file mode 100644 index fd395570..00000000 --- a/sql-streaming-mqtt/src/main/scala/org/apache/spark/sql/mqtt/HdfsBasedMQTTStreamSource.scala +++ /dev/null @@ -1,398 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.mqtt - -import java.io.IOException -import java.util.Calendar -import java.util.concurrent.locks.{Lock, ReentrantLock} - -import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{FileSystem, FSDataOutputStream, Path, PathFilter} -import org.eclipse.paho.client.mqttv3._ -import org.eclipse.paho.client.mqttv3.persist.MemoryPersistence - -import org.apache.spark.internal.Logging -import org.apache.spark.sql.{DataFrame, SQLContext} -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.execution.streaming._ -import org.apache.spark.sql.types.StructType -import org.apache.spark.unsafe.types.UTF8String - -import org.apache.bahir.sql.streaming.mqtt.{LongOffset, MQTTStreamConstants} - -/** - * A Text based mqtt stream source, it interprets the payload of each incoming message by converting - * the bytes to String using Charset.defaultCharset as charset. Each value is associated with a - * timestamp of arrival of the message on the source. It can be used to operate a window on the - * incoming stream. - * - * @param sqlContext Spark provided, SqlContext. - * @param metadataPath meta data path - * @param brokerUrl url MqttClient connects to. - * @param topic topic MqttClient subscribes to. - * @param clientId clientId, this client is assoicated with. - * Provide the same value to recover a stopped client. - * @param mqttConnectOptions an instance of MqttConnectOptions for this Source. - * @param qos the maximum quality of service to subscribe each topic at. - * Messages published at a lower quality of service will be received - * at the published QoS. Messages published at a higher quality of - * service will be received using the QoS specified on the subscribe. - * @param maxBatchNumber the max message number to process in one batch. - * @param maxBatchSize the max total size in one batch, measured in bytes number. - */ -class HdfsBasedMQTTStreamSource( - sqlContext: SQLContext, - metadataPath: String, - brokerUrl: String, - clientId: String, - topic: String, - mqttConnectOptions: MqttConnectOptions, - qos: Int, - maxBatchNumber: Long = Long.MaxValue, - maxBatchSize: Long = Long.MaxValue, - maxRetryNumber: Int = 3 -) extends Source with Logging { - - import HDFSMQTTSourceProvider.SEP - - override def schema: StructType = MQTTStreamConstants.SCHEMA_DEFAULT - - // Last batch offset file index - private var lastOffset: Long = -1L - - // Current data file index to write messages. - private var currentMessageDataFileOffset: Long = 0L - - // FileSystem instance for storing received messages. - private var fs: FileSystem = _ - private var messageStoreOutputStream: FSDataOutputStream = _ - - // total message number received for current batch. - private var messageNumberForCurrentBatch: Int = 0 - // total message size received for - private var messageSizeForCurrentBatch: Int = 0 - - private val minBatchesToRetain = sqlContext.sparkSession.sessionState.conf.minBatchesToRetain - - // the consecutive fail number, cannot exceed the `maxRetryNumber` - private var consecutiveFailNum = 0 - - private var client: MqttClient = _ - - private val lock: Lock = new ReentrantLock() - - private val hadoopConfig: Configuration = if (HdfsBasedMQTTStreamSource.hadoopConfig != null) { - logInfo("using setted hadoop configuration!") - HdfsBasedMQTTStreamSource.hadoopConfig - } else { - logInfo("create a new configuration.") - new Configuration() - } - - private val rootCheckpointPath = { - val path = new Path(metadataPath).getParent.getParent.toUri.toString - logInfo(s"get rootCheckpointPath $path") - path - } - - private val receivedDataPath = s"$rootCheckpointPath/receivedMessages" - - // lazily init latest offset from offset WAL log - private lazy val recoveredLatestOffset = { - // the index of this source, parsing from metadata path - val currentSourceIndex = { - if (!metadataPath.isEmpty) { - metadataPath.substring(metadataPath.lastIndexOf("/") + 1).toInt - } else { - -1 - } - } - if (currentSourceIndex >= 0) { - val offsetLog = new OffsetSeqLog(sqlContext.sparkSession, - new Path(rootCheckpointPath, "offsets").toUri.toString) - // get the latest offset from WAL log - offsetLog.getLatest() match { - case Some((batchId, _)) => - logInfo(s"get latest batch $batchId") - Some(batchId) - case None => - logInfo("no offset avaliable in offset log") - None - } - } else { - logInfo("checkpoint path is not set") - None - } - } - - initialize() - - // Change data file if reach flow control threshold for one batch. - // Not thread safe. - private def startWriteNewDataFile(): Unit = { - if (messageStoreOutputStream != null) { - logInfo(s"Need to write a new data file," - + s" close current data file index $currentMessageDataFileOffset") - messageStoreOutputStream.flush() - messageStoreOutputStream.hsync() - messageStoreOutputStream.close() - messageStoreOutputStream = null - } - currentMessageDataFileOffset += 1 - messageSizeForCurrentBatch = 0 - messageNumberForCurrentBatch = 0 - messageStoreOutputStream = null - } - - // not thread safe - private def addReceivedMessageInfo(messageNum: Int, messageSize: Int): Unit = { - messageSizeForCurrentBatch += messageSize - messageNumberForCurrentBatch += messageNum - } - - // not thread safe - private def hasNewMessageForCurrentBatch(): Boolean = { - currentMessageDataFileOffset > lastOffset + 1 || messageNumberForCurrentBatch > 0 - } - - private def withLock[T](body: => T): T = { - lock.lock() - try body - finally lock.unlock() - } - - private def initialize(): Unit = { - - // recover lastOffset from WAL log - if (recoveredLatestOffset.nonEmpty) { - lastOffset = recoveredLatestOffset.get - logInfo(s"Recover lastOffset value ${lastOffset}") - } - - fs = FileSystem.get(hadoopConfig) - - // recover message data file offset from hdfs - val dataPath = new Path(receivedDataPath) - if (fs.exists(dataPath)) { - val fileManager = CheckpointFileManager.create(dataPath, hadoopConfig) - val dataFileIndexs = fileManager.list(dataPath, new PathFilter { - private def isBatchFile(path: Path) = { - try { - path.getName.toLong - true - } catch { - case _: NumberFormatException => false - } - } - - override def accept(path: Path): Boolean = isBatchFile(path) - }).map(_.getPath.getName.toLong) - if (dataFileIndexs.nonEmpty) { - currentMessageDataFileOffset = dataFileIndexs.max + 1 - assert(currentMessageDataFileOffset >= lastOffset + 1, - s"Recovered invalid message data file offset $currentMessageDataFileOffset," - + s"do not match with lastOffset $lastOffset") - logInfo(s"Recovered last message data file offset: ${currentMessageDataFileOffset - 1}, " - + s"start from $currentMessageDataFileOffset") - } else { - logInfo("No old data file exist, start data file index from 0") - currentMessageDataFileOffset = 0 - } - } else { - logInfo(s"Create data dir $receivedDataPath, start data file index from 0") - fs.mkdirs(dataPath) - currentMessageDataFileOffset = 0 - } - - client = new MqttClient(brokerUrl, clientId, new MemoryPersistence()) - - val callback = new MqttCallbackExtended() { - - override def messageArrived(topic: String, message: MqttMessage): Unit = { - withLock[Unit] { - val messageSize = message.getPayload.size - // check if have reached the max number or max size for current batch. - if (messageNumberForCurrentBatch + 1 > maxBatchNumber - || messageSizeForCurrentBatch + messageSize > maxBatchSize) { - startWriteNewDataFile() - } - // write message content to data file - if (messageStoreOutputStream == null) { - val path = new Path(s"${receivedDataPath}/${currentMessageDataFileOffset}") - if (fs.createNewFile(path)) { - logInfo(s"Create new message data file ${path.toUri.toString} success!") - } else { - throw new IOException(s"${path.toUri.toString} already exist," - + s"make sure do use unique checkpoint path for each app.") - } - messageStoreOutputStream = fs.append(path) - } - - messageStoreOutputStream.writeBytes(s"${message.getId}${SEP}") - messageStoreOutputStream.writeBytes(s"${topic}${SEP}") - val timestamp = Calendar.getInstance().getTimeInMillis().toString - messageStoreOutputStream.writeBytes(s"${timestamp}${SEP}") - messageStoreOutputStream.write(message.getPayload()) - messageStoreOutputStream.writeBytes("\n") - addReceivedMessageInfo(1, messageSize) - consecutiveFailNum = 0 - logInfo(s"Message arrived, topic: $topic, message payload $message, " - + s"messageId: ${message.getId}, message size: ${messageSize}") - } - } - - override def deliveryComplete(token: IMqttDeliveryToken): Unit = { - // callback for publisher, no need here. - } - - override def connectionLost(cause: Throwable): Unit = { - // auto reconnection is enabled, so just add a log here. - withLock[Unit] { - consecutiveFailNum += 1 - logWarning(s"Connection to mqtt server lost, " - + s"consecutive fail number $consecutiveFailNum", cause) - } - } - - override def connectComplete(reconnect: Boolean, serverURI: String): Unit = { - logInfo(s"Connect complete $serverURI. Is it a reconnect?: $reconnect") - } - } - client.setCallback(callback) - client.connect(mqttConnectOptions) - client.subscribe(topic, qos) - } - - /** Stop this source and free any resources it has allocated. */ - override def stop(): Unit = { - logInfo("Stop mqtt source.") - client.disconnect() - client.close() - withLock[Unit] { - if (messageStoreOutputStream != null) { - messageStoreOutputStream.hflush() - messageStoreOutputStream.hsync() - messageStoreOutputStream.close() - messageStoreOutputStream = null - } - fs.close() - } - } - - /** Returns the maximum available offset for this source. */ - override def getOffset: Option[Offset] = { - withLock[Option[Offset]] { - assert(consecutiveFailNum < maxRetryNumber, - s"Write message data fail continuously for ${maxRetryNumber} times.") - val result = if (!hasNewMessageForCurrentBatch()) { - if (lastOffset == -1) { - // first submit and no message has arrived. - None - } else { - // no message has arrived for this batch. - Some(LongOffset(lastOffset)) - } - } else { - // check if currently write the batch to be executed. - if (currentMessageDataFileOffset == lastOffset + 1) { - startWriteNewDataFile() - } - lastOffset += 1 - Some(LongOffset(lastOffset)) - } - logInfo(s"getOffset result $result") - result - } - } - - /** - * Returns the data that is between the offsets (`start`, `end`]. - * The batch return the data in file ${checkpointPath}/receivedMessages/${end}. - * `Start` and `end` value have the relationship: `end value` = `start valud` + 1, - * if `start` is not None. - */ - override def getBatch(start: Option[Offset], end: Offset): DataFrame = { - withLock[Unit]{ - assert(consecutiveFailNum < maxRetryNumber, - s"Write message data fail continuously for ${maxRetryNumber} times.") - } - logInfo(s"getBatch with start = $start, end = $end") - val endIndex = getOffsetValue(end) - if (start.nonEmpty) { - val startIndex = getOffsetValue(start.get) - assert(startIndex + 1 == endIndex, - s"start offset: ${startIndex} and end offset: ${endIndex} do not match") - } - logTrace(s"Create a data frame using hdfs file $receivedDataPath/$endIndex") - val rdd = sqlContext.sparkContext.textFile(s"$receivedDataPath/$endIndex") - .map{case str => - // calculate message in - val idIndex = str.indexOf(SEP) - val messageId = str.substring(0, idIndex).toInt - // get topic - var subStr = str.substring(idIndex + SEP.length) - val topicIndex = subStr.indexOf(SEP) - val topic = UTF8String.fromString(subStr.substring(0, topicIndex)) - // get timestamp - subStr = subStr.substring(topicIndex + SEP.length) - val timestampIndex = subStr.indexOf(SEP) - /* - val timestamp = Timestamp.valueOf( - MQTTStreamConstants.DATE_FORMAT.format(subStr.substring(0, timestampIndex).toLong)) - */ - val timestamp = subStr.substring(0, timestampIndex).toLong - // get playload - subStr = subStr.substring(timestampIndex + SEP.length) - val payload = UTF8String.fromString(subStr).getBytes - InternalRow(messageId, topic, payload, timestamp) - } - sqlContext.internalCreateDataFrame(rdd, MQTTStreamConstants.SCHEMA_DEFAULT, true) - } - - /** - * Remove the data file for the offset. - * - * @param end the end of offset that all data has been committed. - */ - override def commit(end: Offset): Unit = { - val offsetValue = getOffsetValue(end) - if (offsetValue >= minBatchesToRetain) { - val deleteDataFileOffset = offsetValue - minBatchesToRetain - try { - fs.delete(new Path(s"$receivedDataPath/$deleteDataFileOffset"), false) - logInfo(s"Delete committed offset data file $deleteDataFileOffset success!") - } catch { - case e: Exception => - logWarning(s"Delete committed offset data file $deleteDataFileOffset failed. ", e) - } - } - } - - private def getOffsetValue(offset: Offset): Long = { - val offsetValue = offset match { - case o: LongOffset => o.offset - case so: SerializedOffset => - so.json.toLong - } - offsetValue - } -} -object HdfsBasedMQTTStreamSource { - - var hadoopConfig: Configuration = _ -} diff --git a/sql-streaming-mqtt/src/test/scala/org/apache/bahir/sql/streaming/mqtt/HDFSBasedMQTTStreamSourceSuite.scala b/sql-streaming-mqtt/src/test/scala/org/apache/bahir/sql/streaming/mqtt/HDFSBasedMQTTStreamSourceSuite.scala deleted file mode 100644 index 777db161..00000000 --- a/sql-streaming-mqtt/src/test/scala/org/apache/bahir/sql/streaming/mqtt/HDFSBasedMQTTStreamSourceSuite.scala +++ /dev/null @@ -1,198 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.bahir.sql.streaming.mqtt - -import java.io.File - -import scala.collection.JavaConverters._ -import scala.collection.mutable - -import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.hdfs.MiniDFSCluster -import org.apache.hadoop.security.Groups -import org.eclipse.paho.client.mqttv3.MqttException -import org.scalatest.BeforeAndAfter - -import org.apache.spark.{SharedSparkContext, SparkFunSuite} -import org.apache.spark.sql._ -import org.apache.spark.sql.mqtt.{HdfsBasedMQTTStreamSource, HDFSMQTTSourceProvider} -import org.apache.spark.sql.streaming.{DataStreamReader, StreamingQuery} - -import org.apache.bahir.utils.FileHelper - -class HDFSBasedMQTTStreamSourceSuite - extends SparkFunSuite - with SharedSparkContext - with BeforeAndAfter { - - protected var mqttTestUtils: MQTTTestUtils = _ - protected val tempDir: File = new File(System.getProperty("java.io.tmpdir") + "/mqtt-test/") - protected var hadoop: MiniDFSCluster = _ - - before { - tempDir.mkdirs() - if (!tempDir.exists()) { - throw new IllegalStateException("Unable to create temp directories.") - } - tempDir.deleteOnExit() - mqttTestUtils = new MQTTTestUtils(tempDir) - mqttTestUtils.setup() - hadoop = HDFSTestUtils.prepareHadoop() - } - - after { - mqttTestUtils.teardown() - HDFSTestUtils.shutdownHadoop() - FileHelper.deleteFileQuietly(tempDir) - } - - protected val tmpDir: String = tempDir.getAbsolutePath - - protected def writeStreamResults(sqlContext: SQLContext, dataFrame: DataFrame): StreamingQuery = { - import sqlContext.implicits._ - val query: StreamingQuery = dataFrame.selectExpr("CAST(payload AS STRING)").as[String] - .writeStream.format("csv").start(s"$tempDir/t.csv") - while (!query.status.isTriggerActive) { - Thread.sleep(20) - } - query - } - - protected def readBackStreamingResults(sqlContext: SQLContext): mutable.Buffer[String] = { - import sqlContext.implicits._ - val asList = - sqlContext.read - .csv(s"$tmpDir/t.csv").as[String] - .collectAsList().asScala - asList - } - - protected def createStreamingDataFrame(dir: String = tmpDir): (SQLContext, DataFrame) = { - - val sqlContext: SQLContext = SparkSession.builder() - .getOrCreate().sqlContext - - sqlContext.setConf("spark.sql.streaming.checkpointLocation", - s"hdfs://localhost:${hadoop.getNameNodePort}/testCheckpoint") - - val ds: DataStreamReader = - sqlContext.readStream.format("org.apache.spark.sql.mqtt.HDFSMQTTSourceProvider") - .option("topic", "test").option("clientId", "clientId").option("connectionTimeout", "120") - .option("keepAlive", "1200").option("autoReconnect", "false") - .option("cleanSession", "true").option("QoS", "2") - val dataFrame = ds.load("tcp://" + mqttTestUtils.brokerUri) - (sqlContext, dataFrame) - } -} - -object HDFSTestUtils { - - private var hadoop: MiniDFSCluster = _ - - def prepareHadoop(): MiniDFSCluster = { - if (hadoop != null) { - hadoop - } else { - val baseDir = new File(System.getProperty("java.io.tmpdir") + "/hadoop").getAbsoluteFile - System.setProperty("HADOOP_USER_NAME", "test") - val conf = new Configuration - conf.set(MiniDFSCluster.HDFS_MINIDFS_BASEDIR, baseDir.getAbsolutePath) - conf.setBoolean("dfs.namenode.acls.enabled", true) - conf.setBoolean("dfs.permissions", true) - Groups.getUserToGroupsMappingService(conf) - val builder = new MiniDFSCluster.Builder(conf) - hadoop = builder.build - conf.set("fs.defaultFS", "hdfs://localhost:" + hadoop.getNameNodePort + "/") - HdfsBasedMQTTStreamSource.hadoopConfig = conf - hadoop - } - } - - def shutdownHadoop(): Unit = { - if (null != hadoop) { - hadoop.shutdown(true) - } - hadoop = null - } -} - -class BasicHDFSBasedMQTTSourceSuite extends HDFSBasedMQTTStreamSourceSuite { - - test("basic usage") { - - val sendMessage = "MQTT is a message queue." - - val (sqlContext: SQLContext, dataFrame: DataFrame) = createStreamingDataFrame() - - val query = writeStreamResults(sqlContext, dataFrame) - mqttTestUtils.publishData("test", sendMessage) - query.processAllAvailable() - query.awaitTermination(10000) - - val resultBuffer: mutable.Buffer[String] = readBackStreamingResults(sqlContext) - - assert(resultBuffer.size == 1) - assert(resultBuffer.head == sendMessage) - } - - test("Send and receive 50 messages.") { - - val sendMessage = "MQTT is a message queue." - - val (sqlContext: SQLContext, dataFrame: DataFrame) = createStreamingDataFrame() - - val q = writeStreamResults(sqlContext, dataFrame) - - mqttTestUtils.publishData("test", sendMessage, 50) - q.processAllAvailable() - q.awaitTermination(10000) - - val resultBuffer: mutable.Buffer[String] = readBackStreamingResults(sqlContext) - - assert(resultBuffer.size == 50) - assert(resultBuffer.head == sendMessage) - } - - test("no server up") { - val provider = new HDFSMQTTSourceProvider - val sqlContext: SQLContext = SparkSession.builder().getOrCreate().sqlContext - intercept[MqttException] { - provider.createSource( - sqlContext, - s"hdfs://localhost:${hadoop.getNameNodePort}/testCheckpoint/0", - Some(MQTTStreamConstants.SCHEMA_DEFAULT), - "org.apache.spark.sql.mqtt.HDFSMQTTSourceProvider", - Map("brokerUrl" -> "tcp://localhost:1881", "topic" -> "test") - ) - } - } - - test("params not provided.") { - val provider = new HDFSMQTTSourceProvider - val sqlContext: SQLContext = SparkSession.builder().getOrCreate().sqlContext - intercept[IllegalArgumentException] { - provider.createSource( - sqlContext, - s"hdfs://localhost:${hadoop.getNameNodePort}/testCheckpoint/0", - Some(MQTTStreamConstants.SCHEMA_DEFAULT), - "org.apache.spark.sql.mqtt.HDFSMQTTSourceProvider", - Map() - ) - } - } -} diff --git a/sql-streaming-mqtt/src/test/scala/org/apache/bahir/sql/streaming/mqtt/HDFSMessageStoreSuite.scala b/sql-streaming-mqtt/src/test/scala/org/apache/bahir/sql/streaming/mqtt/HDFSMessageStoreSuite.scala new file mode 100644 index 00000000..4c24a337 --- /dev/null +++ b/sql-streaming-mqtt/src/test/scala/org/apache/bahir/sql/streaming/mqtt/HDFSMessageStoreSuite.scala @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.bahir.sql.streaming.mqtt + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.hdfs.MiniDFSCluster +import org.eclipse.paho.client.mqttv3.MqttClientPersistence +import org.scalatest.BeforeAndAfter + +import org.apache.spark.{SharedSparkContext, SparkFunSuite} + + +class HDFSMessageStoreSuite extends SparkFunSuite with SharedSparkContext with BeforeAndAfter { + + private val testData = Seq(1, 2, 3, 4, 5, 6) + private val javaSerializer: JavaSerializer = new JavaSerializer() + private val serializerInstance = javaSerializer + private var config: Configuration = _ + private var hadoop: MiniDFSCluster = _ + private var persistence: MqttClientPersistence = _ + private var store: LocalMessageStore = _ + + override def beforeAll() { + val (hadoopConfig, hadoopInstance) = HDFSTestUtils.prepareHadoop() + config = hadoopConfig + hadoop = hadoopInstance + persistence = new HdfsMqttClientPersistence(config) + persistence.open("temp", "tcp://dummy-url:0000") + store = new LocalMessageStore(persistence, javaSerializer) + } + + override def afterAll() { + store = null + persistence.clear() + persistence.close() + persistence = null + if (hadoop != null) { + hadoop.shutdown(true) + } + hadoop = null + config = null + } + + test("serialize and deserialize") { + val serialized = serializerInstance.serialize(testData) + val deserialized: Seq[Int] = serializerInstance + .deserialize(serialized).asInstanceOf[Seq[Int]] + assert(testData === deserialized) + } + + test("Store, retrieve and remove") { + store.store(1, testData) + var result: Seq[Int] = store.retrieve(1) + assert(testData === result) + store.remove(1) + } + + test("Max offset stored") { + store.store(1, testData) + store.store(10, testData) + val offset = store.maxProcessedOffset + assert(offset == 10) + } + +} diff --git a/sql-streaming-mqtt/src/test/scala/org/apache/bahir/sql/streaming/mqtt/HDFSTestUtils.scala b/sql-streaming-mqtt/src/test/scala/org/apache/bahir/sql/streaming/mqtt/HDFSTestUtils.scala new file mode 100644 index 00000000..33376dff --- /dev/null +++ b/sql-streaming-mqtt/src/test/scala/org/apache/bahir/sql/streaming/mqtt/HDFSTestUtils.scala @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.bahir.sql.streaming.mqtt + +import java.io.File + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.hdfs.MiniDFSCluster +import org.apache.hadoop.security.Groups + +object HDFSTestUtils { + def prepareHadoop(): (Configuration, MiniDFSCluster) = { + val baseDir = new File(System.getProperty("java.io.tmpdir") + "/hadoop").getAbsoluteFile + System.setProperty("HADOOP_USER_NAME", "test") + val config = new Configuration + config.set(MiniDFSCluster.HDFS_MINIDFS_BASEDIR, baseDir.getAbsolutePath) + config.set("dfs.namenode.acls.enabled", "true") + config.set("dfs.permissions", "true") + Groups.getUserToGroupsMappingService(config) + val builder = new MiniDFSCluster.Builder(config) + val hadoop = builder.build + config.set("fs.defaultFS", "hdfs://localhost:" + hadoop.getNameNodePort + "/") + (config, hadoop) + } +} diff --git a/sql-streaming-mqtt/src/test/scala/org/apache/bahir/sql/streaming/mqtt/LocalMessageStoreSuite.scala b/sql-streaming-mqtt/src/test/scala/org/apache/bahir/sql/streaming/mqtt/LocalMessageStoreSuite.scala index 0b6b80b6..9656a84b 100644 --- a/sql-streaming-mqtt/src/test/scala/org/apache/bahir/sql/streaming/mqtt/LocalMessageStoreSuite.scala +++ b/sql-streaming-mqtt/src/test/scala/org/apache/bahir/sql/streaming/mqtt/LocalMessageStoreSuite.scala @@ -33,7 +33,7 @@ class LocalMessageStoreSuite extends SparkFunSuite with BeforeAndAfter { private val javaSerializer: JavaSerializer = new JavaSerializer() private val serializerInstance = javaSerializer - private val tempDir: File = new File(System.getProperty("java.io.tmpdir") + "/mqtt-test2/") + private val tempDir: File = new File(System.getProperty("java.io.tmpdir") + "/mqtt-test-local/") private val persistence: MqttDefaultFilePersistence = new MqttDefaultFilePersistence(tempDir.getAbsolutePath) @@ -52,7 +52,7 @@ class LocalMessageStoreSuite extends SparkFunSuite with BeforeAndAfter { } test("serialize and deserialize") { - val serialized = serializerInstance.serialize(testData) + val serialized = serializerInstance.serialize(testData) val deserialized: Seq[Int] = serializerInstance .deserialize(serialized).asInstanceOf[Seq[Int]] assert(testData === deserialized) diff --git a/sql-streaming-mqtt/src/test/scala/org/apache/bahir/sql/streaming/mqtt/MQTTStreamSourceSuite.scala b/sql-streaming-mqtt/src/test/scala/org/apache/bahir/sql/streaming/mqtt/MQTTStreamSourceSuite.scala index 39cf0df2..0de05604 100644 --- a/sql-streaming-mqtt/src/test/scala/org/apache/bahir/sql/streaming/mqtt/MQTTStreamSourceSuite.scala +++ b/sql-streaming-mqtt/src/test/scala/org/apache/bahir/sql/streaming/mqtt/MQTTStreamSourceSuite.scala @@ -23,6 +23,8 @@ import java.util.Optional import scala.collection.JavaConverters._ import scala.collection.mutable +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.hdfs.MiniDFSCluster import org.eclipse.paho.client.mqttv3.MqttConnectOptions import org.eclipse.paho.client.mqttv3.MqttException import org.eclipse.paho.client.mqttv3.persist.MemoryPersistence @@ -81,8 +83,9 @@ class MQTTStreamSourceSuite extends SparkFunSuite asList } - protected def createStreamingDataFrame(dir: String = tmpDir, - filePersistence: Boolean = false): (SQLContext, DataFrame) = { + protected def createStreamingDataFrame(persistenceMode: String = "memory", + persistenceConfig: Map[String, String] = Map()) + : (SQLContext, DataFrame) = { val sqlContext: SQLContext = SparkSession.builder() .getOrCreate().sqlContext @@ -95,12 +98,8 @@ class MQTTStreamSourceSuite extends SparkFunSuite .option("keepAlive", "1200").option("maxInflight", "120").option("autoReconnect", "false") .option("cleanSession", "true").option("QoS", "2") - val dataFrame = if (!filePersistence) { - ds.option("persistence", "memory").load("tcp://" + mqttTestUtils.brokerUri) - } else { - ds.option("persistence", "file").option("localStorage", tmpDir) - .load("tcp://" + mqttTestUtils.brokerUri) - } + val dataFrame = ds.option("persistence", persistenceMode).options(persistenceConfig) + .load("tcp://" + mqttTestUtils.brokerUri) (sqlContext, dataFrame) } @@ -193,6 +192,72 @@ class BasicMQTTSourceSuite extends MQTTStreamSourceSuite { } +class FileMQTTSourceSuite extends MQTTStreamSourceSuite { + + test("Send and receive 50 messages.") { + val sendMessage = "MQTT is a message queue." + val (sqlContext: SQLContext, dataFrame: DataFrame) = + createStreamingDataFrame("file", Map("localStorage" -> tmpDir)) + val q = writeStreamResults(sqlContext, dataFrame) + + mqttTestUtils.publishData("test", sendMessage, 50) + q.processAllAvailable() + q.awaitTermination(10000) + + val resultBuffer: mutable.Buffer[String] = readBackStreamingResults(sqlContext) + + assert(resultBuffer.size == 50) + assert(resultBuffer.head == sendMessage) + } + +} + +class HDFSMQTTSourceSuite extends MQTTStreamSourceSuite { + + private var hadoopConfig: Configuration = _ + private var hadoopInstance: MiniDFSCluster = _ + + override def beforeAll(): Unit = { + super.beforeAll + val (config, instance) = HDFSTestUtils.prepareHadoop() + hadoopConfig = config + hadoopInstance = instance + } + + override def afterAll(): Unit = { + super.afterAll + if (hadoopInstance != null) { + hadoopInstance.shutdown(true) + } + hadoopInstance = null + hadoopConfig = null + } + + test("Send and receive 50 messages.") { + val sendMessage = "MQTT is a message queue." + val (sqlContext: SQLContext, dataFrame: DataFrame) = + createStreamingDataFrame( + "hdfs", Map( + "hdfs.hdfs.minidfs.basedir" -> hadoopConfig.get("hdfs.minidfs.basedir"), + "hdfs.fs.defaultFS" -> hadoopConfig.get("fs.defaultFS"), + "hdfs.dfs.namenode.acls.enabled" -> hadoopConfig.get("dfs.namenode.acls.enabled"), + "hdfs.dfs.permissions" -> hadoopConfig.get("dfs.permissions") + ) + ) + val q = writeStreamResults(sqlContext, dataFrame) + + mqttTestUtils.publishData("test", sendMessage, 50) + q.processAllAvailable() + q.awaitTermination(10000) + + val resultBuffer: mutable.Buffer[String] = readBackStreamingResults(sqlContext) + + assert(resultBuffer.size == 50) + assert(resultBuffer.head == sendMessage) + } + +} + class StressTestMQTTSource extends MQTTStreamSourceSuite { // Run with -Xmx1024m