Skip to content

Commit

Permalink
recover from checkpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
davies committed Oct 1, 2014
1 parent fa7261b commit 6f0da2f
Show file tree
Hide file tree
Showing 9 changed files with 136 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.util.Utils

private[spark] class PythonRDD(
parent: RDD[_],
@transient parent: RDD[_],
command: Array[Byte],
envVars: JMap[String, String],
pythonIncludes: JList[String],
Expand All @@ -61,9 +61,9 @@ private[spark] class PythonRDD(
val bufferSize = conf.getInt("spark.buffer.size", 65536)
val reuse_worker = conf.getBoolean("spark.python.worker.reuse", true)

override def getPartitions = parent.partitions
override def getPartitions = firstParent.partitions

override val partitioner = if (preservePartitoning) parent.partitioner else None
override val partitioner = if (preservePartitoning) firstParent.partitioner else None

override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = {
val startTime = System.currentTimeMillis
Expand Down Expand Up @@ -241,7 +241,7 @@ private[spark] class PythonRDD(
dataOut.writeInt(command.length)
dataOut.write(command)
// Data values
PythonRDD.writeIteratorToStream(parent.iterator(split, context), dataOut)
PythonRDD.writeIteratorToStream(firstParent.iterator(split, context), dataOut)
dataOut.writeInt(SpecialLengths.END_OF_DATA_SECTION)
dataOut.flush()
} catch {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ private[spark] class ParallelCollectionPartition[T: ClassTag](

private[spark] class ParallelCollectionRDD[T: ClassTag](
@transient sc: SparkContext,
@transient data: Seq[T],
data: Seq[T],
numSlices: Int,
locationPrefs: Map[Int, Seq[String]])
extends RDD[T](sc, Nil) {
Expand Down
8 changes: 8 additions & 0 deletions core/src/main/scala/org/apache/spark/rdd/RDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,14 @@ abstract class RDD[T: ClassTag](
def this(@transient oneParent: RDD[_]) =
this(oneParent.context , List(new OneToOneDependency(oneParent)))

// setContext after loading from checkpointing
private[spark] def setContext(s: SparkContext) = {
if (sc != null && sc != s) {
throw new SparkException("Context is already set in " + this + ", cannot set it again")
}
sc = s
}

private[spark] def conf = sc.conf
// =======================================================================
// Methods that should be implemented by subclasses of RDD
Expand Down
8 changes: 4 additions & 4 deletions python/pyspark/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ class SparkContext(object):

def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None,
environment=None, batchSize=1024, serializer=PickleSerializer(), conf=None,
gateway=None):
gateway=None, jsc=None):
"""
Create a new SparkContext. At least the master and app name should be set,
either through the named parameters here or through C{conf}.
Expand Down Expand Up @@ -103,14 +103,14 @@ def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None,
SparkContext._ensure_initialized(self, gateway=gateway)
try:
self._do_init(master, appName, sparkHome, pyFiles, environment, batchSize, serializer,
conf)
conf, jsc)
except:
# If an error occurs, clean up in order to allow future SparkContext creation:
self.stop()
raise

def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize, serializer,
conf):
conf, jsc):
self.environment = environment or {}
self._conf = conf or SparkConf(_jvm=self._jvm)
self._batchSize = batchSize # -1 represents an unlimited batch size
Expand Down Expand Up @@ -151,7 +151,7 @@ def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize,
self.environment[varName] = v

# Create the Java SparkContext through Py4J
self._jsc = self._initialize_context(self._conf._jconf)
self._jsc = jsc or self._initialize_context(self._conf._jconf)

# Create a single Accumulator in Java that we'll send all our updates through;
# they will be passed back to us through a TCP server
Expand Down
76 changes: 59 additions & 17 deletions python/pyspark/streaming/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
import sys

from py4j.java_collections import ListConverter
from py4j.java_gateway import java_import

from pyspark import RDD
from pyspark import RDD, SparkConf
from pyspark.serializers import UTF8Deserializer, CloudPickleSerializer
from pyspark.context import SparkContext
from pyspark.storagelevel import StorageLevel
Expand Down Expand Up @@ -75,41 +77,81 @@ class StreamingContext(object):
respectively. `context.awaitTransformation()` allows the current thread
to wait for the termination of the context by `stop()` or by an exception.
"""
_transformerSerializer = None

def __init__(self, sparkContext, duration):
def __init__(self, sparkContext, duration=None, jssc=None):
"""
Create a new StreamingContext.
@param sparkContext: L{SparkContext} object.
@param duration: number of seconds.
"""

self._sc = sparkContext
self._jvm = self._sc._jvm
self._start_callback_server()
self._jssc = self._initialize_context(self._sc, duration)
self._jssc = jssc or self._initialize_context(self._sc, duration)

def _initialize_context(self, sc, duration):
self._ensure_initialized()
return self._jvm.JavaStreamingContext(sc._jsc, self._jduration(duration))

def _jduration(self, seconds):
"""
Create Duration object given number of seconds
"""
return self._jvm.Duration(int(seconds * 1000))

def _start_callback_server(self):
gw = self._sc._gateway
@classmethod
def _ensure_initialized(cls):
SparkContext._ensure_initialized()
gw = SparkContext._gateway
# start callback server
# getattr will fallback to JVM
if "_callback_server" not in gw.__dict__:
_daemonize_callback_server()
gw._start_callback_server(gw._python_proxy_port)
gw._python_proxy_port = gw._callback_server.port # update port with real port

def _initialize_context(self, sc, duration):
java_import(self._jvm, "org.apache.spark.streaming.*")
java_import(self._jvm, "org.apache.spark.streaming.api.java.*")
java_import(self._jvm, "org.apache.spark.streaming.api.python.*")
java_import(gw.jvm, "org.apache.spark.streaming.*")
java_import(gw.jvm, "org.apache.spark.streaming.api.java.*")
java_import(gw.jvm, "org.apache.spark.streaming.api.python.*")
# register serializer for RDDFunction
ser = RDDFunctionSerializer(self._sc, CloudPickleSerializer())
self._jvm.PythonDStream.registerSerializer(ser)
return self._jvm.JavaStreamingContext(sc._jsc, self._jduration(duration))
# it happens before creating SparkContext when loading from checkpointing
cls._transformerSerializer = RDDFunctionSerializer(SparkContext._active_spark_context,
CloudPickleSerializer(), gw)
gw.jvm.PythonDStream.registerSerializer(cls._transformerSerializer)

def _jduration(self, seconds):
@classmethod
def getOrCreate(cls, path, setupFunc):
"""
Create Duration object given number of seconds
Get the StreamingContext from checkpoint file at `path`, or setup
it by `setupFunc`.
:param path: directory of checkpoint
:param setupFunc: a function used to create StreamingContext and
setup DStreams.
:return: a StreamingContext
"""
return self._jvm.Duration(int(seconds * 1000))
if not os.path.exists(path) or not os.path.isdir(path) or not os.listdir(path):
ssc = setupFunc()
ssc.checkpoint(path)
return ssc

cls._ensure_initialized()
gw = SparkContext._gateway

try:
jssc = gw.jvm.JavaStreamingContext(path)
except Exception:
print >>sys.stderr, "failed to load StreamingContext from checkpoint"
raise

jsc = jssc.sparkContext()
conf = SparkConf(_jconf=jsc.getConf())
sc = SparkContext(conf=conf, gateway=gw, jsc=jsc)
# update ctx in serializer
SparkContext._active_spark_context = sc
cls._transformerSerializer.ctx = sc
return StreamingContext(sc, None, jssc)

@property
def sparkContext(self):
Expand Down
33 changes: 33 additions & 0 deletions python/pyspark/streaming/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,5 +493,38 @@ def func(rdds):
self.assertEqual([2, 3, 1], self._take(dstream, 3))


class TestCheckpoint(PySparkStreamingTestCase):

def setUp(self):
pass

def tearDown(self):
pass

def test_get_or_create(self):
result = [0]

def setup():
conf = SparkConf().set("spark.default.parallelism", 1)
sc = SparkContext(conf=conf)
ssc = StreamingContext(sc, .2)
rdd = sc.parallelize(range(10), 1)
dstream = ssc.queueStream([rdd], default=rdd)
result[0] = self._collect(dstream.countByWindow(1, .2))
return ssc
tmpd = tempfile.mkdtemp("test_streaming_cps")
ssc = StreamingContext.getOrCreate(tmpd, setup)
ssc.start()
ssc.awaitTermination(4)
ssc.stop()
expected = [[i * 10 + 10] for i in range(5)] + [[50]] * 5
self.assertEqual(expected, result[0][:10])

ssc = StreamingContext.getOrCreate(tmpd, setup)
ssc.start()
ssc.awaitTermination(2)
ssc.stop()


if __name__ == "__main__":
unittest.main()
24 changes: 16 additions & 8 deletions python/pyspark/streaming/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,24 +18,31 @@
from datetime import datetime
import traceback

from pyspark.rdd import RDD
from pyspark import SparkContext, RDD


class RDDFunction(object):
"""
This class is for py4j callback.
"""
_emptyRDD = None

def __init__(self, ctx, func, *deserializers):
self.ctx = ctx
self.func = func
self.deserializers = deserializers
emptyRDD = getattr(self.ctx, "_emptyRDD", None)
if emptyRDD is None:
self.ctx._emptyRDD = emptyRDD = self.ctx.parallelize([]).cache()
self.emptyRDD = emptyRDD

@property
def emptyRDD(self):
if self._emptyRDD is None and self.ctx:
self._emptyRDD = self.ctx.parallelize([]).cache()
return self._emptyRDD

def call(self, milliseconds, jrdds):
try:
if self.ctx is None:
self.ctx = SparkContext._active_spark_context

# extend deserializers with the first one
sers = self.deserializers
if len(sers) < len(jrdds):
Expand All @@ -51,20 +58,21 @@ def call(self, milliseconds, jrdds):
traceback.print_exc()

def __repr__(self):
return "RDDFunction(%s)" % (str(self.func))
return "RDDFunction(%s)" % self.func

class Java:
implements = ['org.apache.spark.streaming.api.python.PythonRDDFunction']


class RDDFunctionSerializer(object):
def __init__(self, ctx, serializer):
def __init__(self, ctx, serializer, gateway=None):
self.ctx = ctx
self.serializer = serializer
self.gateway = gateway or self.ctx._gateway

def dumps(self, id):
try:
func = self.ctx._gateway.gateway_property.pool[id]
func = self.gateway.gateway_property.pool[id]
return bytearray(self.serializer.dumps((func.func, func.deserializers)))
except Exception:
traceback.print_exc()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ private[python] class RDDFunction(@transient var pfunc: PythonRDDFunction)
}

/**
* Inferface for Python Serializer to serialize PythonRDDFunction
* Interface for Python Serializer to serialize PythonRDDFunction
*/
private[python] trait PythonRDDFunctionSerializer {
def dumps(id: String): Array[Byte] //
Expand All @@ -91,9 +91,9 @@ private[python] class RDDFunctionSerializer(pser: PythonRDDFunctionSerializer) {
def serialize(func: PythonRDDFunction): Array[Byte] = {
// get the id of PythonRDDFunction in py4j
val h = Proxy.getInvocationHandler(func.asInstanceOf[Proxy])
val f = h.getClass().getDeclaredField("id");
f.setAccessible(true);
val id = f.get(h).asInstanceOf[String];
val f = h.getClass().getDeclaredField("id")
f.setAccessible(true)
val id = f.get(h).asInstanceOf[String]
pser.dumps(id)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.streaming.dstream

import org.apache.spark.SparkException
import org.apache.spark.rdd.RDD
import org.apache.spark.rdd.UnionRDD
import scala.collection.mutable.Queue
Expand All @@ -32,6 +33,12 @@ class QueueInputDStream[T: ClassTag](
defaultRDD: RDD[T]
) extends InputDStream[T](ssc) {

private[streaming] override def setContext(s: StreamingContext) {
super.setContext(s)
queue.map(_.setContext(s.sparkContext))
defaultRDD.setContext(s.sparkContext)
}

override def start() { }

override def stop() { }
Expand Down

0 comments on commit 6f0da2f

Please sign in to comment.