diff --git a/python/pyspark/streaming/context.py b/python/pyspark/streaming/context.py index 04737243f3192..5952e81a4bef3 100644 --- a/python/pyspark/streaming/context.py +++ b/python/pyspark/streaming/context.py @@ -15,7 +15,8 @@ # limitations under the License. # -import time +import sys +from signal import signal, SIGTERM, SIGINT from pyspark.conf import SparkConf from pyspark.files import SparkFiles @@ -63,15 +64,14 @@ def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None, """ - # launch call back server - if not gateway: - gateway = launch_gateway() -# gateway.restart_callback_server() - # Create the Python Sparkcontext self._sc = SparkContext(master=master, appName=appName, sparkHome=sparkHome, pyFiles=pyFiles, environment=environment, batchSize=batchSize, serializer=serializer, conf=conf, gateway=gateway) + + # Start py4j callback server + SparkContext._gateway.restart_callback_server() + self._clean_up_trigger() self._jvm = self._sc._jvm self._jssc = self._initialize_context(self._sc._jsc, duration._jduration) @@ -79,6 +79,16 @@ def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None, def _initialize_context(self, jspark_context, jduration): return self._jvm.JavaStreamingContext(jspark_context, jduration) + def _clean_up_trigger(self): + """Kill py4j callback server properly using signal lib""" + + def clean_up_handler(*args): + SparkContext._gateway.shutdown() + sys.exit(0) + + for sig in (SIGINT, SIGTERM): + signal(sig, clean_up_handler) + def start(self): """ Start the execution of the streams.