Skip to content

Commit

Permalink
fix serializer in queueStream
Browse files Browse the repository at this point in the history
  • Loading branch information
davies committed Oct 1, 2014
1 parent 6f0da2f commit d328aca
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 12 deletions.
24 changes: 16 additions & 8 deletions python/pyspark/streaming/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,29 +238,37 @@ def textFileStream(self, directory):

def _check_serialzers(self, rdds):
# make sure they have same serializer
if len(set(rdd._jrdd_deserializer for rdd in rdds)):
if len(set(rdd._jrdd_deserializer for rdd in rdds)) > 1:
for i in range(len(rdds)):
# reset them to sc.serializer
rdds[i] = rdds[i].map(lambda x: x, preservesPartitioning=True)

def queueStream(self, queue, oneAtATime=True, default=None):
def queueStream(self, rdds, oneAtATime=True, default=None):
"""
Create an input stream from an queue of RDDs or list. In each batch,
it will process either one or all of the RDDs returned by the queue.
NOTE: changes to the queue after the stream is created will not be recognized.
@param queue Queue of RDDs
@tparam T Type of objects in the RDD
@param rdds Queue of RDDs
@param oneAtATime pick one rdd each time or pick all of them once.
@param default The default rdd if no more in rdds
"""
if queue and not isinstance(queue[0], RDD):
rdds = [self._sc.parallelize(input) for input in queue]
else:
rdds = queue
if default and not isinstance(default, RDD):
default = self._sc.parallelize(default)

if not rdds and default:
rdds = [rdds]

if rdds and not isinstance(rdds[0], RDD):
rdds = [self._sc.parallelize(input) for input in rdds]
self._check_serialzers(rdds)

jrdds = ListConverter().convert([r._jrdd for r in rdds],
SparkContext._gateway._gateway_client)
queue = self._jvm.PythonDStream.toRDDQueue(jrdds)
if default:
default = default._reserialize(rdds[0]._jrdd_deserializer)
jdstream = self._jssc.queueStream(queue, oneAtATime, default._jrdd)
else:
jdstream = self._jssc.queueStream(queue, oneAtATime)
Expand Down
2 changes: 1 addition & 1 deletion python/pyspark/streaming/dstream.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ def transformWith(self, func, other, keepSerializer=False):
oldfunc = func
func = lambda t, a, b: oldfunc(a, b)
assert func.func_code.co_argcount == 3, "func should take two or three arguments"
jfunc = RDDFunction(self.ctx, func, self._jrdd_deserializer)
jfunc = RDDFunction(self.ctx, func, self._jrdd_deserializer, other._jrdd_deserializer)
dstream = self.ctx._jvm.PythonTransformed2DStream(self._jdstream.dstream(),
other._jdstream.dstream(), jfunc)
jrdd_serializer = self._jrdd_deserializer if keepSerializer else self.ctx.serializer
Expand Down
6 changes: 3 additions & 3 deletions python/pyspark/streaming/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,16 +508,16 @@ def setup():
conf = SparkConf().set("spark.default.parallelism", 1)
sc = SparkContext(conf=conf)
ssc = StreamingContext(sc, .2)
rdd = sc.parallelize(range(10), 1)
rdd = sc.parallelize(range(1), 1)
dstream = ssc.queueStream([rdd], default=rdd)
result[0] = self._collect(dstream.countByWindow(1, .2))
result[0] = self._collect(dstream.countByWindow(1, 0.2))
return ssc
tmpd = tempfile.mkdtemp("test_streaming_cps")
ssc = StreamingContext.getOrCreate(tmpd, setup)
ssc.start()
ssc.awaitTermination(4)
ssc.stop()
expected = [[i * 10 + 10] for i in range(5)] + [[50]] * 5
expected = [[i * 1 + 1] for i in range(5)] + [[5]] * 5
self.assertEqual(expected, result[0][:10])

ssc = StreamingContext.getOrCreate(tmpd, setup)
Expand Down

0 comments on commit d328aca

Please sign in to comment.