Skip to content

Commit

Permalink
[SPARK-48761][SQL] Introduce clusterBy DataFrameWriter API for Scala
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

Introduce a new `clusterBy` DataFrame API in Scala. This PR adds the API for both the DataFrameWriter V1 and V2, as well as Spark Connect.

### Why are the changes needed?

Introduce more ways for users to interact with clustered tables.

### Does this PR introduce _any_ user-facing change?

Yes, it adds a new `clusterBy` DataFrame API in Scala to allow specifying the clustering columns when writing DataFrames.

### How was this patch tested?

New unit tests.

### Was this patch authored or co-authored using generative AI tooling?

No

Closes #47301 from zedtang/clusterby-scala-api.

Authored-by: Jiaheng Tang <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
  • Loading branch information
zedtang authored and cloud-fan committed Jul 25, 2024
1 parent 8c625ea commit bafce5d
Show file tree
Hide file tree
Showing 18 changed files with 482 additions and 11 deletions.
14 changes: 14 additions & 0 deletions common/utils/src/main/resources/error/error-conditions.json
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,20 @@
],
"sqlState" : "0A000"
},
"CLUSTERING_COLUMNS_MISMATCH" : {
"message" : [
"Specified clustering does not match that of the existing table <tableName>.",
"Specified clustering columns: [<specifiedClusteringString>].",
"Existing clustering columns: [<existingClusteringString>]."
],
"sqlState" : "42P10"
},
"CLUSTERING_NOT_SUPPORTED" : {
"message" : [
"'<operation>' does not support clustering."
],
"sqlState" : "42000"
},
"CODEC_NOT_AVAILABLE" : {
"message" : [
"The codec <codecName> is not available."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3090,6 +3090,11 @@ class SparkConnectPlanner(
w.partitionBy(names.toSeq: _*)
}

if (writeOperation.getClusteringColumnsCount > 0) {
val names = writeOperation.getClusteringColumnsList.asScala
w.clusterBy(names.head, names.tail.toSeq: _*)
}

if (writeOperation.hasSource) {
w.format(writeOperation.getSource)
}
Expand Down Expand Up @@ -3153,6 +3158,11 @@ class SparkConnectPlanner(
w.partitionedBy(names.head, names.tail: _*)
}

if (writeOperation.getClusteringColumnsCount > 0) {
val names = writeOperation.getClusteringColumnsList.asScala
w.clusterBy(names.head, names.tail.toSeq: _*)
}

writeOperation.getMode match {
case proto.WriteOperationV2.Mode.MODE_CREATE =>
if (writeOperation.hasProvider) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,7 @@ package object dsl {
mode: Option[String] = None,
sortByColumns: Seq[String] = Seq.empty,
partitionByCols: Seq[String] = Seq.empty,
clusterByCols: Seq[String] = Seq.empty,
bucketByCols: Seq[String] = Seq.empty,
numBuckets: Option[Int] = None): Command = {
val writeOp = WriteOperation.newBuilder()
Expand All @@ -242,6 +243,7 @@ package object dsl {
}
sortByColumns.foreach(writeOp.addSortColumnNames(_))
partitionByCols.foreach(writeOp.addPartitioningColumns(_))
clusterByCols.foreach(writeOp.addClusteringColumns(_))

if (numBuckets.nonEmpty && bucketByCols.nonEmpty) {
val op = WriteOperation.BucketBy.newBuilder()
Expand Down Expand Up @@ -272,13 +274,15 @@ package object dsl {
options: Map[String, String] = Map.empty,
tableProperties: Map[String, String] = Map.empty,
partitionByCols: Seq[Expression] = Seq.empty,
clusterByCols: Seq[String] = Seq.empty,
mode: Option[String] = None,
overwriteCondition: Option[Expression] = None): Command = {
val writeOp = WriteOperationV2.newBuilder()
writeOp.setInput(logicalPlan)
tableName.foreach(writeOp.setTableName)
provider.foreach(writeOp.setProvider)
partitionByCols.foreach(writeOp.addPartitioningColumns)
clusterByCols.foreach(writeOp.addClusteringColumns)
options.foreach { case (k, v) =>
writeOp.putOptions(k, v)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -596,6 +596,48 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest {
}
}

test("Write with clustering") {
// Cluster by existing column.
withTable("testtable") {
transform(
localRelation.write(
tableName = Some("testtable"),
tableSaveMethod = Some("save_as_table"),
format = Some("parquet"),
clusterByCols = Seq("id")))
}

// Cluster by non-existing column.
assertThrows[AnalysisException](
transform(
localRelation
.write(
tableName = Some("testtable"),
tableSaveMethod = Some("save_as_table"),
format = Some("parquet"),
clusterByCols = Seq("noid"))))
}

test("Write V2 with clustering") {
// Cluster by existing column.
withTable("testtable") {
transform(
localRelation.writeV2(
tableName = Some("testtable"),
mode = Some("MODE_CREATE"),
clusterByCols = Seq("id")))
}

// Cluster by non-existing column.
assertThrows[AnalysisException](
transform(
localRelation
.writeV2(
tableName = Some("testtable"),
mode = Some("MODE_CREATE"),
clusterByCols = Seq("noid"))))
}

test("Write with invalid bucketBy configuration") {
val cmd = localRelation.write(bucketByCols = Seq("id"), numBuckets = Some(0))
assertThrows[InvalidCommandInput] {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,22 @@ final class DataFrameWriter[T] private[sql] (ds: Dataset[T]) {
this
}

/**
* Clusters the output by the given columns on the storage. The rows with matching values in the
* specified clustering columns will be consolidated within the same group.
*
* For instance, if you cluster a dataset by date, the data sharing the same date will be stored
* together in a file. This arrangement improves query efficiency when you apply selective
* filters to these clustering columns, thanks to data skipping.
*
* @since 4.0.0
*/
@scala.annotation.varargs
def clusterBy(colName: String, colNames: String*): DataFrameWriter[T] = {
this.clusteringColumns = Option(colName +: colNames)
this
}

/**
* Saves the content of the `DataFrame` at the specified path.
*
Expand Down Expand Up @@ -242,6 +258,7 @@ final class DataFrameWriter[T] private[sql] (ds: Dataset[T]) {
source.foreach(builder.setSource)
sortColumnNames.foreach(names => builder.addAllSortColumnNames(names.asJava))
partitioningColumns.foreach(cols => builder.addAllPartitioningColumns(cols.asJava))
clusteringColumns.foreach(cols => builder.addAllClusteringColumns(cols.asJava))

numBuckets.foreach(n => {
val bucketBuilder = proto.WriteOperation.BucketBy.newBuilder()
Expand Down Expand Up @@ -509,4 +526,6 @@ final class DataFrameWriter[T] private[sql] (ds: Dataset[T]) {
private var numBuckets: Option[Int] = None

private var sortColumnNames: Option[Seq[String]] = None

private var clusteringColumns: Option[Seq[String]] = None
}
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ final class DataFrameWriterV2[T] private[sql] (table: String, ds: Dataset[T])

private var partitioning: Option[Seq[proto.Expression]] = None

private var clustering: Option[Seq[String]] = None

private var overwriteCondition: Option[proto.Expression] = None

override def using(provider: String): CreateTableWriter[T] = {
Expand Down Expand Up @@ -77,6 +79,12 @@ final class DataFrameWriterV2[T] private[sql] (table: String, ds: Dataset[T])
this
}

@scala.annotation.varargs
override def clusterBy(colName: String, colNames: String*): CreateTableWriter[T] = {
this.clustering = Some(colName +: colNames)
this
}

override def create(): Unit = {
executeWriteOperation(proto.WriteOperationV2.Mode.MODE_CREATE)
}
Expand Down Expand Up @@ -133,6 +141,7 @@ final class DataFrameWriterV2[T] private[sql] (table: String, ds: Dataset[T])
provider.foreach(builder.setProvider)

partitioning.foreach(columns => builder.addAllPartitioningColumns(columns.asJava))
clustering.foreach(columns => builder.addAllClusteringColumns(columns.asJava))

options.foreach { case (k, v) =>
builder.putOptions(k, v)
Expand Down Expand Up @@ -252,8 +261,22 @@ trait CreateTableWriter[T] extends WriteConfigMethods[CreateTableWriter[T]] {
*
* @since 3.4.0
*/
@scala.annotation.varargs
def partitionedBy(column: Column, columns: Column*): CreateTableWriter[T]

/**
* Clusters the output by the given columns on the storage. The rows with matching values in the
* specified clustering columns will be consolidated within the same group.
*
* For instance, if you cluster a dataset by date, the data sharing the same date will be stored
* together in a file. This arrangement improves query efficiency when you apply selective
* filters to these clustering columns, thanks to data skipping.
*
* @since 4.0.0
*/
@scala.annotation.varargs
def clusterBy(colName: String, colNames: String*): CreateTableWriter[T]

/**
* Specifies a provider for the underlying output data source. Spark's default catalog supports
* "parquet", "json", etc.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ class ClientDatasetSuite extends ConnectFunSuite with BeforeAndAfterEach {
.setNumBuckets(2)
.addBucketColumnNames("col1")
.addBucketColumnNames("col2"))
.addClusteringColumns("col3")

val expectedPlan = proto.Plan
.newBuilder()
Expand All @@ -95,6 +96,7 @@ class ClientDatasetSuite extends ConnectFunSuite with BeforeAndAfterEach {
.sortBy("col1")
.partitionBy("col99")
.bucketBy(2, "col1", "col2")
.clusterBy("col3")
.parquet("my/test/path")
val actualPlan = service.getAndClearLatestInputPlan()
assert(actualPlan.equals(expectedPlan))
Expand Down Expand Up @@ -136,6 +138,7 @@ class ClientDatasetSuite extends ConnectFunSuite with BeforeAndAfterEach {
.setTableName("t1")
.addPartitioningColumns(col("col99").expr)
.setProvider("json")
.addClusteringColumns("col3")
.putTableProperties("key", "value")
.putOptions("key2", "value2")
.setMode(proto.WriteOperationV2.Mode.MODE_CREATE_OR_REPLACE)
Expand All @@ -147,6 +150,7 @@ class ClientDatasetSuite extends ConnectFunSuite with BeforeAndAfterEach {

df.writeTo("t1")
.partitionedBy(col("col99"))
.clusterBy("col3")
.using("json")
.tableProperty("key", "value")
.options(Map("key2" -> "value2"))
Expand Down
4 changes: 3 additions & 1 deletion project/MimaExcludes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,9 @@ object MimaExcludes {
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.SQLContext#implicits._sqlContext"),
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.SQLImplicits._sqlContext"),
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.SQLImplicits.session"),
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.SparkSession#implicits._sqlContext")
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.SparkSession#implicits._sqlContext"),
// SPARK-48761: Add clusterBy() to CreateTableWriter.
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.CreateTableWriter.clusterBy")
)

// Default exclude rules
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -197,10 +197,22 @@ object ClusterBySpec {
ret
}

/**
* Converts the clustering column property to a ClusterBySpec.
*/
def fromProperty(columns: String): ClusterBySpec = {
ClusterBySpec(mapper.readValue[Seq[Seq[String]]](columns).map(FieldReference(_)))
}

/**
* Converts a ClusterBySpec to a clustering column property map entry, with validation
* of the column names against the schema.
*
* @param schema the schema of the table.
* @param clusterBySpec the ClusterBySpec to be converted to a property.
* @param resolver the resolver used to match the column names.
* @return a map entry for the clustering column property.
*/
def toProperty(
schema: StructType,
clusterBySpec: ClusterBySpec,
Expand All @@ -209,10 +221,25 @@ object ClusterBySpec {
normalizeClusterBySpec(schema, clusterBySpec, resolver).toJson
}

/**
* Converts a ClusterBySpec to a clustering column property map entry, without validating
* the column names against the schema.
*
* @param clusterBySpec existing ClusterBySpec to be converted to properties.
* @return a map entry for the clustering column property.
*/
def toPropertyWithoutValidation(clusterBySpec: ClusterBySpec): (String, String) = {
(CatalogTable.PROP_CLUSTERING_COLUMNS -> clusterBySpec.toJson)
}

private def normalizeClusterBySpec(
schema: StructType,
clusterBySpec: ClusterBySpec,
resolver: Resolver): ClusterBySpec = {
if (schema.isEmpty) {
return clusterBySpec
}

val normalizedColumns = clusterBySpec.columnNames.map { columnName =>
val position = SchemaUtils.findColumnPosition(
columnName.fieldNames().toImmutableArraySeq, schema, resolver)
Expand All @@ -239,6 +266,10 @@ object ClusterBySpec {
val normalizedClusterBySpec = normalizeClusterBySpec(schema, clusterBySpec, resolver)
ClusterByTransform(normalizedClusterBySpec.columnNames)
}

def fromColumnNames(names: Seq[String]): ClusterBySpec = {
ClusterBySpec(names.map(FieldReference(_)))
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1866,6 +1866,18 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase with Compilat
"existingBucketString" -> existingBucketString))
}

def mismatchedTableClusteringError(
tableName: String,
specifiedClusteringString: String,
existingClusteringString: String): Throwable = {
new AnalysisException(
errorClass = "CLUSTERING_COLUMNS_MISMATCH",
messageParameters = Map(
"tableName" -> tableName,
"specifiedClusteringString" -> specifiedClusteringString,
"existingClusteringString" -> existingClusteringString))
}

def specifyPartitionNotAllowedWhenTableSchemaNotDefinedError(): Throwable = {
new AnalysisException(
errorClass = "_LEGACY_ERROR_TEMP_1165",
Expand Down Expand Up @@ -4100,4 +4112,22 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase with Compilat
messageParameters = Map("functionName" -> functionName)
)
}

def operationNotSupportClusteringError(operation: String): Throwable = {
new AnalysisException(
errorClass = "CLUSTERING_NOT_SUPPORTED",
messageParameters = Map("operation" -> operation))
}

def clusterByWithPartitionedBy(): Throwable = {
new AnalysisException(
errorClass = "SPECIFY_CLUSTER_BY_WITH_PARTITIONED_BY_IS_NOT_ALLOWED",
messageParameters = Map.empty)
}

def clusterByWithBucketing(): Throwable = {
new AnalysisException(
errorClass = "SPECIFY_CLUSTER_BY_WITH_BUCKETING_IS_NOT_ALLOWED",
messageParameters = Map.empty)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,10 @@ abstract class InMemoryBaseTable(
case (v, t) =>
throw new IllegalArgumentException(s"Match: unsupported argument(s) type - ($v, $t)")
}
case ClusterByTransform(columnNames) =>
columnNames.map { colName =>
extractor(colName.fieldNames, cleanedSchema, row)._1
}
}.toImmutableArraySeq
}

Expand Down
Loading

0 comments on commit bafce5d

Please sign in to comment.