diff --git a/python/pyspark/streaming/context.py b/python/pyspark/streaming/context.py index a647c9ec734df..00a1ec6f31fec 100644 --- a/python/pyspark/streaming/context.py +++ b/python/pyspark/streaming/context.py @@ -184,7 +184,7 @@ def _check_serialzers(self, rdds): # reset them to sc.serializer rdds[i] = rdds[i].map(lambda x: x, preservesPartitioning=True) - def queueStream(self, queue, oneAtATime=False, default=None): + def queueStream(self, queue, oneAtATime=True, default=None): """ Create an input stream from an queue of RDDs or list. In each batch, it will process either one or all of the RDDs returned by the queue. @@ -200,9 +200,12 @@ def queueStream(self, queue, oneAtATime=False, default=None): self._check_serialzers(rdds) jrdds = ListConverter().convert([r._jrdd for r in rdds], SparkContext._gateway._gateway_client) - jdstream = self._jvm.PythonDataInputStream(self._jssc, jrdds, oneAtATime, - default and default._jrdd) - return DStream(jdstream.asJavaDStream(), self, rdds[0]._jrdd_deserializer) + queue = self._jvm.PythonDStream.toRDDQueue(jrdds) + if default: + jdstream = self._jssc.queueStream(queue, oneAtATime, default._jrdd) + else: + jdstream = self._jssc.queueStream(queue, oneAtATime) + return DStream(jdstream, self, rdds[0]._jrdd_deserializer) def transform(self, dstreams, transformFunc): """ diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala index c0a1aa71840a5..d7dd0a0c5c88b 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala @@ -18,6 +18,7 @@ package org.apache.spark.streaming.api.python import java.util.{ArrayList => JArrayList} +import scala.collection.JavaConversions._ import org.apache.spark.rdd.RDD import org.apache.spark.api.java._ @@ -65,6 +66,16 @@ abstract class PythonDStream(parent: DStream[_]) extends DStream[Array[Byte]] (p val asJavaDStream = JavaDStream.fromDStream(this) } +object PythonDStream { + + // convert list of RDD into queue of RDDs, for ssc.queueStream() + def toRDDQueue(rdds: JArrayList[JavaRDD[Array[Byte]]]): java.util.Queue[JavaRDD[Array[Byte]]] = { + val queue = new java.util.LinkedList[JavaRDD[Array[Byte]]] + rdds.forall(queue.add(_)) + queue + } +} + /** * Transformed DStream in Python. * @@ -243,46 +254,4 @@ class PythonForeachDStream( ) { this.register() -} - - -/** - * similar to QueueInputStream - */ -class PythonDataInputStream( - ssc_ : JavaStreamingContext, - inputRDDs: JArrayList[JavaRDD[Array[Byte]]], - oneAtAtime: Boolean, - defaultRDD: JavaRDD[Array[Byte]] - ) extends InputDStream[Array[Byte]](JavaStreamingContext.toStreamingContext(ssc_)) { - - val emptyRDD = if (defaultRDD != null) { - Some(defaultRDD.rdd) - } else { - Some(ssc.sparkContext.emptyRDD[Array[Byte]]) - } - - def start() {} - - def stop() {} - - def compute(validTime: Time): Option[RDD[Array[Byte]]] = { - val index = ((validTime - zeroTime) / slideDuration - 1).toInt - if (oneAtAtime) { - if (index == 0) { - val rdds = inputRDDs.toArray.map(_.asInstanceOf[JavaRDD[Array[Byte]]].rdd).toSeq - Some(ssc.sparkContext.union(rdds)) - } else { - emptyRDD - } - } else { - if (index < inputRDDs.size()) { - Some(inputRDDs.get(index).rdd) - } else { - emptyRDD - } - } - } - - val asJavaDStream = JavaDStream.fromDStream(this) -} +} \ No newline at end of file