Skip to content

Commit

Permalink
Implemented DStream.foreachRDD in the Python API using Py4J callback …
Browse files Browse the repository at this point in the history
…server.
  • Loading branch information
tdas authored and giwa committed Sep 20, 2014
1 parent fe02547 commit 4f07163
Show file tree
Hide file tree
Showing 5 changed files with 4 additions and 71 deletions.
2 changes: 1 addition & 1 deletion python/pyspark/java_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def run(self):
EchoOutputThread(proc.stdout).start()

# Connect to the gateway
gateway = JavaGateway(GatewayClient(port=gateway_port), auto_convert=False)
gateway = JavaGateway(GatewayClient(port=gateway_port), auto_convert=False, start_callback_server=True)

# Import the classes used by PySpark
java_import(gateway.jvm, "org.apache.spark.SparkConf")
Expand Down
4 changes: 2 additions & 2 deletions python/pyspark/streaming/dstream.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,6 @@ def pipeline_func(split, iterator):
self._prev_jdstream = prev._prev_jdstream # maintain the pipeline
self._prev_jrdd_deserializer = prev._prev_jrdd_deserializer
self.is_cached = False
self.is_checkpointed = False
self._ssc = prev._ssc
self.ctx = prev.ctx
self.prev = prev
Expand Down Expand Up @@ -482,4 +481,5 @@ def _jdstream(self):
return self._jdstream_val

def _is_pipelinable(self):
return not (self.is_cached or self.is_checkpointed)
return not (self.is_cached)

Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,6 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T
dstream.print()
}

/**
* Print the first ten elements of each PythonRDD generated in the PythonDStream. This is an output
* operator, so this PythonDStream will be registered as an output stream and there materialized.
* This function is for PythonAPI.
*/
//TODO move this function to PythonDStream
def pyprint() = dstream.pyprint()

/**
* Return a new DStream in which each RDD has a single element generated by counting each RDD
* of this DStream.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ class PythonTestInputStream3(ssc_ : JavaStreamingContext)

val asJavaDStream = JavaDStream.fromDStream(this)
}

class PythonForeachDStream(
prev: DStream[Array[Byte]],
foreachFunction: PythonRDDFunction
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -620,66 +620,6 @@ abstract class DStream[T: ClassTag] (
new ForEachDStream(this, context.sparkContext.clean(foreachFunc)).register()
}

//TODO: move pyprint to PythonDStream and executed by py4j call back function
/**
* Print the first ten elements of each PythonRDD generated in this PythonDStream. This is an output
* operator, so this PythonDStream will be registered as an output stream and there materialized.
* Since serialized Python object is readable by Python, pyprint writes out binary data to
* temporary file and run python script to deserialized and print the first ten elements
*
* Currently call python script directly. We should avoid this
*/
private[streaming] def pyprint() {
def foreachFunc = (rdd: RDD[T], time: Time) => {
val iter = rdd.take(11).iterator

// Generate a temporary file
val prefix = "spark"
val suffix = ".tmp"
val tempFile = File.createTempFile(prefix, suffix)
val tempFileStream = new DataOutputStream(new FileOutputStream(tempFile.getAbsolutePath))
// Write out serialized python object to temporary file
PythonRDD.writeIteratorToStream(iter, tempFileStream)
tempFileStream.close()

// pythonExec should be passed from python. Move pyprint to PythonDStream
val pythonExec = new ProcessBuilder().environment().get("PYSPARK_PYTHON")

val sparkHome = new ProcessBuilder().environment().get("SPARK_HOME")
// Call python script to deserialize and print result in stdout
val pb = new ProcessBuilder(pythonExec, sparkHome + "/python/pyspark/streaming/pyprint.py", tempFile.getAbsolutePath)
val workerEnv = pb.environment()

// envVars also should be pass from python
val pythonPath = sparkHome + "/python/" + File.pathSeparator + workerEnv.get("PYTHONPATH")
workerEnv.put("PYTHONPATH", pythonPath)
val worker = pb.start()
val is = worker.getInputStream()
val isr = new InputStreamReader(is)
val br = new BufferedReader(isr)

println ("-------------------------------------------")
println ("Time: " + time)
println ("-------------------------------------------")

// Print values which is from python std out
var line = ""
breakable {
while (true) {
line = br.readLine()
if (line == null) break()
println(line)
}
}
// Delete temporary file
tempFile.delete()
println()

}
new ForEachDStream(this, context.sparkContext.clean(foreachFunc)).register()
}


/**
* Return a new DStream in which each RDD contains all the elements in seen in a
* sliding window of time over this DStream. The new DStream generates RDDs with
Expand Down

0 comments on commit 4f07163

Please sign in to comment.