diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 79dafb0a4ef27..3218bed5c74fc 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -700,12 +700,14 @@ def groupBy(self, f, numPartitions=None): return self.map(lambda x: (f(x), x)).groupByKey(numPartitions) @ignore_unicode_prefix - def pipe(self, command, env={}): + def pipe(self, command, env={}, checkCode=False): """ Return an RDD created by piping elements to a forked external process. >>> sc.parallelize(['1', '2', '', '3']).pipe('cat').collect() [u'1', u'2', u'', u'3'] + + :param checkCode: whether or not to check the return value of the shell command. """ def func(iterator): pipe = Popen( @@ -717,7 +719,17 @@ def pipe_objs(out): out.write(s.encode('utf-8')) out.close() Thread(target=pipe_objs, args=[pipe.stdin]).start() - return (x.rstrip(b'\n').decode('utf-8') for x in iter(pipe.stdout.readline, b'')) + + def check_return_code(): + pipe.wait() + if checkCode and pipe.returncode: + raise Exception("Pipe function `%s' exited " + "with error code %d" % (command, pipe.returncode)) + else: + for i in range(0): + yield i + return (x.rstrip(b'\n').decode('utf-8') for x in + chain(iter(pipe.stdout.readline, b''), check_return_code())) return self.mapPartitions(func) def foreach(self, f): diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 17256dfc95744..c5c0add49d02c 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -885,6 +885,18 @@ def test_sortByKey_uses_all_partitions_not_only_first_and_last(self): for size in sizes: self.assertGreater(size, 0) + def test_pipe_functions(self): + data = ['1', '2', '3'] + rdd = self.sc.parallelize(data) + with QuietTest(self.sc): + self.assertEqual([], rdd.pipe('cc').collect()) + self.assertRaises(Py4JJavaError, rdd.pipe('cc', checkCode=True).collect) + result = rdd.pipe('cat').collect() + result.sort() + [self.assertEqual(x, y) for x, y in zip(data, result)] + self.assertRaises(Py4JJavaError, rdd.pipe('grep 4', checkCode=True).collect) + self.assertEqual([], rdd.pipe('grep 4').collect()) + class ProfilerTests(PySparkTestCase):