Skip to content

Commit

Permalink
add physical rule
Browse files Browse the repository at this point in the history
  • Loading branch information
allisonwang-db committed Oct 22, 2020
1 parent 2966802 commit bcbd9fa
Show file tree
Hide file tree
Showing 7 changed files with 193 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1052,6 +1052,8 @@ object CombineFilters extends Rule[LogicalPlan] with PredicateHelper {
* function is order irrelevant
*/
object EliminateSorts extends Rule[LogicalPlan] {
// transformUp is needed here to ensure idempotency of this rule when removing consecutive
// local sorts.
def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
case s @ Sort(orders, _, child) if orders.isEmpty || orders.exists(_.child.foldable) =>
val newOrders = orders.filterNot(_.child.foldable)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1242,6 +1242,13 @@ object SQLConf {
.booleanConf
.createWithDefault(true)

val REMOVE_REDUNDANT_SORTS_ENABLED = buildConf("spark.sql.execution.removeRedundantSorts")
.internal()
.doc("Whether to remove redundant physical sort node")
.version("3.1.0")
.booleanConf
.createWithDefault(true)

val STATE_STORE_PROVIDER_CLASS =
buildConf("spark.sql.streaming.stateStore.providerClass")
.internal()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,15 +128,15 @@ class EliminateSortsSuite extends PlanTest {
comparePlans(optimized, correctAnswer)
}

test("SPARK-33183: filters should not affect order for local sort") {
test("SPARK-33183: remove top level local sort with filter operators") {
val orderedPlan = testRelation.select('a, 'b).orderBy('a.asc, 'b.desc)
val filteredAndReordered = orderedPlan.where('a > Literal(10)).sortBy('a.asc, 'b.desc)
val optimized = Optimize.execute(filteredAndReordered.analyze)
val correctAnswer = orderedPlan.where('a > Literal(10)).analyze
comparePlans(optimized, correctAnswer)
}

test("SPARK-33183: should not remove global sort with filter operators") {
test("SPARK-33183: keep top level global sort with filter operators") {
val projectPlan = testRelation.select('a, 'b)
val orderedPlan = projectPlan.orderBy('a.asc, 'b.desc)
val filteredAndReordered = orderedPlan.where('a > Literal(10)).orderBy('a.asc, 'b.desc)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,7 @@ object QueryExecution {
PlanDynamicPruningFilters(sparkSession),
PlanSubqueries(sparkSession),
RemoveRedundantProjects(sparkSession.sessionState.conf),
RemoveRedundantSorts(sparkSession.sessionState.conf),
EnsureRequirements(sparkSession.sessionState.conf),
DisableUnnecessaryBucketedScan(sparkSession.sessionState.conf),
ApplyColumnarRulesAndInsertTransitions(sparkSession.sessionState.conf,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution

import org.apache.spark.sql.catalyst.expressions.SortOrder
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.internal.SQLConf

/**
* Remove redundant SortExec node from the spark plan. A sort node is redundant when
* its child satisfies both its sort orders and its required child distribution.
*/
case class RemoveRedundantSorts(conf: SQLConf) extends Rule[SparkPlan] {
def apply(plan: SparkPlan): SparkPlan = {
if (!conf.getConf(SQLConf.REMOVE_REDUNDANT_SORTS_ENABLED)) {
plan
} else {
removeSorts(plan)
}
}

private def removeSorts(plan: SparkPlan): SparkPlan = plan transform {
case s @ SortExec(orders, _, child, _)
if SortOrder.orderingSatisfies(child.outputOrdering, orders) &&
child.outputPartitioning.satisfies(s.requiredChildDistribution.head) =>
child
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -83,13 +83,15 @@ case class AdaptiveSparkPlanExec(
@transient private val optimizer = new AQEOptimizer(conf)

@transient private val removeRedundantProjects = RemoveRedundantProjects(conf)
@transient private val removeRedundantSorts = RemoveRedundantSorts(conf)
@transient private val ensureRequirements = EnsureRequirements(conf)

// A list of physical plan rules to be applied before creation of query stages. The physical
// plan should reach a final status of query stages (i.e., no more addition or removal of
// Exchange nodes) after running these rules.
private def queryStagePreparationRules: Seq[Rule[SparkPlan]] = Seq(
removeRedundantProjects,
removeRedundantSorts,
ensureRequirements
) ++ context.session.sessionState.queryStagePrepRules

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution

import org.apache.spark.sql.{DataFrame, QueryTest}
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanHelper, DisableAdaptiveExecutionSuite, EnableAdaptiveExecutionSuite}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession


abstract class RemoveRedundantSortsSuiteBase
extends QueryTest
with SharedSparkSession
with AdaptiveSparkPlanHelper {
import testImplicits._

private def checkNumSorts(df: DataFrame, count: Int): Unit = {
val plan = df.queryExecution.executedPlan
assert(collectWithSubqueries(plan) { case s: SortExec => s }.length == count)
}

private def checkSorts(query: String, enabledCount: Int, disabledCount: Int): Unit = {
withSQLConf(SQLConf.REMOVE_REDUNDANT_SORTS_ENABLED.key -> "true") {
val df = sql(query)
checkNumSorts(df, enabledCount)
val result = df.collect()
withSQLConf(SQLConf.REMOVE_REDUNDANT_SORTS_ENABLED.key -> "false") {
val df = sql(query)
checkNumSorts(df, disabledCount)
checkAnswer(df, result)
}
}
}

test("remove redundant sorts with limit") {
withTempView("t") {
spark.range(100).select('id as "key").createOrReplaceTempView("t")
val query =
"""
|SELECT key FROM
| (SELECT key FROM t WHERE key > 10 ORDER BY key DESC LIMIT 10)
|ORDER BY key DESC
|""".stripMargin
checkSorts(query, 0, 1)
}
}

test("remove redundant sorts with broadcast hash join") {
withTempView("t1", "t2") {
spark.range(1000).select('id as "key").createOrReplaceTempView("t1")
spark.range(1000).select('id as "key").createOrReplaceTempView("t2")
val queryTemplate = """
|SELECT t1.key FROM
| (SELECT key FROM t1 WHERE key > 10 ORDER BY key DESC LIMIT 10) t1
|%s
| (SELECT key FROM t2 WHERE key > 50 ORDER BY key DESC LIMIT 100) t2
|ON t1.key = t2.key
|ORDER BY %s
""".stripMargin

val innerJoinAsc = queryTemplate.format("JOIN", "t2.key ASC")
checkSorts(innerJoinAsc, 1, 1)

val innerJoinDesc = queryTemplate.format("JOIN", "t2.key DESC")
checkSorts(innerJoinDesc, 0, 1)

val innerJoinDesc1 = queryTemplate.format("JOIN", "t1.key DESC")
checkSorts(innerJoinDesc1, 1, 1)

val leftOuterJoinDesc = queryTemplate.format("LEFT JOIN", "t1.key DESC")
checkSorts(leftOuterJoinDesc, 0, 1)
}
}

test("remove redundant sorts with sort merge join") {
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
withTempView("t1", "t2") {
spark.range(1000).select('id as "key").createOrReplaceTempView("t1")
spark.range(1000).select('id as "key").createOrReplaceTempView("t2")
val query = """
|SELECT t1.key FROM
| (SELECT key FROM t1 WHERE key > 10 ORDER BY key DESC LIMIT 10) t1
|JOIN
| (SELECT key FROM t2 WHERE key > 50 ORDER BY key DESC LIMIT 100) t2
|ON t1.key = t2.key
|ORDER BY t1.key
""".stripMargin

val queryAsc = query + " ASC"
checkSorts(queryAsc, 2, 3)

// Top level sort should only be eliminated if it's order is descending with SMJ.
val queryDesc = query + " DESC"
checkSorts(queryDesc, 3, 3)
}
}
}

test("cached sorted data doesn't need to be re-sorted") {
withSQLConf(SQLConf.REMOVE_REDUNDANT_SORTS_ENABLED.key -> "true") {
val df = spark.range(1000).select('id as "key").sort('key.desc).cache()
val resorted = df.sort('key.desc)
val sortedAsc = df.sort('key.asc)
checkNumSorts(df, 0)
checkNumSorts(resorted, 0)
checkNumSorts(sortedAsc, 1)
val result = resorted.collect()
withSQLConf(SQLConf.REMOVE_REDUNDANT_SORTS_ENABLED.key -> "false") {
val resorted = df.sort('key.desc)
checkNumSorts(resorted, 1)
checkAnswer(resorted, result)
}
}
}
}

class RemoveRedundantSortsSuite extends RemoveRedundantSortsSuiteBase
with DisableAdaptiveExecutionSuite

class RemoveRedundantSortsSuiteAE extends RemoveRedundantSortsSuiteBase
with EnableAdaptiveExecutionSuite

0 comments on commit bcbd9fa

Please sign in to comment.