From f671cdb57475cac5a0418898c42a02df91c83ed5 Mon Sep 17 00:00:00 2001 From: giwa Date: Tue, 5 Aug 2014 00:09:38 -0700 Subject: [PATCH] WIP: added PythonTestInputStream --- .../main/python/streaming/test_oprations.py | 14 +++-------- python/pyspark/streaming/context.py | 25 +++++++++++++++++++ python/pyspark/streaming/dstream.py | 1 + .../api/java/JavaStreamingContext.scala | 3 +++ .../streaming/api/python/PythonDStream.scala | 13 ++++++---- 5 files changed, 41 insertions(+), 15 deletions(-) diff --git a/examples/src/main/python/streaming/test_oprations.py b/examples/src/main/python/streaming/test_oprations.py index 084902b6a2f0d..3338a766b9cc3 100644 --- a/examples/src/main/python/streaming/test_oprations.py +++ b/examples/src/main/python/streaming/test_oprations.py @@ -6,20 +6,14 @@ from pyspark.streaming.duration import * if __name__ == "__main__": - if len(sys.argv) != 3: - print >> sys.stderr, "Usage: wordcount " - exit(-1) conf = SparkConf() conf.setAppName("PythonStreamingNetworkWordCount") ssc = StreamingContext(conf=conf, duration=Seconds(1)) - lines = ssc.socketTextStream(sys.argv[1], int(sys.argv[2])) - words = lines.flatMap(lambda line: line.split(" ")) -# ssc.checkpoint("checkpoint") - mapped_words = words.map(lambda word: (word, 1)) - count = mapped_words.reduceByKey(add) + test_input = ssc._testInputStream([1,1,1,1]) + mapped = test_input.map(lambda x: (x, 1)) + mapped.pyprint() - count.pyprint() ssc.start() - ssc.awaitTermination() +# ssc.awaitTermination() # ssc.stop() diff --git a/python/pyspark/streaming/context.py b/python/pyspark/streaming/context.py index dfaa5cfbbae27..d544eab9b8fc7 100644 --- a/python/pyspark/streaming/context.py +++ b/python/pyspark/streaming/context.py @@ -17,6 +17,7 @@ import sys from signal import signal, SIGTERM, SIGINT +from tempfile import NamedTemporaryFile from pyspark.conf import SparkConf from pyspark.files import SparkFiles @@ -138,3 +139,27 @@ def checkpoint(self, directory): """ """ self._jssc.checkpoint(directory) + + def _testInputStream(self, test_input, numSlices=None): + + numSlices = numSlices or self._sc.defaultParallelism + # Calling the Java parallelize() method with an ArrayList is too slow, + # because it sends O(n) Py4J commands. As an alternative, serialized + # objects are written to a file and loaded through textFile(). + tempFile = NamedTemporaryFile(delete=False, dir=self._sc._temp_dir) + # Make sure we distribute data evenly if it's smaller than self.batchSize + if "__len__" not in dir(test_input): + c = list(test_input) # Make it a list so we can compute its length + batchSize = min(len(test_input) // numSlices, self._sc._batchSize) + if batchSize > 1: + serializer = BatchedSerializer(self._sc._unbatched_serializer, + batchSize) + else: + serializer = self._sc._unbatched_serializer + serializer.dump_stream(test_input, tempFile) + tempFile.close() + print tempFile.name + jinput_stream = self._jvm.PythonTestInputStream(self._jssc, + tempFile.name, + numSlices).asJavaDStream() + return DStream(jinput_stream, self, UTF8Deserializer()) diff --git a/python/pyspark/streaming/dstream.py b/python/pyspark/streaming/dstream.py index 3026254f8fab6..77c9a22239c69 100644 --- a/python/pyspark/streaming/dstream.py +++ b/python/pyspark/streaming/dstream.py @@ -141,6 +141,7 @@ def _mergeCombiners(iterator): combiners[k] = v else: combiners[k] = mergeCombiners(combiners[k], v) + return combiners.iteritems() return shuffled._mapPartitions(_mergeCombiners) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala index 18605cac7006c..b51d5ff0be9fc 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala @@ -546,6 +546,9 @@ class JavaStreamingContext(val ssc: StreamingContext) { * JavaStreamingContext object contains a number of utility functions. */ object JavaStreamingContext { + implicit def fromStreamingContext(ssc: StreamingContext): JavaStreamingContext = new JavaStreamingContext(ssc) + + implicit def toStreamingContext(jssc: JavaStreamingContext): StreamingContext = jssc.ssc /** * Either recreate a StreamingContext from checkpoint data or create a new StreamingContext. diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala index 861def33671f1..96440b15d0285 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala @@ -23,6 +23,7 @@ import scala.reflect.ClassTag import org.apache.spark._ import org.apache.spark.rdd.RDD +import org.apache.spark.api.java._ import org.apache.spark.api.python._ import org.apache.spark.broadcast.Broadcast import org.apache.spark.streaming.{StreamingContext, Duration, Time} @@ -130,10 +131,10 @@ class PythonTransformedDStream( /** * This is a input stream just for the unitest. This is equivalent to a checkpointable, * replayable, reliable message queue like Kafka. It requires a sequence as input, and - * returns the i_th element at the i_th batch unde manual clock. + * returns the i_th element at the i_th batch under manual clock. */ -class PythonTestInputStream(ssc_ : StreamingContext, filename: String, numPartitions: Int) - extends InputDStream[Array[Byte]](ssc_) { +class PythonTestInputStream(ssc_ : JavaStreamingContext, filename: String, numPartitions: Int) + extends InputDStream[Array[Byte]](JavaStreamingContext.toStreamingContext(ssc_)){ def start() {} @@ -141,7 +142,7 @@ class PythonTestInputStream(ssc_ : StreamingContext, filename: String, numPartit def compute(validTime: Time): Option[RDD[Array[Byte]]] = { logInfo("Computing RDD for time " + validTime) - val index = ((validTime - zeroTime) / slideDuration - 1).toInt + //val index = ((validTime - zeroTime) / slideDuration - 1).toInt //val selectedInput = if (index < input.size) input(index) else Seq[T]() // lets us test cases where RDDs are not created @@ -149,8 +150,10 @@ class PythonTestInputStream(ssc_ : StreamingContext, filename: String, numPartit // return None //val rdd = ssc.sc.makeRDD(selectedInput, numPartitions) - val rdd = PythonRDD.readRDDFromFile(ssc.sc, filename, numPartitions).rdd + val rdd = PythonRDD.readRDDFromFile(JavaSparkContext.fromSparkContext(ssc_.sparkContext), filename, numPartitions).rdd logInfo("Created RDD " + rdd.id + " with " + filename) Some(rdd) } + + val asJavaDStream = JavaDStream.fromDStream(this) } \ No newline at end of file