From 81e48286806d36e7630e961168d87cbad4f10194 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Sat, 10 Feb 2018 13:00:56 +0100 Subject: [PATCH] Use separate rule and add more tests --- .../sql/catalyst/optimizer/Optimizer.scala | 14 ++- .../catalyst/plans/logical/LogicalPlan.scala | 6 +- .../plans/logical/basicLogicalOperators.scala | 12 ++- .../optimizer/EliminateSortsSuite.scala | 15 +-- .../optimizer/RemoveRedundantSortsSuite.scala | 93 +++++++++++++++++++ .../execution/columnar/InMemoryRelation.scala | 2 + .../spark/sql/execution/PlannerSuite.scala | 18 +++- 7 files changed, 133 insertions(+), 27 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantSortsSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 775633695ca53..93097852528c6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -140,6 +140,8 @@ abstract class Optimizer(sessionCatalog: SessionCatalog) operatorOptimizationBatch) :+ Batch("Join Reorder", Once, CostBasedJoinReorder) :+ + Batch("Remove Redundant Sorts", Once, + RemoveRedundantSorts) :+ Batch("Decimal Optimizations", fixedPoint, DecimalAggregates) :+ Batch("Object Expressions Optimization", fixedPoint, @@ -730,8 +732,16 @@ object EliminateSorts extends Rule[LogicalPlan] { case s @ Sort(orders, _, child) if orders.isEmpty || orders.exists(_.child.foldable) => val newOrders = orders.filterNot(_.child.foldable) if (newOrders.isEmpty) child else s.copy(order = newOrders) - case Sort(orders, true, child) if child.isSorted && child.sortedOrder.get.zip(orders).forall { - case (s1, s2) => s1.satisfies(s2) } => + } +} + +/** + * Removes Sort operations on already sorted data + */ +object RemoveRedundantSorts extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case Sort(orders, true, child) if child.sortedOrder.nonEmpty + && child.sortedOrder.zip(orders).forall { case (s1, s2) => s1.satisfies(s2) } => child } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 5b28d79537168..99b3fe6588f2e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -223,9 +223,7 @@ abstract class LogicalPlan /** * If the current plan contains sorted data, it contains the sorted order. */ - def sortedOrder: Option[Seq[SortOrder]] = None - - final def isSorted: Boolean = sortedOrder.isDefined + def sortedOrder: Seq[SortOrder] = Nil } /** @@ -283,5 +281,5 @@ abstract class BinaryNode extends LogicalPlan { } abstract class KeepOrderUnaryNode extends UnaryNode { - override final def sortedOrder: Option[Seq[SortOrder]] = child.sortedOrder + override final def sortedOrder: Seq[SortOrder] = child.sortedOrder } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index bc72dc4c92a88..127398d7e7389 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -470,7 +470,7 @@ case class Sort( child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output override def maxRows: Option[Long] = child.maxRows - override def sortedOrder: Option[Seq[SortOrder]] = Some(order) + override def sortedOrder: Seq[SortOrder] = order } /** Factory for constructing new `Range` nodes. */ @@ -524,6 +524,8 @@ case class Range( override def computeStats(): Statistics = { Statistics(sizeInBytes = LongType.defaultSize * numElements) } + + override def sortedOrder: Seq[SortOrder] = output.map(a => SortOrder(a, Descending)) } case class Aggregate( @@ -746,7 +748,7 @@ case class GlobalLimit(limitExpr: Expression, child: LogicalPlan) extends KeepOr * * See [[Limit]] for more information. */ -case class LocalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode { +case class LocalLimit(limitExpr: Expression, child: LogicalPlan) extends KeepOrderUnaryNode { override def output: Seq[Attribute] = child.output override def maxRowsPerPartition: Option[Long] = { @@ -870,9 +872,9 @@ case class RepartitionByExpression( override def maxRows: Option[Long] = child.maxRows override def shuffle: Boolean = true - override def sortedOrder: Option[Seq[SortOrder]] = partitioning match { - case RangePartitioning(sortedOrder, _) => Some(sortedOrder) - case _ => None + override def sortedOrder: Seq[SortOrder] = partitioning match { + case RangePartitioning(sortedOrder, _) => sortedOrder + case _ => Nil } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala index e1f53ef6388de..e318f36d78270 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala @@ -37,8 +37,7 @@ class EliminateSortsSuite extends PlanTest { val batches = Batch("Eliminate Sorts", FixedPoint(10), FoldablePropagation, - EliminateSorts, - CollapseProject) :: Nil + EliminateSorts) :: Nil } val testRelation = LocalRelation('a.int, 'b.int, 'c.int) @@ -84,16 +83,4 @@ class EliminateSortsSuite extends PlanTest { comparePlans(optimized, correctAnswer) } - - test("remove redundant order by") { - val orderedPlan = testRelation.select('a, 'b).orderBy('a.asc, 'b.desc_nullsFirst) - val unnecessaryReordered = orderedPlan.select('a).orderBy('a.asc, 'b.desc_nullsFirst) - val optimized = Optimize.execute(analyzer.execute(unnecessaryReordered)) - val correctAnswer = analyzer.execute(orderedPlan.select('a)) - comparePlans(Optimize.execute(optimized), correctAnswer) - val reorderedDifferently = orderedPlan.select('a).orderBy('a.asc, 'b.desc) - val nonOptimized = Optimize.execute(analyzer.execute(reorderedDifferently)) - val correctAnswerNonOptimized = analyzer.execute(reorderedDifferently) - comparePlans(nonOptimized, correctAnswerNonOptimized) - } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantSortsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantSortsSuite.scala new file mode 100644 index 0000000000000..9bc53b756ae01 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantSortsSuite.scala @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.analysis.{Analyzer, EmptyFunctionRegistry} +import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog} +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules._ +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.SQLConf.{CASE_SENSITIVE, ORDER_BY_ORDINAL} + +class RemoveRedundantSortsSuite extends PlanTest { + override val conf = new SQLConf().copy(CASE_SENSITIVE -> true, ORDER_BY_ORDINAL -> false) + val catalog = new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, conf) + val analyzer = new Analyzer(catalog, conf) + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("Remove Redundant Sorts", Once, + RemoveRedundantSorts) :: + Batch("Collapse Project", Once, + CollapseProject) :: Nil + } + + val testRelation = LocalRelation('a.int, 'b.int, 'c.int) + + test("remove redundant order by") { + val orderedPlan = testRelation.select('a, 'b).orderBy('a.asc, 'b.desc_nullsFirst) + val unnecessaryReordered = orderedPlan.select('a).orderBy('a.asc, 'b.desc_nullsFirst) + val optimized = Optimize.execute(analyzer.execute(unnecessaryReordered)) + val correctAnswer = analyzer.execute(orderedPlan.select('a)) + comparePlans(Optimize.execute(optimized), correctAnswer) + } + + test("do not remove sort if the order is different") { + val orderedPlan = testRelation.select('a, 'b).orderBy('a.asc, 'b.desc_nullsFirst) + val reorderedDifferently = orderedPlan.select('a).orderBy('a.asc, 'b.desc) + val optimized = Optimize.execute(analyzer.execute(reorderedDifferently)) + val correctAnswer = analyzer.execute(reorderedDifferently) + comparePlans(optimized, correctAnswer) + } + + test("filters don't affect order") { + val orderedPlan = testRelation.select('a, 'b).orderBy('a.asc, 'b.desc) + val filteredAndReordered = orderedPlan.where('a > Literal(10)).orderBy('a.asc, 'b.desc) + val optimized = Optimize.execute(analyzer.execute(filteredAndReordered)) + val correctAnswer = analyzer.execute(orderedPlan.where('a > Literal(10))) + comparePlans(optimized, correctAnswer) + } + + test("limits don't affect order") { + val orderedPlan = testRelation.select('a, 'b).orderBy('a.asc, 'b.desc) + val filteredAndReordered = orderedPlan.limit(Literal(10)).orderBy('a.asc, 'b.desc) + val optimized = Optimize.execute(analyzer.execute(filteredAndReordered)) + val correctAnswer = analyzer.execute(orderedPlan.limit(Literal(10))) + comparePlans(optimized, correctAnswer) + } + + test("range is already sorted") { + val inputPlan = Range(1L, 1000L, 1, 10) + val orderedPlan = inputPlan.orderBy('id.desc) + val optimized = Optimize.execute(analyzer.execute(orderedPlan)) + val correctAnswer = analyzer.execute(inputPlan) + comparePlans(optimized, correctAnswer) + } + + test("sort should not be removed when there is a node which doesn't guarantee any order") { + val orderedPlan = testRelation.select('a, 'b).orderBy('a.asc) + val groupedAndResorted = orderedPlan.groupBy('a)(sum('a)).orderBy('a.asc) + val optimized = Optimize.execute(analyzer.execute(groupedAndResorted)) + val correctAnswer = analyzer.execute(groupedAndResorted) + comparePlans(optimized, correctAnswer) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala index 22e16913d4da9..84618cf4105f1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala @@ -169,4 +169,6 @@ case class InMemoryRelation( override protected def otherCopyArgs: Seq[AnyRef] = Seq(_cachedColumnBuffers, sizeInBytesStats, statsOfPlanToCache) + + override def sortedOrder: Seq[SortOrder] = child.outputOrdering } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index f8b26f5b28cc7..a158a42b3df9c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -22,10 +22,11 @@ import org.apache.spark.sql.{execution, Row} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.{Cross, FullOuter, Inner, LeftOuter, RightOuter} -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Repartition} +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Repartition, Sort} import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.columnar.InMemoryRelation -import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReusedExchangeExec, ReuseExchange, ShuffleExchangeExec} +import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReusedExchangeExec, ReuseExchange, + ShuffleExchangeExec} import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, SortMergeJoinExec} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf @@ -197,6 +198,19 @@ class PlannerSuite extends SharedSQLContext { assert(planned.child.isInstanceOf[CollectLimitExec]) } + test("SPARK-23375: Cached sorted data doesn't need to be re-sorted") { + val query = testData.select('key, 'value).sort('key.desc).cache() + assert(query.queryExecution.optimizedPlan.isInstanceOf[InMemoryRelation]) + val resorted = query.sort('key.desc) + assert(resorted.queryExecution.optimizedPlan.collect { case s: Sort => s}.isEmpty) + assert(resorted.select('key).collect().map(_.getInt(0)).toSeq == + (1 to 100).sorted(Ordering[Int].reverse)) + // with a different order, the sort is needed + val sortedAsc = query.sort('key) + assert(sortedAsc.queryExecution.optimizedPlan.collect { case s: Sort => s}.nonEmpty) + assert(sortedAsc.select('key).collect().map(_.getInt(0)).toSeq == (1 to 100)) + } + test("PartitioningCollection") { withTempView("normal", "small", "tiny") { testData.createOrReplaceTempView("normal")