Skip to content

Commit

Permalink
[SPARK-34026][SQL] Inject repartition and sort nodes to satisfy requi…
Browse files Browse the repository at this point in the history
…red distribution and ordering (apache#905)

### What changes were proposed in this pull request?

This PR adds repartition and sort nodes to satisfy the required distribution and ordering introduced in SPARK-33779.

Note: This PR contains the final part of changes discussed in PR apache#29066.

### Why are the changes needed?

These changes are the next step as discussed in the [design doc](https://docs.google.com/document/d/1X0NsQSryvNmXBY9kcvfINeYyKC-AahZarUqg3nS1GQs/edit#) for SPARK-23889.

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

This PR comes with a new test suite.

Closes apache#31083 from aokolnychyi/spark-34026.

Authored-by: Anton Okolnychyi <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
  • Loading branch information
aokolnychyi authored and dongjoon-hyun committed Mar 1, 2021
1 parent 3f9a3bf commit bb235ce
Show file tree
Hide file tree
Showing 5 changed files with 708 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, JoinedRow}
import org.apache.spark.sql.catalyst.util.{CharVarcharUtils, DateTimeUtils}
import org.apache.spark.sql.connector.catalog._
import org.apache.spark.sql.connector.expressions.{BucketTransform, DaysTransform, HoursTransform, IdentityTransform, MonthsTransform, Transform, YearsTransform}
import org.apache.spark.sql.connector.distributions.{Distribution, Distributions}
import org.apache.spark.sql.connector.expressions.{BucketTransform, DaysTransform, HoursTransform, IdentityTransform, MonthsTransform, SortOrder, Transform, YearsTransform}
import org.apache.spark.sql.connector.read._
import org.apache.spark.sql.connector.write._
import org.apache.spark.sql.connector.write.streaming.{StreamingDataWriterFactory, StreamingWrite}
Expand All @@ -46,7 +47,9 @@ class InMemoryTable(
val name: String,
val schema: StructType,
override val partitioning: Array[Transform],
override val properties: util.Map[String, String])
override val properties: util.Map[String, String],
val distribution: Distribution = Distributions.unspecified(),
val ordering: Array[SortOrder] = Array.empty)
extends Table with SupportsRead with SupportsWrite with SupportsDelete
with SupportsMetadataColumns {

Expand Down Expand Up @@ -262,7 +265,11 @@ class InMemoryTable(
this
}

override def build(): Write = new Write {
override def build(): Write = new Write with RequiresDistributionAndOrdering {
override def requiredDistribution: Distribution = distribution

override def requiredOrdering: Array[SortOrder] = ordering

override def toBatch: BatchWrite = writer

override def toStreaming: StreamingWrite = streamingWriter match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ import scala.collection.JavaConverters._

import org.apache.spark.sql.catalyst.analysis.{NamespaceAlreadyExistsException, NoSuchNamespaceException, NoSuchTableException, TableAlreadyExistsException}
import org.apache.spark.sql.connector.catalog._
import org.apache.spark.sql.connector.expressions.Transform
import org.apache.spark.sql.connector.distributions.{Distribution, Distributions}
import org.apache.spark.sql.connector.expressions.{SortOrder, Transform}
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.CaseInsensitiveStringMap

Expand Down Expand Up @@ -69,13 +70,24 @@ class BasicInMemoryTableCatalog extends TableCatalog {
schema: StructType,
partitions: Array[Transform],
properties: util.Map[String, String]): Table = {
createTable(ident, schema, partitions, properties, Distributions.unspecified(), Array.empty)
}

def createTable(
ident: Identifier,
schema: StructType,
partitions: Array[Transform],
properties: util.Map[String, String],
distribution: Distribution,
ordering: Array[SortOrder]): Table = {
if (tables.containsKey(ident)) {
throw new TableAlreadyExistsException(ident)
}

InMemoryTableCatalog.maybeSimulateFailedTableCreation(properties)

val table = new InMemoryTable(s"$name.${ident.quoted}", schema, partitions, properties)
val tableName = s"$name.${ident.quoted}"
val table = new InMemoryTable(tableName, schema, partitions, properties, distribution, ordering)
tables.put(ident, table)
namespaces.putIfAbsent(ident.namespace.toList, Map())
table
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.datasources.v2

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.analysis.Resolver
import org.apache.spark.sql.catalyst.expressions.{Ascending, Descending, Expression, NamedExpression, NullOrdering, NullsFirst, NullsLast, SortDirection, SortOrder}
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, RepartitionByExpression, Sort}
import org.apache.spark.sql.connector.distributions.{ClusteredDistribution, OrderedDistribution, UnspecifiedDistribution}
import org.apache.spark.sql.connector.expressions.{Expression => V2Expression, FieldReference, IdentityTransform, NullOrdering => V2NullOrdering, SortDirection => V2SortDirection, SortValue}
import org.apache.spark.sql.connector.write.{RequiresDistributionAndOrdering, Write}
import org.apache.spark.sql.internal.SQLConf

object DistributionAndOrderingUtils {

def prepareQuery(write: Write, query: LogicalPlan, conf: SQLConf): LogicalPlan = write match {
case write: RequiresDistributionAndOrdering =>
val resolver = conf.resolver

val distribution = write.requiredDistribution match {
case d: OrderedDistribution =>
d.ordering.map(e => toCatalyst(e, query, resolver))
case d: ClusteredDistribution =>
d.clustering.map(e => toCatalyst(e, query, resolver))
case _: UnspecifiedDistribution =>
Array.empty[Expression]
}

val queryWithDistribution = if (distribution.nonEmpty) {
val numShufflePartitions = conf.numShufflePartitions
// the conversion to catalyst expressions above produces SortOrder expressions
// for OrderedDistribution and generic expressions for ClusteredDistribution
// this allows RepartitionByExpression to pick either range or hash partitioning
RepartitionByExpression(distribution, query, numShufflePartitions)
} else {
query
}

val ordering = write.requiredOrdering.toSeq
.map(e => toCatalyst(e, query, resolver))
.asInstanceOf[Seq[SortOrder]]

val queryWithDistributionAndOrdering = if (ordering.nonEmpty) {
Sort(ordering, global = false, queryWithDistribution)
} else {
queryWithDistribution
}

queryWithDistributionAndOrdering

case _ =>
query
}

private def toCatalyst(
expr: V2Expression,
query: LogicalPlan,
resolver: Resolver): Expression = {

// we cannot perform the resolution in the analyzer since we need to optimize expressions
// in nodes like OverwriteByExpression before constructing a logical write
def resolve(ref: FieldReference): NamedExpression = {
query.resolve(ref.parts, resolver) match {
case Some(attr) => attr
case None => throw new AnalysisException(s"Cannot resolve '$ref' using ${query.output}")
}
}

expr match {
case SortValue(child, direction, nullOrdering) =>
val catalystChild = toCatalyst(child, query, resolver)
SortOrder(catalystChild, toCatalyst(direction), toCatalyst(nullOrdering), Seq.empty)
case IdentityTransform(ref) =>
resolve(ref)
case ref: FieldReference =>
resolve(ref)
case _ =>
throw new AnalysisException(s"$expr is not currently supported")
}
}

private def toCatalyst(direction: V2SortDirection): SortDirection = direction match {
case V2SortDirection.ASCENDING => Ascending
case V2SortDirection.DESCENDING => Descending
}

private def toCatalyst(nullOrdering: V2NullOrdering): NullOrdering = nullOrdering match {
case V2NullOrdering.NULLS_FIRST => NullsFirst
case V2NullOrdering.NULLS_LAST => NullsLast
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ object V2Writes extends Rule[LogicalPlan] with PredicateHelper {
case a @ AppendData(r: DataSourceV2Relation, query, options, _, None) =>
val writeBuilder = newWriteBuilder(r.table, query, options)
val write = writeBuilder.build()
a.copy(write = Some(write))
val newQuery = DistributionAndOrderingUtils.prepareQuery(write, query, conf)
a.copy(write = Some(write), query = newQuery)

case o @ OverwriteByExpression(r: DataSourceV2Relation, deleteExpr, query, options, _, None) =>
// fail if any filter cannot be converted. correctness depends on removing all matching data.
Expand All @@ -63,7 +64,8 @@ object V2Writes extends Rule[LogicalPlan] with PredicateHelper {
throw new SparkException(s"Table does not support overwrite by expression: $table")
}

o.copy(write = Some(write))
val newQuery = DistributionAndOrderingUtils.prepareQuery(write, query, conf)
o.copy(write = Some(write), query = newQuery)

case o @ OverwritePartitionsDynamic(r: DataSourceV2Relation, query, options, _, None) =>
val table = r.table
Expand All @@ -74,7 +76,8 @@ object V2Writes extends Rule[LogicalPlan] with PredicateHelper {
case _ =>
throw new SparkException(s"Table does not support dynamic partition overwrite: $table")
}
o.copy(write = Some(write))
val newQuery = DistributionAndOrderingUtils.prepareQuery(write, query, conf)
o.copy(write = Some(write), query = newQuery)
}

private def isTruncate(filters: Array[Filter]): Boolean = {
Expand Down
Loading

0 comments on commit bb235ce

Please sign in to comment.