Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-16931][PYTHON][SQL] Add Python wrapper for bucketBy #17077

Closed
wants to merge 13 commits into from
57 changes: 57 additions & 0 deletions python/pyspark/sql/readwriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,6 +563,63 @@ def partitionBy(self, *cols):
self._jwrite = self._jwrite.partitionBy(_to_seq(self._spark._sc, cols))
return self

@since(2.3)
def bucketBy(self, numBuckets, col, *cols):
"""Buckets the output by the given columns.If specified,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: columns.If -> columns. If

the output is laid out on the file system similar to Hive's bucketing scheme.

:param numBuckets: the number of buckets to save
:param col: a name of a column, or a list of names.
:param cols: additional names (optional). If `col` is a list it should be empty.

.. note:: Applicable for file-based data sources in combination with
:py:meth:`DataFrameWriter.saveAsTable`.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not accurate. We also can use save to store the bucked tables without saving its metadata in metastore.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gatorsmile Can we?

➜  spark git:(master) git rev-parse HEAD   
2cf83c47838115f71419ba5b9296c69ec1d746cd
➜  spark git:(master) bin/spark-shell 
Spark context Web UI available at http://192.168.1.101:4041
Spark context available as 'sc' (master = local[*], app id = local-1494184109262).
Spark session available as 'spark'.
Welcome to
      ____              __
     / __/__  ___ _____/ /__
    _\ \/ _ \/ _ `/ __/  '_/
   /___/ .__/\_,_/_/ /_/\_\   version 2.3.0-SNAPSHOT
      /_/
         
Using Scala version 2.11.8 (OpenJDK 64-Bit Server VM, Java 1.8.0_121)
Type in expressions to have them evaluated.
Type :help for more information.

scala> Seq(("a", 1, 3)).toDF("x", "y", "z").write.bucketBy(3, "x", "y").format("parquet").save("/tmp/foo")
org.apache.spark.sql.AnalysisException: 'save' does not support bucketing right now;
  at org.apache.spark.sql.DataFrameWriter.assertNotBucketed(DataFrameWriter.scala:305)
  at org.apache.spark.sql.DataFrameWriter.save(DataFrameWriter.scala:231)
  at org.apache.spark.sql.DataFrameWriter.save(DataFrameWriter.scala:217)
  ... 48 elided

`

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

uh. Yes. Bucket info is not part of the file/directory names, unlike partitioning info.


>>> (df.write.format('parquet')
... .bucketBy(100, 'year', 'month')
... .mode("overwrite")
... .saveAsTable('bucketed_table'))
"""
if not isinstance(numBuckets, int):
raise TypeError("numBuckets should be an int, got {0}.".format(type(numBuckets)))

if isinstance(col, (list, tuple)):
if cols:
raise ValueError("col is a {0} but cols are not empty".format(type(col)))

col, cols = col[0], col[1:]

if not all(isinstance(c, basestring) for c in cols) or not(isinstance(col, basestring)):
raise TypeError("all names should be `str`")

self._jwrite = self._jwrite.bucketBy(numBuckets, col, _to_seq(self._spark._sc, cols))
return self

@since(2.3)
def sortBy(self, col, *cols):
"""Sorts the output in each bucket by the given columns on the file system.

:param col: a name of a column, or a list of names.
:param cols: additional names (optional). If `col` is a list it should be empty.

>>> (df.write.format('parquet')
... .bucketBy(100, 'year', 'month')
... .sortBy('day')
... .mode("overwrite")
... .saveAsTable('sorted_bucketed_table'))
"""
if isinstance(col, (list, tuple)):
if cols:
raise ValueError("col is a {0} but cols are not empty".format(type(col)))

col, cols = col[0], col[1:]

if not all(isinstance(c, basestring) for c in cols) or not(isinstance(col, basestring)):
raise TypeError("all names should be `str`")

self._jwrite = self._jwrite.sortBy(col, _to_seq(self._spark._sc, cols))
return self

@since(1.4)
def save(self, path=None, format=None, mode=None, partitionBy=None, **options):
"""Saves the contents of the :class:`DataFrame` to a data source.
Expand Down
54 changes: 54 additions & 0 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,12 @@ def test_sqlcontext_reuses_sparksession(self):
sqlContext2 = SQLContext(self.sc)
self.assertTrue(sqlContext1.sparkSession is sqlContext2.sparkSession)

def tearDown(self):
super(SQLTests, self).tearDown()

# tear down test_bucketed_write state
self.spark.sql("DROP TABLE IF EXISTS pyspark_bucket")

def test_row_should_be_read_only(self):
row = Row(a=1, b=2)
self.assertEqual(1, row.a)
Expand Down Expand Up @@ -2187,6 +2193,54 @@ def test_BinaryType_serialization(self):
df = self.spark.createDataFrame(data, schema=schema)
df.collect()

def test_bucketed_write(self):
data = [
(1, "foo", 3.0), (2, "foo", 5.0),
(3, "bar", -1.0), (4, "bar", 6.0),
]
df = self.spark.createDataFrame(data, ["x", "y", "z"])

def count_bucketed_cols(names, table="pyspark_bucket"):
"""Given a sequence of column names and a table name
query the catalog and return number o columns which are
used for bucketing
"""
cols = self.spark.catalog.listColumns(table)
num = len([c for c in cols if c.name in names and c.isBucket])
return num

# Test write with one bucketing column
df.write.bucketBy(3, "x").mode("overwrite").saveAsTable("pyspark_bucket")
self.assertEqual(count_bucketed_cols(["x"]), 1)
self.assertSetEqual(set(data), set(self.spark.table("pyspark_bucket").collect()))

# Test write two bucketing columns
df.write.bucketBy(3, "x", "y").mode("overwrite").saveAsTable("pyspark_bucket")
self.assertEqual(count_bucketed_cols(["x", "y"]), 2)
self.assertSetEqual(set(data), set(self.spark.table("pyspark_bucket").collect()))

# Test write with bucket and sort
df.write.bucketBy(2, "x").sortBy("z").mode("overwrite").saveAsTable("pyspark_bucket")
self.assertEqual(count_bucketed_cols(["x"]), 1)
self.assertSetEqual(set(data), set(self.spark.table("pyspark_bucket").collect()))

# Test write with a list of columns
df.write.bucketBy(3, ["x", "y"]).mode("overwrite").saveAsTable("pyspark_bucket")
self.assertEqual(count_bucketed_cols(["x", "y"]), 2)
self.assertSetEqual(set(data), set(self.spark.table("pyspark_bucket").collect()))

# Test write with bucket and sort with a list of columns
(df.write.bucketBy(2, "x")
.sortBy(["y", "z"])
.mode("overwrite").saveAsTable("pyspark_bucket"))
self.assertSetEqual(set(data), set(self.spark.table("pyspark_bucket").collect()))

# Test write with bucket and sort with multiple columns
(df.write.bucketBy(2, "x")
.sortBy("y", "z")
.mode("overwrite").saveAsTable("pyspark_bucket"))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zero323, should we drop the table before or after this test?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think that dropping before is necessary. We override on each write and name clashes are unlikely.

We can drop down after the tests but I am not sure how to do it right. SQLTests is overgrown and I am not sure if we should add tearDown only for this but adding DROP TABLE in test itself doesn't look right.

self.assertSetEqual(set(data), set(self.spark.table("pyspark_bucket").collect()))


class HiveSparkSubmitTests(SparkSubmitTests):

Expand Down