diff --git a/python/pyspark/streaming/dstream.py b/python/pyspark/streaming/dstream.py index 8eea3f4d52782..cfadef941a551 100644 --- a/python/pyspark/streaming/dstream.py +++ b/python/pyspark/streaming/dstream.py @@ -56,7 +56,7 @@ def _sum(self): """ Add up the elements in this DStream. """ - return self.mapPartitions(lambda x: [sum(x)]).reduce(operator.add) + return self._mapPartitions(lambda x: [sum(x)]).reduce(operator.add) def print_(self, label=None): """ @@ -65,6 +65,7 @@ def print_(self, label=None): deserialized pickled python object. Please use DStream.pyprint() to print results. Call DStream.print() and this function will print byte array in the DStream + """ # a hack to call print function in DStream getattr(self._jdstream, "print")(label) @@ -74,7 +75,7 @@ def filter(self, f): Return a new DStream containing only the elements that satisfy predicate. """ def func(iterator): return ifilter(f, iterator) - return self.mapPartitions(func) + return self._mapPartitions(func) def flatMap(self, f, preservesPartitioning=False): """ @@ -85,7 +86,7 @@ def func(s, iterator): return chain.from_iterable(imap(f, iterator)) return self._mapPartitionsWithIndex(func, preservesPartitioning) - def map(self, f, preservesPartitioning=False): + def map(self, f): """ Return a new DStream by applying a function to each element of DStream. """ @@ -217,13 +218,11 @@ def _defaultReducePartitions(self): return 2 def getNumPartitions(self): - """ - Returns the number of partitions in RDD - >>> rdd = sc.parallelize([1, 2, 3, 4], 2) - >>> rdd.getNumPartitions() - 2 - """ - return self._jdstream.partitions().size() + """ + Return the number of partitions in RDD + """ + # TODO: remove hardcoding. RDD has NumPartitions but DStream does not have. + return 2 def foreachRDD(self, func): """ diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala index cfa336df8674f..a2b9d581f609c 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala @@ -59,7 +59,7 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T * operator, so this PythonDStream will be registered as an output stream and there materialized. * This function is for PythonAPI. */ - + //TODO move this function to PythonDStream def pyprint() = dstream.pyprint() /** diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala index 13032fca15616..994a696a44f5b 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala @@ -620,37 +620,36 @@ abstract class DStream[T: ClassTag] ( new ForEachDStream(this, context.sparkContext.clean(foreachFunc)).register() } -//TODO move pyprint to PythonDStream +//TODO move pyprint to PythonDStream and executed by py4j call back function /** * Print the first ten elements of each PythonRDD generated in this PythonDStream. This is an output * operator, so this PythonDStream will be registered as an output stream and there materialized. * Since serialized Python object is readable by Python, pyprint writes out binary data to * temporary file and run python script to deserialized and print the first ten elements + * + * Currently call python script directly. We should avoid this */ private[streaming] def pyprint() { def foreachFunc = (rdd: RDD[T], time: Time) => { val iter = rdd.take(11).iterator - // make a temporary file + // Generate a temporary file val prefix = "spark" val suffix = ".tmp" val tempFile = File.createTempFile(prefix, suffix) val tempFileStream = new DataOutputStream(new FileOutputStream(tempFile.getAbsolutePath)) - //write out serialized python object + // Write out serialized python object to temporary file PythonRDD.writeIteratorToStream(iter, tempFileStream) tempFileStream.close() - // This value has to be passed from python - // Python currently does not do cluster deployment. But what happened + // pythonExec should be passed from python. Move pyprint to PythonDStream val pythonExec = new ProcessBuilder().environment().get("PYSPARK_PYTHON") val sparkHome = new ProcessBuilder().environment().get("SPARK_HOME") - //val pb = new ProcessBuilder(Seq(pythonExec, sparkHome + "/python/pyspark/streaming/pyprint.py", tempFile.getAbsolutePath())) // why this fails to compile??? - //absolute path to the python script is needed to change because we do not use pysparkstreaming + // Call python script to deserialize and print result in stdout val pb = new ProcessBuilder(pythonExec, sparkHome + "/python/pyspark/streaming/pyprint.py", tempFile.getAbsolutePath) val workerEnv = pb.environment() - //envVars also need to be pass - //workerEnv.putAll(envVars) + // envVars also should be pass from python val pythonPath = sparkHome + "/python/" + File.pathSeparator + workerEnv.get("PYTHONPATH") workerEnv.put("PYTHONPATH", pythonPath) val worker = pb.start() @@ -662,7 +661,7 @@ abstract class DStream[T: ClassTag] ( println ("Time: " + time) println ("-------------------------------------------") - //print value from python std out + // Print values which is from python std out var line = "" breakable { while (true) { @@ -671,7 +670,7 @@ abstract class DStream[T: ClassTag] ( println(line) } } - //delete temporary file + // Delete temporary file tempFile.delete() println()