Skip to content

Commit

Permalink
Move Py4jCallbackConnectionCleaner to Streaming
Browse files Browse the repository at this point in the history
  • Loading branch information
zsxwing committed Jan 6, 2016
1 parent 5d871ea commit 329a78b
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 61 deletions.
61 changes: 0 additions & 61 deletions python/pyspark/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,64 +54,6 @@
}


class Py4jCallbackConnectionCleaner(object):

"""
A cleaner to clean up callback connections that are not closed by Py4j. See SPARK-12617.
It will scan all callback connections every 30 seconds and close the dead connections.
"""

def __init__(self, gateway):
self._gateway = gateway
self._stopped = False
self._timer = None
self._lock = RLock()

def start(self):
if self._stopped:
return

def clean_closed_connections():
from py4j.java_gateway import quiet_close, quiet_shutdown

callback_server = self._gateway._callback_server
with callback_server.lock:
try:
closed_connections = []
for connection in callback_server.connections:
if not connection.isAlive():
quiet_close(connection.input)
quiet_shutdown(connection.socket)
quiet_close(connection.socket)
closed_connections.append(connection)

for closed_connection in closed_connections:
callback_server.connections.remove(closed_connection)
except Exception:
import traceback
traceback.print_exc()

self._start_timer(clean_closed_connections)

self._start_timer(clean_closed_connections)

def _start_timer(self, f):
from threading import Timer

with self._lock:
if not self._stopped:
self._timer = Timer(30.0, f)
self._timer.daemon = True
self._timer.start()

def stop(self):
with self._lock:
self._stopped = True
if self._timer:
self._timer.cancel()
self._timer = None


class SparkContext(object):

"""
Expand All @@ -126,7 +68,6 @@ class SparkContext(object):
_active_spark_context = None
_lock = RLock()
_python_includes = None # zip and egg files that need to be added to PYTHONPATH
_py4j_cleaner = None

PACKAGE_EXTENSIONS = ('.zip', '.egg', '.jar')

Expand Down Expand Up @@ -303,8 +244,6 @@ def _ensure_initialized(cls, instance=None, gateway=None):
if not SparkContext._gateway:
SparkContext._gateway = gateway or launch_gateway()
SparkContext._jvm = SparkContext._gateway.jvm
_py4j_cleaner = Py4jCallbackConnectionCleaner(SparkContext._gateway)
_py4j_cleaner.start()

if instance:
if (SparkContext._active_spark_context and
Expand Down
63 changes: 63 additions & 0 deletions python/pyspark/streaming/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import os
import sys
from threading import RLock, Timer

from py4j.java_gateway import java_import, JavaObject

Expand All @@ -32,6 +33,63 @@
__all__ = ["StreamingContext"]


class Py4jCallbackConnectionCleaner(object):

"""
A cleaner to clean up callback connections that are not closed by Py4j. See SPARK-12617.
It will scan all callback connections every 30 seconds and close the dead connections.
"""

def __init__(self, gateway):
self._gateway = gateway
self._stopped = False
self._timer = None
self._lock = RLock()

def start(self):
if self._stopped:
return

def clean_closed_connections():
from py4j.java_gateway import quiet_close, quiet_shutdown

callback_server = self._gateway._callback_server
if callback_server:
with callback_server.lock:
try:
closed_connections = []
for connection in callback_server.connections:
if not connection.isAlive():
quiet_close(connection.input)
quiet_shutdown(connection.socket)
quiet_close(connection.socket)
closed_connections.append(connection)

for closed_connection in closed_connections:
callback_server.connections.remove(closed_connection)
except Exception:
import traceback
traceback.print_exc()

self._start_timer(clean_closed_connections)

self._start_timer(clean_closed_connections)

def _start_timer(self, f):
with self._lock:
if not self._stopped:
self._timer = Timer(30.0, f)
self._timer.daemon = True
self._timer.start()

def stop(self):
with self._lock:
self._stopped = True
if self._timer:
self._timer.cancel()
self._timer = None


class StreamingContext(object):
"""
Main entry point for Spark Streaming functionality. A StreamingContext
Expand All @@ -47,6 +105,9 @@ class StreamingContext(object):
# Reference to a currently active StreamingContext
_activeContext = None

# A cleaner to clean leak sockets of callback server every 30 seconds
_py4j_cleaner = None

def __init__(self, sparkContext, batchDuration=None, jssc=None):
"""
Create a new StreamingContext.
Expand Down Expand Up @@ -95,6 +156,8 @@ def _ensure_initialized(cls):
jgws = JavaObject("GATEWAY_SERVER", gw._gateway_client)
# update the port of CallbackClient with real port
gw.jvm.PythonDStream.updatePythonGatewayPort(jgws, gw._python_proxy_port)
_py4j_cleaner = Py4jCallbackConnectionCleaner(gw)
_py4j_cleaner.start()

# register serializer for TransformFunction
# it happens before creating SparkContext when loading from checkpointing
Expand Down

0 comments on commit 329a78b

Please sign in to comment.