From 7f96294ecd90b489b5a9ee88abe99a7715cbb41f Mon Sep 17 00:00:00 2001 From: giwa Date: Mon, 11 Aug 2014 03:21:22 -0700 Subject: [PATCH] added basic operation test cases --- .../main/python/streaming/test_oprations.py | 19 ++-- python/pyspark/streaming/context.py | 43 +++++---- python/pyspark/streaming/dstream.py | 8 +- python/pyspark/streaming_tests.py | 95 +++++++++++++++---- .../streaming/api/python/PythonDStream.scala | 2 - 5 files changed, 113 insertions(+), 54 deletions(-) diff --git a/examples/src/main/python/streaming/test_oprations.py b/examples/src/main/python/streaming/test_oprations.py index 24ebe23d63166..70a62058286e9 100644 --- a/examples/src/main/python/streaming/test_oprations.py +++ b/examples/src/main/python/streaming/test_oprations.py @@ -9,22 +9,23 @@ conf = SparkConf() conf.setAppName("PythonStreamingNetworkWordCount") ssc = StreamingContext(conf=conf, duration=Seconds(1)) - - test_input = ssc._testInputStream([1,2,3]) - class buff: + class Buff: + result = list() pass + Buff.result = list() + + test_input = ssc._testInputStream([range(1,4), range(4,7), range(7,10)]) fm_test = test_input.map(lambda x: (x, 1)) - fm_test.test_output(buff) + fm_test.pyprint() + fm_test._test_output(Buff.result) ssc.start() while True: ssc.awaitTermination(50) - try: - buff.result + if len(Buff.result) == 3: break - except AttributeError: - pass ssc.stop() - print buff.result + print Buff.result + diff --git a/python/pyspark/streaming/context.py b/python/pyspark/streaming/context.py index 0d7665d645be8..be142fd4f327b 100644 --- a/python/pyspark/streaming/context.py +++ b/python/pyspark/streaming/context.py @@ -123,14 +123,14 @@ def textFileStream(self, directory): """ return DStream(self._jssc.textFileStream(directory), self, UTF8Deserializer()) - def stop(self, stopSparkContext=True): + def stop(self, stopSparkContext=True, stopGraceFully=False): """ Stop the execution of the streams immediately (does not wait for all received data to be processed). """ try: - self._jssc.stop(stopSparkContext) + self._jssc.stop(stopSparkContext, stopGraceFully) finally: # Stop Callback server SparkContext._gateway.shutdown() @@ -141,27 +141,34 @@ def checkpoint(self, directory): """ self._jssc.checkpoint(directory) - def _testInputStream(self, test_input, numSlices=None): - + def _testInputStream(self, test_inputs, numSlices=None): + """ + Generate multiple files to make "stream" in Scala side for test. + Scala chooses one of the files and generates RDD using PythonRDD.readRDDFromFile. + """ numSlices = numSlices or self._sc.defaultParallelism # Calling the Java parallelize() method with an ArrayList is too slow, # because it sends O(n) Py4J commands. As an alternative, serialized # objects are written to a file and loaded through textFile(). - tempFile = NamedTemporaryFile(delete=False, dir=self._sc._temp_dir) - - # Make sure we distribute data evenly if it's smaller than self.batchSize - if "__len__" not in dir(test_input): - c = list(test_input) # Make it a list so we can compute its length - batchSize = min(len(test_input) // numSlices, self._sc._batchSize) - if batchSize > 1: - serializer = BatchedSerializer(self._sc._unbatched_serializer, - batchSize) - else: - serializer = self._sc._unbatched_serializer - serializer.dump_stream(test_input, tempFile) - + tempFiles = list() + for test_input in test_inputs: + tempFile = NamedTemporaryFile(delete=False, dir=self._sc._temp_dir) + + # Make sure we distribute data evenly if it's smaller than self.batchSize + if "__len__" not in dir(test_input): + c = list(test_input) # Make it a list so we can compute its length + batchSize = min(len(test_input) // numSlices, self._sc._batchSize) + if batchSize > 1: + serializer = BatchedSerializer(self._sc._unbatched_serializer, + batchSize) + else: + serializer = self._sc._unbatched_serializer + serializer.dump_stream(test_input, tempFile) + tempFiles.append(tempFile.name) + + jtempFiles = ListConverter().convert(tempFiles, SparkContext._gateway._gateway_client) jinput_stream = self._jvm.PythonTestInputStream(self._jssc, - tempFile.name, + jtempFiles, numSlices).asJavaDStream() return DStream(jinput_stream, self, PickleSerializer()) diff --git a/python/pyspark/streaming/dstream.py b/python/pyspark/streaming/dstream.py index 47196196466db..0f0a1847535ce 100644 --- a/python/pyspark/streaming/dstream.py +++ b/python/pyspark/streaming/dstream.py @@ -217,7 +217,6 @@ def pyprint(self): """ def takeAndPrint(rdd, time): - print "take and print ===================" taken = rdd.take(11) print "-------------------------------------------" print "Time: %s" % (str(time)) @@ -242,13 +241,10 @@ def _test_output(self, buff): Store data in dstream to buffer to valify the result in tesecase """ def get_output(rdd, time): - taken = rdd.take(11) - buff.result = taken + taken = rdd.collect() + buff.append(taken) self.foreachRDD(get_output) - def output(self): - self._jdstream.outputToFile() - class PipelinedDStream(DStream): def __init__(self, prev, func, preservesPartitioning=False): diff --git a/python/pyspark/streaming_tests.py b/python/pyspark/streaming_tests.py index 0660be10b027b..d2e638a7d2acc 100644 --- a/python/pyspark/streaming_tests.py +++ b/python/pyspark/streaming_tests.py @@ -35,76 +35,133 @@ import time import unittest import zipfile +import operator +from pyspark.context import SparkContext from pyspark.streaming.context import StreamingContext from pyspark.streaming.duration import * SPARK_HOME = os.environ["SPARK_HOME"] -class buff: +class StreamOutput: """ - Buffer for store the output from stream + a class to store the output from stream """ - result = None + result = list() class PySparkStreamingTestCase(unittest.TestCase): def setUp(self): - print "set up" class_name = self.__class__.__name__ self.ssc = StreamingContext(appName=class_name, duration=Seconds(1)) def tearDown(self): - print "tear donw" - self.ssc.stop() - time.sleep(10) + # Do not call StreamingContext.stop directly because we do not wait to shutdown + # call back server and py4j client + self.ssc._jssc.stop() + self.ssc._sc.stop() + # Why does it long time to terminaete StremaingContext and SparkContext? + # Should we change the sleep time if this depends on machine spec? + time.sleep(5) + + @classmethod + def tearDownClass(cls): + time.sleep(5) + SparkContext._gateway._shutdown_callback_server() class TestBasicOperationsSuite(PySparkStreamingTestCase): + """ + Input and output of this TestBasicOperationsSuite is the equivalent to + Scala TestBasicOperationsSuite. + """ def setUp(self): PySparkStreamingTestCase.setUp(self) - buff.result = None + StreamOutput.result = list() self.timeout = 10 # seconds def tearDown(self): PySparkStreamingTestCase.tearDown(self) + @classmethod + def tearDownClass(cls): + PySparkStreamingTestCase.tearDownClass() + def test_map(self): + """Basic operation test for DStream.map""" test_input = [range(1,5), range(5,9), range(9, 13)] def test_func(dstream): return dstream.map(lambda x: str(x)) - expected = map(str, test_input) - output = self.run_stream(test_input, test_func) - self.assertEqual(output, expected) + expected_output = map(lambda x: map(lambda y: str(y), x), test_input) + output = self._run_stream(test_input, test_func, expected_output) + self.assertEqual(expected_output, output) def test_flatMap(self): + """Basic operation test for DStream.faltMap""" test_input = [range(1,5), range(5,9), range(9, 13)] def test_func(dstream): return dstream.flatMap(lambda x: (x, x * 2)) - # Maybe there be good way to create flatmap - excepted = map(lambda x: list(chain.from_iterable((map(lambda y:[y, y*2], x)))), + expected_output = map(lambda x: list(chain.from_iterable((map(lambda y: [y, y * 2], x)))), test_input) - output = self.run_stream(test_input, test_func) + output = self._run_stream(test_input, test_func, expected_output) + self.assertEqual(expected_output, output) + + def test_filter(self): + """Basic operation test for DStream.filter""" + test_input = [range(1,5), range(5,9), range(9, 13)] + def test_func(dstream): + return dstream.filter(lambda x: x % 2 == 0) + expected_output = map(lambda x: filter(lambda y: y % 2 == 0, x), test_input) + output = self._run_stream(test_input, test_func, expected_output) + self.assertEqual(expected_output, output) + + def test_count(self): + """Basic operation test for DStream.count""" + test_input = [[], [1], range(1, 3), range(1,4), range(1,5)] + def test_func(dstream): + return dstream.count() + expected_output = map(lambda x: [len(x)], test_input) + output = self._run_stream(test_input, test_func, expected_output) + self.assertEqual(expected_output, output) + + def test_reduce(self): + """Basic operation test for DStream.reduce""" + test_input = [range(1,5), range(5,9), range(9, 13)] + def test_func(dstream): + return dstream.reduce(operator.add) + expected_output = map(lambda x: [reduce(operator.add, x)], test_input) + output = self._run_stream(test_input, test_func, expected_output) + self.assertEqual(expected_output, output) + + def test_reduceByKey(self): + """Basic operation test for DStream.reduceByKey""" + test_input = [["a", "a", "b"], ["", ""], []] + def test_func(dstream): + return dstream.map(lambda x: (x, 1)).reduceByKey(operator.add) + expected_output = [[("a", 2), ("b", 1)],[("", 2)], []] + output = self._run_stream(test_input, test_func, expected_output) + self.assertEqual(expected_output, output) - def run_stream(self, test_input, test_func): + def _run_stream(self, test_input, test_func, expected_output): + """Start stream and return the output""" # Generate input stream with user-defined input test_input_stream = self.ssc._testInputStream(test_input) # Applyed test function to stream test_stream = test_func(test_input_stream) # Add job to get outpuf from stream - test_stream._test_output(buff) + test_stream._test_output(StreamOutput.result) self.ssc.start() start_time = time.time() + # loop until get the result from stream while True: current_time = time.time() # check time out if (current_time - start_time) > self.timeout: - self.ssc.stop() break self.ssc.awaitTermination(50) - if buff.result is not None: + if len(expected_output) == len(StreamOutput.result): break - return buff.result + return StreamOutput.result if __name__ == "__main__": unittest.main() diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala index 734c2535ef8a3..5ef31b0f7bb3c 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala @@ -56,8 +56,6 @@ class PythonDStream[T: ClassTag]( override def compute(validTime: Time): Option[RDD[Array[Byte]]] = { parent.getOrCompute(validTime) match{ case Some(rdd) => - logInfo("RDD ID in python DStream ===========") - logInfo("RDD id " + rdd.id) val pythonRDD = new PythonRDD(rdd, command, envVars, pythonIncludes, preservePartitoning, pythonExec, broadcastVars, accumulator) Some(pythonRDD.asJavaRDD.rdd) case None => None