-
Notifications
You must be signed in to change notification settings - Fork 28.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-16931][PYTHON][SQL] Add Python wrapper for bucketBy #17077
Changes from all commits
c2840ae
8c4e761
a9571db
d38da88
fab0ef4
23b0ef8
75d748d
e40987a
6f0a795
0684d92
cbfbac4
8eac959
c996828
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -563,6 +563,63 @@ def partitionBy(self, *cols): | |
self._jwrite = self._jwrite.partitionBy(_to_seq(self._spark._sc, cols)) | ||
return self | ||
|
||
@since(2.3) | ||
def bucketBy(self, numBuckets, col, *cols): | ||
"""Buckets the output by the given columns.If specified, | ||
the output is laid out on the file system similar to Hive's bucketing scheme. | ||
|
||
:param numBuckets: the number of buckets to save | ||
:param col: a name of a column, or a list of names. | ||
:param cols: additional names (optional). If `col` is a list it should be empty. | ||
|
||
.. note:: Applicable for file-based data sources in combination with | ||
:py:meth:`DataFrameWriter.saveAsTable`. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is not accurate. We also can use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @gatorsmile Can we?
` There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. uh. Yes. Bucket info is not part of the file/directory names, unlike partitioning info. |
||
|
||
>>> (df.write.format('parquet') | ||
... .bucketBy(100, 'year', 'month') | ||
... .mode("overwrite") | ||
... .saveAsTable('bucketed_table')) | ||
""" | ||
if not isinstance(numBuckets, int): | ||
raise TypeError("numBuckets should be an int, got {0}.".format(type(numBuckets))) | ||
|
||
if isinstance(col, (list, tuple)): | ||
if cols: | ||
raise ValueError("col is a {0} but cols are not empty".format(type(col))) | ||
|
||
col, cols = col[0], col[1:] | ||
|
||
if not all(isinstance(c, basestring) for c in cols) or not(isinstance(col, basestring)): | ||
raise TypeError("all names should be `str`") | ||
|
||
self._jwrite = self._jwrite.bucketBy(numBuckets, col, _to_seq(self._spark._sc, cols)) | ||
return self | ||
|
||
@since(2.3) | ||
def sortBy(self, col, *cols): | ||
"""Sorts the output in each bucket by the given columns on the file system. | ||
|
||
:param col: a name of a column, or a list of names. | ||
:param cols: additional names (optional). If `col` is a list it should be empty. | ||
|
||
>>> (df.write.format('parquet') | ||
... .bucketBy(100, 'year', 'month') | ||
... .sortBy('day') | ||
... .mode("overwrite") | ||
... .saveAsTable('sorted_bucketed_table')) | ||
""" | ||
if isinstance(col, (list, tuple)): | ||
if cols: | ||
raise ValueError("col is a {0} but cols are not empty".format(type(col))) | ||
|
||
col, cols = col[0], col[1:] | ||
|
||
if not all(isinstance(c, basestring) for c in cols) or not(isinstance(col, basestring)): | ||
raise TypeError("all names should be `str`") | ||
|
||
self._jwrite = self._jwrite.sortBy(col, _to_seq(self._spark._sc, cols)) | ||
return self | ||
|
||
@since(1.4) | ||
def save(self, path=None, format=None, mode=None, partitionBy=None, **options): | ||
"""Saves the contents of the :class:`DataFrame` to a data source. | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -211,6 +211,12 @@ def test_sqlcontext_reuses_sparksession(self): | |
sqlContext2 = SQLContext(self.sc) | ||
self.assertTrue(sqlContext1.sparkSession is sqlContext2.sparkSession) | ||
|
||
def tearDown(self): | ||
super(SQLTests, self).tearDown() | ||
|
||
# tear down test_bucketed_write state | ||
self.spark.sql("DROP TABLE IF EXISTS pyspark_bucket") | ||
|
||
def test_row_should_be_read_only(self): | ||
row = Row(a=1, b=2) | ||
self.assertEqual(1, row.a) | ||
|
@@ -2187,6 +2193,54 @@ def test_BinaryType_serialization(self): | |
df = self.spark.createDataFrame(data, schema=schema) | ||
df.collect() | ||
|
||
def test_bucketed_write(self): | ||
data = [ | ||
(1, "foo", 3.0), (2, "foo", 5.0), | ||
(3, "bar", -1.0), (4, "bar", 6.0), | ||
] | ||
df = self.spark.createDataFrame(data, ["x", "y", "z"]) | ||
|
||
def count_bucketed_cols(names, table="pyspark_bucket"): | ||
"""Given a sequence of column names and a table name | ||
query the catalog and return number o columns which are | ||
used for bucketing | ||
""" | ||
cols = self.spark.catalog.listColumns(table) | ||
num = len([c for c in cols if c.name in names and c.isBucket]) | ||
return num | ||
|
||
# Test write with one bucketing column | ||
df.write.bucketBy(3, "x").mode("overwrite").saveAsTable("pyspark_bucket") | ||
self.assertEqual(count_bucketed_cols(["x"]), 1) | ||
self.assertSetEqual(set(data), set(self.spark.table("pyspark_bucket").collect())) | ||
|
||
# Test write two bucketing columns | ||
df.write.bucketBy(3, "x", "y").mode("overwrite").saveAsTable("pyspark_bucket") | ||
self.assertEqual(count_bucketed_cols(["x", "y"]), 2) | ||
self.assertSetEqual(set(data), set(self.spark.table("pyspark_bucket").collect())) | ||
|
||
# Test write with bucket and sort | ||
df.write.bucketBy(2, "x").sortBy("z").mode("overwrite").saveAsTable("pyspark_bucket") | ||
self.assertEqual(count_bucketed_cols(["x"]), 1) | ||
self.assertSetEqual(set(data), set(self.spark.table("pyspark_bucket").collect())) | ||
|
||
# Test write with a list of columns | ||
df.write.bucketBy(3, ["x", "y"]).mode("overwrite").saveAsTable("pyspark_bucket") | ||
self.assertEqual(count_bucketed_cols(["x", "y"]), 2) | ||
self.assertSetEqual(set(data), set(self.spark.table("pyspark_bucket").collect())) | ||
|
||
# Test write with bucket and sort with a list of columns | ||
(df.write.bucketBy(2, "x") | ||
.sortBy(["y", "z"]) | ||
.mode("overwrite").saveAsTable("pyspark_bucket")) | ||
self.assertSetEqual(set(data), set(self.spark.table("pyspark_bucket").collect())) | ||
|
||
# Test write with bucket and sort with multiple columns | ||
(df.write.bucketBy(2, "x") | ||
.sortBy("y", "z") | ||
.mode("overwrite").saveAsTable("pyspark_bucket")) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @zero323, should we drop the table before or after this test? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think that dropping before is necessary. We override on each write and name clashes are unlikely. We can drop down after the tests but I am not sure how to do it right. |
||
self.assertSetEqual(set(data), set(self.spark.table("pyspark_bucket").collect())) | ||
|
||
|
||
class HiveSparkSubmitTests(SparkSubmitTests): | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit:
columns.If
->columns. If