Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
allisonwang-db committed Oct 26, 2020
1 parent 2e5bc2c commit 59f9cd4
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1054,7 +1054,7 @@ object CombineFilters extends Rule[LogicalPlan] with PredicateHelper {
object EliminateSorts extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform applyLocally

val applyLocally: PartialFunction[LogicalPlan, LogicalPlan] = {
private val applyLocally: PartialFunction[LogicalPlan, LogicalPlan] = {
case s @ Sort(orders, _, child) if orders.isEmpty || orders.exists(_.child.foldable) =>
val newOrders = orders.filterNot(_.child.foldable)
if (newOrders.isEmpty) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@ import org.apache.spark.sql.internal.SQLConf

/**
* Remove redundant SortExec node from the spark plan. A sort node is redundant when
* its child satisfies both its sort orders and its required child distribution.
* its child satisfies both its sort orders and its required child distribution. Note
* this rule differs from the Optimizer rule EliminateSorts in that this rule also checks
* if the child satisfies the required distribution so that it is safe to remove not only a
* local sort but also a global sort when its child already satisfies required sort orders.
*/
case class RemoveRedundantSorts(conf: SQLConf) extends Rule[SparkPlan] {
def apply(plan: SparkPlan): SparkPlan = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ abstract class RemoveRedundantSortsSuiteBase
withTempView("t1", "t2") {
spark.range(1000).select('id as "key").createOrReplaceTempView("t1")
spark.range(1000).select('id as "key").createOrReplaceTempView("t2")

val queryTemplate = """
|SELECT /*+ BROADCAST(%s) */ t1.key FROM
| (SELECT key FROM t1 WHERE key > 10 ORDER BY key DESC LIMIT 10) t1
Expand All @@ -74,17 +74,25 @@ abstract class RemoveRedundantSortsSuiteBase
|ORDER BY %s
""".stripMargin

val innerJoinAsc = queryTemplate.format("t1", "t2.key ASC")
checkSorts(innerJoinAsc, 1, 1)

val innerJoinDesc = queryTemplate.format("t1", "t2.key DESC")
checkSorts(innerJoinDesc, 0, 1)

val innerJoinDesc1 = queryTemplate.format("t1", "t1.key DESC")
checkSorts(innerJoinDesc1, 1, 1)

val leftOuterJoinDesc = queryTemplate.format("t2", "t1.key DESC")
checkSorts(leftOuterJoinDesc, 0, 1)
// No sort should be removed since the stream side (t2) order DESC
// does not satisfy the required sort order ASC.
val buildLeftOrderByRightAsc = queryTemplate.format("t1", "t2.key ASC")
checkSorts(buildLeftOrderByRightAsc, 1, 1)

// The top sort node should be removed since the stream side (t2) order DESC already
// satisfies the required sort order DESC.
val buildLeftOrderByRightDesc = queryTemplate.format("t1", "t2.key DESC")
checkSorts(buildLeftOrderByRightDesc, 0, 1)

// No sort should be removed since the sort ordering from broadcast-hash join is based
// on the stream side (t2) and the required sort order is from t1.
val buildLeftOrderByLeftDesc = queryTemplate.format("t1", "t1.key DESC")
checkSorts(buildLeftOrderByLeftDesc, 1, 1)

// The top sort node should be removed since the stream side (t1) order DESC already
// satisfies the required sort order DESC.
val buildRightOrderByLeftDesc = queryTemplate.format("t2", "t1.key DESC")
checkSorts(buildRightOrderByLeftDesc, 0, 1)
}
}

Expand All @@ -104,7 +112,8 @@ abstract class RemoveRedundantSortsSuiteBase
val queryAsc = query + " ASC"
checkSorts(queryAsc, 2, 3)

// Top level sort should only be eliminated if it's order is descending with SMJ.
// The top level sort should not be removed since the child output ordering is ASC and
// the required ordering is DESC.
val queryDesc = query + " DESC"
checkSorts(queryDesc, 3, 3)
}
Expand Down

0 comments on commit 59f9cd4

Please sign in to comment.