Skip to content

Commit

Permalink
clean up code
Browse files Browse the repository at this point in the history
  • Loading branch information
giwa committed Sep 20, 2014
1 parent 3166d31 commit f198d14
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 20 deletions.
11 changes: 3 additions & 8 deletions python/pyspark/streaming/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,21 +72,19 @@ def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None,
# Callback sever is need only by SparkStreming; therefore the callback sever
# is started in StreamingContext.
SparkContext._gateway.restart_callback_server()
self._clean_up_trigger()
self._set_clean_up_trigger()
self._jvm = self._sc._jvm
self._jssc = self._initialize_context(self._sc._jsc, duration._jduration)

# Initialize StremaingContext in function to allow subclass specific initialization
def _initialize_context(self, jspark_context, jduration):
return self._jvm.JavaStreamingContext(jspark_context, jduration)

def _clean_up_trigger(self):
def _set_clean_up_trigger(self):
"""Kill py4j callback server properly using signal lib"""

def clean_up_handler(*args):
# Make sure stop callback server.
# This need improvement how to terminate callback sever properly.
SparkContext._gateway._shutdown_callback_server()
SparkContext._gateway.shutdown()
sys.exit(0)

Expand Down Expand Up @@ -132,18 +130,15 @@ def stop(self, stopSparkContext=True, stopGraceFully=False):
Stop the execution of the streams immediately (does not wait for all received data
to be processed).
"""

try:
self._jssc.stop(stopSparkContext, stopGraceFully)
finally:
# Stop Callback server
SparkContext._gateway._shutdown_callback_server()
SparkContext._gateway.shutdown()

def _testInputStream(self, test_inputs, numSlices=None):
"""
This function is only for unittest.
It requires a sequence as input, and returns the i_th element at the i_th batch
It requires a list as input, and returns the i_th element at the i_th batch
under manual clock.
"""
test_rdds = list()
Expand Down
28 changes: 16 additions & 12 deletions python/pyspark/streaming/dstream.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ def _defaultReducePartitions(self):
"""
Returns the default number of partitions to use during reduce tasks (e.g., groupBy).
If spark.default.parallelism is set, then we'll use the value from SparkContext
defaultParallelism, otherwise we'll use the number of partitions in this RDD.
defaultParallelism, otherwise we'll use the number of partitions in this RDD
This mirrors the behavior of the Scala Partitioner#defaultPartitioner, intended to reduce
the likelihood of OOMs. Once PySpark adopts Partitioner-based APIs, this behavior will
Expand All @@ -222,7 +222,8 @@ def getNumPartitions(self):
"""
Return the number of partitions in RDD
"""
# TODO: remove hardcoding. RDD has NumPartitions but DStream does not have.
# TODO: remove hardcoding. RDD has NumPartitions. How do we get the number of partition
# through DStream?
return 2

def foreachRDD(self, func):
Expand All @@ -243,6 +244,10 @@ def pyprint(self):
operator, so this DStream will be registered as an output stream and there materialized.
"""
def takeAndPrint(rdd, time):
"""
Closure to take element from RDD and print first 10 elements.
This closure is called by py4j callback server.
"""
taken = rdd.take(11)
print "-------------------------------------------"
print "Time: %s" % (str(time))
Expand Down Expand Up @@ -307,17 +312,11 @@ def checkpoint(self, interval):
Mark this DStream for checkpointing. It will be saved to a file inside the
checkpoint directory set with L{SparkContext.setCheckpointDir()}
I am not sure this part in DStream
and
all references to its parent RDDs will be removed. This function must
be called before any job has been executed on this RDD. It is strongly
recommended that this RDD is persisted in memory, otherwise saving it
on a file will require recomputation.
interval must be pysprak.streaming.duration
@param interval: Time interval after which generated RDD will be checkpointed
interval has to be pyspark.streaming.duration.Duration
"""
self.is_checkpointed = True
self._jdstream.checkpoint(interval)
self._jdstream.checkpoint(interval._jduration)
return self

def groupByKey(self, numPartitions=None):
Expand Down Expand Up @@ -369,6 +368,10 @@ def saveAsTextFiles(self, prefix, suffix=None):
Save this DStream as a text file, using string representations of elements.
"""
def saveAsTextFile(rdd, time):
"""
Closure to save element in RDD in DStream as Pickled data in file.
This closure is called by py4j callback server.
"""
path = rddToFileName(prefix, suffix, time)
rdd.saveAsTextFile(path)

Expand Down Expand Up @@ -410,9 +413,10 @@ def get_output(rdd, time):
# TODO: implement countByWindow
# TODO: implement reduceByWindow

# Following operation has dependency to transform
# transform Operation
# TODO: implement transform
# TODO: implement transformWith
# Following operation has dependency with transform
# TODO: implement union
# TODO: implement repertitions
# TODO: implement cogroup
Expand Down

0 comments on commit f198d14

Please sign in to comment.