Skip to content

Commit

Permalink
[SPARK-16931][PYTHON][SQL] Add Python wrapper for bucketBy
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?

Adds Python wrappers for `DataFrameWriter.bucketBy` and `DataFrameWriter.sortBy` ([SPARK-16931](https://issues.apache.org/jira/browse/SPARK-16931))

## How was this patch tested?

Unit tests covering new feature.

__Note__: Based on work of GregBowyer (f49b9a2)

CC HyukjinKwon

Author: zero323 <[email protected]>
Author: Greg Bowyer <[email protected]>

Closes #17077 from zero323/SPARK-16931.
  • Loading branch information
zero323 authored and cloud-fan committed May 8, 2017
1 parent 1f73d35 commit f53a820
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 0 deletions.
57 changes: 57 additions & 0 deletions python/pyspark/sql/readwriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,6 +563,63 @@ def partitionBy(self, *cols):
self._jwrite = self._jwrite.partitionBy(_to_seq(self._spark._sc, cols))
return self

@since(2.3)
def bucketBy(self, numBuckets, col, *cols):
"""Buckets the output by the given columns.If specified,
the output is laid out on the file system similar to Hive's bucketing scheme.
:param numBuckets: the number of buckets to save
:param col: a name of a column, or a list of names.
:param cols: additional names (optional). If `col` is a list it should be empty.
.. note:: Applicable for file-based data sources in combination with
:py:meth:`DataFrameWriter.saveAsTable`.
>>> (df.write.format('parquet')
... .bucketBy(100, 'year', 'month')
... .mode("overwrite")
... .saveAsTable('bucketed_table'))
"""
if not isinstance(numBuckets, int):
raise TypeError("numBuckets should be an int, got {0}.".format(type(numBuckets)))

if isinstance(col, (list, tuple)):
if cols:
raise ValueError("col is a {0} but cols are not empty".format(type(col)))

col, cols = col[0], col[1:]

if not all(isinstance(c, basestring) for c in cols) or not(isinstance(col, basestring)):
raise TypeError("all names should be `str`")

self._jwrite = self._jwrite.bucketBy(numBuckets, col, _to_seq(self._spark._sc, cols))
return self

@since(2.3)
def sortBy(self, col, *cols):
"""Sorts the output in each bucket by the given columns on the file system.
:param col: a name of a column, or a list of names.
:param cols: additional names (optional). If `col` is a list it should be empty.
>>> (df.write.format('parquet')
... .bucketBy(100, 'year', 'month')
... .sortBy('day')
... .mode("overwrite")
... .saveAsTable('sorted_bucketed_table'))
"""
if isinstance(col, (list, tuple)):
if cols:
raise ValueError("col is a {0} but cols are not empty".format(type(col)))

col, cols = col[0], col[1:]

if not all(isinstance(c, basestring) for c in cols) or not(isinstance(col, basestring)):
raise TypeError("all names should be `str`")

self._jwrite = self._jwrite.sortBy(col, _to_seq(self._spark._sc, cols))
return self

@since(1.4)
def save(self, path=None, format=None, mode=None, partitionBy=None, **options):
"""Saves the contents of the :class:`DataFrame` to a data source.
Expand Down
54 changes: 54 additions & 0 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,12 @@ def test_sqlcontext_reuses_sparksession(self):
sqlContext2 = SQLContext(self.sc)
self.assertTrue(sqlContext1.sparkSession is sqlContext2.sparkSession)

def tearDown(self):
super(SQLTests, self).tearDown()

# tear down test_bucketed_write state
self.spark.sql("DROP TABLE IF EXISTS pyspark_bucket")

def test_row_should_be_read_only(self):
row = Row(a=1, b=2)
self.assertEqual(1, row.a)
Expand Down Expand Up @@ -2196,6 +2202,54 @@ def test_BinaryType_serialization(self):
df = self.spark.createDataFrame(data, schema=schema)
df.collect()

def test_bucketed_write(self):
data = [
(1, "foo", 3.0), (2, "foo", 5.0),
(3, "bar", -1.0), (4, "bar", 6.0),
]
df = self.spark.createDataFrame(data, ["x", "y", "z"])

def count_bucketed_cols(names, table="pyspark_bucket"):
"""Given a sequence of column names and a table name
query the catalog and return number o columns which are
used for bucketing
"""
cols = self.spark.catalog.listColumns(table)
num = len([c for c in cols if c.name in names and c.isBucket])
return num

# Test write with one bucketing column
df.write.bucketBy(3, "x").mode("overwrite").saveAsTable("pyspark_bucket")
self.assertEqual(count_bucketed_cols(["x"]), 1)
self.assertSetEqual(set(data), set(self.spark.table("pyspark_bucket").collect()))

# Test write two bucketing columns
df.write.bucketBy(3, "x", "y").mode("overwrite").saveAsTable("pyspark_bucket")
self.assertEqual(count_bucketed_cols(["x", "y"]), 2)
self.assertSetEqual(set(data), set(self.spark.table("pyspark_bucket").collect()))

# Test write with bucket and sort
df.write.bucketBy(2, "x").sortBy("z").mode("overwrite").saveAsTable("pyspark_bucket")
self.assertEqual(count_bucketed_cols(["x"]), 1)
self.assertSetEqual(set(data), set(self.spark.table("pyspark_bucket").collect()))

# Test write with a list of columns
df.write.bucketBy(3, ["x", "y"]).mode("overwrite").saveAsTable("pyspark_bucket")
self.assertEqual(count_bucketed_cols(["x", "y"]), 2)
self.assertSetEqual(set(data), set(self.spark.table("pyspark_bucket").collect()))

# Test write with bucket and sort with a list of columns
(df.write.bucketBy(2, "x")
.sortBy(["y", "z"])
.mode("overwrite").saveAsTable("pyspark_bucket"))
self.assertSetEqual(set(data), set(self.spark.table("pyspark_bucket").collect()))

# Test write with bucket and sort with multiple columns
(df.write.bucketBy(2, "x")
.sortBy("y", "z")
.mode("overwrite").saveAsTable("pyspark_bucket"))
self.assertSetEqual(set(data), set(self.spark.table("pyspark_bucket").collect()))


class HiveSparkSubmitTests(SparkSubmitTests):

Expand Down

0 comments on commit f53a820

Please sign in to comment.