diff --git a/python/pyspark/streaming/dstream.py b/python/pyspark/streaming/dstream.py index 6fbd5b6f88089..3f23e65712368 100644 --- a/python/pyspark/streaming/dstream.py +++ b/python/pyspark/streaming/dstream.py @@ -1,5 +1,8 @@ from collections import defaultdict from itertools import chain, ifilter, imap +import operator + +import logging from pyspark.serializers import NoOpSerializer,\ BatchedSerializer, CloudPickleSerializer, pack_long @@ -24,6 +27,18 @@ def generatedRDDs(self): """ pass + def count(self): + """ + + """ + #TODO make sure count implementation, thiis different from what pyspark does + return self.mapPartitions(lambda i: [sum(1 for _ in i)]).sum().map(lambda x: x[1]) + + def sum(self): + """ + """ + return self.mapPartitions(lambda x: [sum(x)]).reduce(operator.add) + def print_(self): """ """ @@ -63,9 +78,9 @@ def reduce(self, func, numPartitions=None): """ """ - return self._combineByKey(lambda x:x, func, func, numPartitions) + return self.combineByKey(lambda x:x, func, func, numPartitions) - def _combineByKey(self, createCombiner, mergeValue, mergeCombiners, + def combineByKey(self, createCombiner, mergeValue, mergeCombiners, numPartitions = None): """ """ @@ -74,6 +89,12 @@ def _combineByKey(self, createCombiner, mergeValue, mergeCombiners, def combineLocally(iterator): combiners = {} for x in iterator: + + #TODO for count operation make sure count implementation + # This is different from what pyspark does + if isinstance(x, int): + x = ("", x) + (k, v) = x if k not in combiners: combiners[k] = createCombiner(v) @@ -143,7 +164,7 @@ def _defaultReducePartitions(self): else: return self.getNumPartitions() - return self._jdstream.partitions().size() + return self._jdstream.partitions().size() def mapPartitionsWithIndex(self, f, preservesPartitioning=False): """