Skip to content

Commit

Permalink
add more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
davies committed Sep 29, 2014
1 parent c40c52d commit 6ebceca
Show file tree
Hide file tree
Showing 3 changed files with 137 additions and 61 deletions.
8 changes: 4 additions & 4 deletions python/pyspark/streaming/dstream.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ def saveAsTextFiles(self, prefix, suffix=None):
Save this DStream as a text file, using string representations of elements.
"""

def saveAsTextFile(rdd, time):
def saveAsTextFile(time, rdd):
"""
Closure to save element in RDD in DStream as Pickled data in file.
This closure is called by py4j callback server.
Expand All @@ -303,7 +303,7 @@ def saveAsPickleFiles(self, prefix, suffix=None):
is 10.
"""

def saveAsPickleFile(rdd, time):
def saveAsPickleFile(time, rdd):
"""
Closure to save element in RDD in the DStream as Pickled data in file.
This closure is called by py4j callback server.
Expand Down Expand Up @@ -388,7 +388,7 @@ def leftOuterJoin(self, other, numPartitions=None):
Hash partitioning is used to generate the RDDs with `numPartitions`
partitions.
"""
return self.transformWith(lambda a, b: a.leftOuterJion(b, numPartitions), other)
return self.transformWith(lambda a, b: a.leftOuterJoin(b, numPartitions), other)

def rightOuterJoin(self, other, numPartitions=None):
"""
Expand Down Expand Up @@ -502,7 +502,7 @@ def countByValueAndWindow(self, windowDuration, slideDuration, numPartitions=Non
@param numPartitions number of partitions of each RDD in the new DStream.
"""
keyed = self.map(lambda x: (x, 1))
counted = keyed.reduceByKeyAndWindow(lambda a, b: a + b, lambda a, b: a - b,
counted = keyed.reduceByKeyAndWindow(operator.add, operator.sub,
windowDuration, slideDuration, numPartitions)
return counted.filter(lambda (k, v): v > 0).count()

Expand Down
156 changes: 118 additions & 38 deletions python/pyspark/streaming/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,12 @@
# limitations under the License.
#

"""
Unit tests for Python SparkStreaming; additional tests are implemented as doctests in
individual modules.
Callback server is sometimes unstable sometimes, which cause error in test case.
But this is very rare case.
"""
import os
from itertools import chain
import time
import operator
import unittest
import tempfile

from pyspark.context import SparkContext
from pyspark.streaming.context import StreamingContext
Expand All @@ -45,16 +40,20 @@ def setUp(self):
def tearDown(self):
self.ssc.stop()

def _test_func(self, input, func, expected, sort=False):
def _test_func(self, input, func, expected, sort=False, input2=None):
"""
@param input: dataset for the test. This should be list of lists.
@param func: wrapped function. This function should return PythonDStream object.
@param expected: expected output for this testcase.
"""
# Generate input stream with user-defined input.
input_stream = self.ssc.queueStream(input)
input_stream2 = self.ssc.queueStream(input2) if input2 is not None else None
# Apply test function to stream.
stream = func(input_stream)
if input2:
stream = func(input_stream, input_stream2)
else:
stream = func(input_stream)

result = stream.collect()
self.ssc.start()

Expand Down Expand Up @@ -92,7 +91,7 @@ def test_take(self):
def test_first(self):
input = [range(10)]
dstream = self.ssc.queueStream(input)
self.assertEqual(0, dstream)
self.assertEqual(0, dstream.first())

def test_map(self):
"""Basic operation test for DStream.map."""
Expand Down Expand Up @@ -238,55 +237,122 @@ def add(a, b):
[("a", "11"), ("b", "1"), ("", "111")]]
self._test_func(input, func, expected, sort=True)

def test_repartition(self):
input = [range(1, 5), range(5, 9)]
rdds = [self.sc.parallelize(r, 2) for r in input]

def func(dstream):
return dstream.repartitions(1).glom()
expected = [[[1, 2, 3, 4]], [[5, 6, 7, 8]]]
self._test_func(rdds, func, expected)

def test_union(self):
input1 = [range(3), range(5), range(1), range(6)]
input2 = [range(3, 6), range(5, 6), range(1, 6)]
input1 = [range(3), range(5), range(6)]
input2 = [range(3, 6), range(5, 6)]

d1 = self.ssc.queueStream(input1)
d2 = self.ssc.queueStream(input2)
d = d1.union(d2)
result = d.collect()
expected = [range(6), range(6), range(6), range(6)]
def func(d1, d2):
return d1.union(d2)

self.ssc.start()
start_time = time.time()
# Loop until get the expected the number of the result from the stream.
while True:
current_time = time.time()
# Check time out.
if (current_time - start_time) > self.timeout * 2:
break
# StreamingContext.awaitTermination is not used to wait because
# if py4j server is called every 50 milliseconds, it gets an error.
time.sleep(0.05)
# Check if the output is the same length of expected output.
if len(expected) == len(result):
break
self.assertEqual(expected, result)
expected = [range(6), range(6), range(6)]
self._test_func(input1, func, expected, input2=input2)

def test_cogroup(self):
input = [[(1, 1), (2, 1), (3, 1)],
[(1, 1), (1, 1), (1, 1), (2, 1)],
[("a", 1), ("a", 1), ("b", 1), ("", 1), ("", 1)]]
input2 = [[(1, 2)],
[(4, 1)],
[("a", 1), ("a", 1), ("b", 1), ("", 1), ("", 2)]]

def func(d1, d2):
return d1.cogroup(d2).mapValues(lambda vs: tuple(map(list, vs)))

expected = [[(1, ([1], [2])), (2, ([1], [])), (3, ([1], []))],
[(1, ([1, 1, 1], [])), (2, ([1], [])), (4, ([], [1]))],
[("a", ([1, 1], [1, 1])), ("b", ([1], [1])), ("", ([1, 1], [1, 2]))]]
self._test_func(input, func, expected, sort=True, input2=input2)

def test_join(self):
input = [[('a', 1), ('b', 2)]]
input2 = [[('b', 3), ('c', 4)]]

def func(a, b):
return a.join(b)

expected = [[('b', (2, 3))]]
self._test_func(input, func, expected, True, input2)

def test_left_outer_join(self):
input = [[('a', 1), ('b', 2)]]
input2 = [[('b', 3), ('c', 4)]]

def func(a, b):
return a.leftOuterJoin(b)

expected = [[('a', (1, None)), ('b', (2, 3))]]
self._test_func(input, func, expected, True, input2)

def test_right_outer_join(self):
input = [[('a', 1), ('b', 2)]]
input2 = [[('b', 3), ('c', 4)]]

def func(a, b):
return a.rightOuterJoin(b)

expected = [[('b', (2, 3)), ('c', (None, 4))]]
self._test_func(input, func, expected, True, input2)

def test_full_outer_join(self):
input = [[('a', 1), ('b', 2)]]
input2 = [[('b', 3), ('c', 4)]]

def func(a, b):
return a.fullOuterJoin(b)

expected = [[('a', (1, None)), ('b', (2, 3)), ('c', (None, 4))]]
self._test_func(input, func, expected, True, input2)


class TestWindowFunctions(PySparkStreamingTestCase):

timeout = 15
timeout = 20

def test_window(self):
input = [range(1), range(2), range(3), range(4), range(5)]

def func(dstream):
return dstream.window(3, 1).count()

expected = [[1], [3], [6], [9], [12], [9], [5]]
self._test_func(input, func, expected)

def test_count_by_window(self):
input = [range(1), range(2), range(3), range(4), range(5), range(6)]
input = [range(1), range(2), range(3), range(4), range(5)]

def func(dstream):
return dstream.countByWindow(4, 1)
return dstream.countByWindow(3, 1)

expected = [[1], [3], [6], [9], [12], [15], [11], [6]]
expected = [[1], [3], [6], [9], [12], [9], [5]]
self._test_func(input, func, expected)

def test_count_by_window_large(self):
input = [range(1), range(2), range(3), range(4), range(5), range(6)]

def func(dstream):
return dstream.countByWindow(6, 1)
return dstream.countByWindow(5, 1)

expected = [[1], [3], [6], [10], [15], [20], [18], [15], [11], [6]]
self._test_func(input, func, expected)

def test_count_by_value_and_window(self):
input = [range(1), range(2), range(3), range(4), range(5), range(6)]

def func(dstream):
return dstream.countByValueAndWindow(6, 1)

expected = [[1], [2], [3], [4], [5], [6], [6], [6], [6], [6]]
self._test_func(input, func, expected)

def test_group_by_key_and_window(self):
input = [[('a', i)] for i in range(5)]

Expand Down Expand Up @@ -359,6 +425,20 @@ def test_queueStream(self):
time.sleep(1)
self.assertEqual(input, result[:3])

# TODO: test textFileStream
# def test_textFileStream(self):
# input = [range(i) for i in range(3)]
# dstream = self.ssc.queueStream(input)
# d = os.path.join(tempfile.gettempdir(), str(id(self)))
# if not os.path.exists(d):
# os.makedirs(d)
# dstream.saveAsTextFiles(os.path.join(d, 'test'))
# dstream2 = self.ssc.textFileStream(d)
# result = dstream2.collect()
# self.ssc.start()
# time.sleep(2)
# self.assertEqual(input, result[:3])

def test_union(self):
input = [range(i) for i in range(3)]
dstream = self.ssc.queueStream(input)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,8 @@ private[spark] object PythonDStream {
}

// helper function for ssc.transform()
def callTransform(ssc: JavaStreamingContext, jdsteams: JList[JavaDStream[_]], pyfunc: PythonRDDFunction)
def callTransform(ssc: JavaStreamingContext, jdsteams: JList[JavaDStream[_]],
pyfunc: PythonRDDFunction)
:JavaDStream[Array[Byte]] = {
val func = new RDDFunction(pyfunc)
ssc.transform(jdsteams, func)
Expand Down Expand Up @@ -210,9 +211,9 @@ class PythonReducedWindowedDStream(parent: DStream[Array[Byte]],

override def compute(validTime: Time): Option[RDD[Array[Byte]]] = {
val currentTime = validTime
val currentWindow = new Interval(currentTime - windowDuration + parent.slideDuration,
val current = new Interval(currentTime - windowDuration,
currentTime)
val previousWindow = currentWindow - slideDuration
val previous = current - slideDuration

// _____________________________
// | previous window _________|___________________
Expand All @@ -225,35 +226,30 @@ class PythonReducedWindowedDStream(parent: DStream[Array[Byte]],
// old RDDs new RDDs
//

// Get the RDD of the reduced value of the previous window
val previousWindowRDD = getOrCompute(previousWindow.endTime)
val previousRDD = getOrCompute(previous.endTime)

if (pinvReduceFunc != null && previousWindowRDD.isDefined
if (pinvReduceFunc != null && previousRDD.isDefined
// for small window, reduce once will be better than twice
&& windowDuration > slideDuration * 5) {
&& windowDuration >= slideDuration * 5) {

// subtract the values from old RDDs
val oldRDDs =
parent.slice(previousWindow.beginTime, currentWindow.beginTime - parent.slideDuration)
val subbed = if (oldRDDs.size > 0) {
invReduceFunc(previousWindowRDD, Some(ssc.sc.union(oldRDDs)), validTime)
val oldRDDs = parent.slice(previous.beginTime + parent.slideDuration, current.beginTime)
val subtracted = if (oldRDDs.size > 0) {
invReduceFunc(previousRDD, Some(ssc.sc.union(oldRDDs)), validTime)
} else {
previousWindowRDD
previousRDD
}

// add the RDDs of the reduced values in "new time steps"
val newRDDs =
parent.slice(previousWindow.endTime, currentWindow.endTime - parent.slideDuration)

val newRDDs = parent.slice(previous.endTime + parent.slideDuration, current.endTime)
if (newRDDs.size > 0) {
reduceFunc(subbed, Some(ssc.sc.union(newRDDs)), validTime)
reduceFunc(subtracted, Some(ssc.sc.union(newRDDs)), validTime)
} else {
subbed
subtracted
}
} else {
// Get the RDDs of the reduced values in current window
val currentRDDs =
parent.slice(currentWindow.beginTime, currentWindow.endTime - parent.slideDuration)
val currentRDDs = parent.slice(current.beginTime + parent.slideDuration, current.endTime)
if (currentRDDs.size > 0) {
reduceFunc(None, Some(ssc.sc.union(currentRDDs)), validTime)
} else {
Expand Down

0 comments on commit 6ebceca

Please sign in to comment.