Skip to content

Commit

Permalink
Use separate rule and add more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mgaido91 committed Feb 10, 2018
1 parent 550ff99 commit 81e4828
Show file tree
Hide file tree
Showing 7 changed files with 133 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,8 @@ abstract class Optimizer(sessionCatalog: SessionCatalog)
operatorOptimizationBatch) :+
Batch("Join Reorder", Once,
CostBasedJoinReorder) :+
Batch("Remove Redundant Sorts", Once,
RemoveRedundantSorts) :+
Batch("Decimal Optimizations", fixedPoint,
DecimalAggregates) :+
Batch("Object Expressions Optimization", fixedPoint,
Expand Down Expand Up @@ -730,8 +732,16 @@ object EliminateSorts extends Rule[LogicalPlan] {
case s @ Sort(orders, _, child) if orders.isEmpty || orders.exists(_.child.foldable) =>
val newOrders = orders.filterNot(_.child.foldable)
if (newOrders.isEmpty) child else s.copy(order = newOrders)
case Sort(orders, true, child) if child.isSorted && child.sortedOrder.get.zip(orders).forall {
case (s1, s2) => s1.satisfies(s2) } =>
}
}

/**
* Removes Sort operations on already sorted data
*/
object RemoveRedundantSorts extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case Sort(orders, true, child) if child.sortedOrder.nonEmpty
&& child.sortedOrder.zip(orders).forall { case (s1, s2) => s1.satisfies(s2) } =>
child
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -223,9 +223,7 @@ abstract class LogicalPlan
/**
* If the current plan contains sorted data, it contains the sorted order.
*/
def sortedOrder: Option[Seq[SortOrder]] = None

final def isSorted: Boolean = sortedOrder.isDefined
def sortedOrder: Seq[SortOrder] = Nil
}

/**
Expand Down Expand Up @@ -283,5 +281,5 @@ abstract class BinaryNode extends LogicalPlan {
}

abstract class KeepOrderUnaryNode extends UnaryNode {
override final def sortedOrder: Option[Seq[SortOrder]] = child.sortedOrder
override final def sortedOrder: Seq[SortOrder] = child.sortedOrder
}
Original file line number Diff line number Diff line change
Expand Up @@ -470,7 +470,7 @@ case class Sort(
child: LogicalPlan) extends UnaryNode {
override def output: Seq[Attribute] = child.output
override def maxRows: Option[Long] = child.maxRows
override def sortedOrder: Option[Seq[SortOrder]] = Some(order)
override def sortedOrder: Seq[SortOrder] = order
}

/** Factory for constructing new `Range` nodes. */
Expand Down Expand Up @@ -524,6 +524,8 @@ case class Range(
override def computeStats(): Statistics = {
Statistics(sizeInBytes = LongType.defaultSize * numElements)
}

override def sortedOrder: Seq[SortOrder] = output.map(a => SortOrder(a, Descending))
}

case class Aggregate(
Expand Down Expand Up @@ -746,7 +748,7 @@ case class GlobalLimit(limitExpr: Expression, child: LogicalPlan) extends KeepOr
*
* See [[Limit]] for more information.
*/
case class LocalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode {
case class LocalLimit(limitExpr: Expression, child: LogicalPlan) extends KeepOrderUnaryNode {
override def output: Seq[Attribute] = child.output

override def maxRowsPerPartition: Option[Long] = {
Expand Down Expand Up @@ -870,9 +872,9 @@ case class RepartitionByExpression(
override def maxRows: Option[Long] = child.maxRows
override def shuffle: Boolean = true

override def sortedOrder: Option[Seq[SortOrder]] = partitioning match {
case RangePartitioning(sortedOrder, _) => Some(sortedOrder)
case _ => None
override def sortedOrder: Seq[SortOrder] = partitioning match {
case RangePartitioning(sortedOrder, _) => sortedOrder
case _ => Nil
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,7 @@ class EliminateSortsSuite extends PlanTest {
val batches =
Batch("Eliminate Sorts", FixedPoint(10),
FoldablePropagation,
EliminateSorts,
CollapseProject) :: Nil
EliminateSorts) :: Nil
}

val testRelation = LocalRelation('a.int, 'b.int, 'c.int)
Expand Down Expand Up @@ -84,16 +83,4 @@ class EliminateSortsSuite extends PlanTest {

comparePlans(optimized, correctAnswer)
}

test("remove redundant order by") {
val orderedPlan = testRelation.select('a, 'b).orderBy('a.asc, 'b.desc_nullsFirst)
val unnecessaryReordered = orderedPlan.select('a).orderBy('a.asc, 'b.desc_nullsFirst)
val optimized = Optimize.execute(analyzer.execute(unnecessaryReordered))
val correctAnswer = analyzer.execute(orderedPlan.select('a))
comparePlans(Optimize.execute(optimized), correctAnswer)
val reorderedDifferently = orderedPlan.select('a).orderBy('a.asc, 'b.desc)
val nonOptimized = Optimize.execute(analyzer.execute(reorderedDifferently))
val correctAnswerNonOptimized = analyzer.execute(reorderedDifferently)
comparePlans(nonOptimized, correctAnswerNonOptimized)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.analysis.{Analyzer, EmptyFunctionRegistry}
import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog}
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf.{CASE_SENSITIVE, ORDER_BY_ORDINAL}

class RemoveRedundantSortsSuite extends PlanTest {
override val conf = new SQLConf().copy(CASE_SENSITIVE -> true, ORDER_BY_ORDINAL -> false)
val catalog = new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, conf)
val analyzer = new Analyzer(catalog, conf)

object Optimize extends RuleExecutor[LogicalPlan] {
val batches =
Batch("Remove Redundant Sorts", Once,
RemoveRedundantSorts) ::
Batch("Collapse Project", Once,
CollapseProject) :: Nil
}

val testRelation = LocalRelation('a.int, 'b.int, 'c.int)

test("remove redundant order by") {
val orderedPlan = testRelation.select('a, 'b).orderBy('a.asc, 'b.desc_nullsFirst)
val unnecessaryReordered = orderedPlan.select('a).orderBy('a.asc, 'b.desc_nullsFirst)
val optimized = Optimize.execute(analyzer.execute(unnecessaryReordered))
val correctAnswer = analyzer.execute(orderedPlan.select('a))
comparePlans(Optimize.execute(optimized), correctAnswer)
}

test("do not remove sort if the order is different") {
val orderedPlan = testRelation.select('a, 'b).orderBy('a.asc, 'b.desc_nullsFirst)
val reorderedDifferently = orderedPlan.select('a).orderBy('a.asc, 'b.desc)
val optimized = Optimize.execute(analyzer.execute(reorderedDifferently))
val correctAnswer = analyzer.execute(reorderedDifferently)
comparePlans(optimized, correctAnswer)
}

test("filters don't affect order") {
val orderedPlan = testRelation.select('a, 'b).orderBy('a.asc, 'b.desc)
val filteredAndReordered = orderedPlan.where('a > Literal(10)).orderBy('a.asc, 'b.desc)
val optimized = Optimize.execute(analyzer.execute(filteredAndReordered))
val correctAnswer = analyzer.execute(orderedPlan.where('a > Literal(10)))
comparePlans(optimized, correctAnswer)
}

test("limits don't affect order") {
val orderedPlan = testRelation.select('a, 'b).orderBy('a.asc, 'b.desc)
val filteredAndReordered = orderedPlan.limit(Literal(10)).orderBy('a.asc, 'b.desc)
val optimized = Optimize.execute(analyzer.execute(filteredAndReordered))
val correctAnswer = analyzer.execute(orderedPlan.limit(Literal(10)))
comparePlans(optimized, correctAnswer)
}

test("range is already sorted") {
val inputPlan = Range(1L, 1000L, 1, 10)
val orderedPlan = inputPlan.orderBy('id.desc)
val optimized = Optimize.execute(analyzer.execute(orderedPlan))
val correctAnswer = analyzer.execute(inputPlan)
comparePlans(optimized, correctAnswer)
}

test("sort should not be removed when there is a node which doesn't guarantee any order") {
val orderedPlan = testRelation.select('a, 'b).orderBy('a.asc)
val groupedAndResorted = orderedPlan.groupBy('a)(sum('a)).orderBy('a.asc)
val optimized = Optimize.execute(analyzer.execute(groupedAndResorted))
val correctAnswer = analyzer.execute(groupedAndResorted)
comparePlans(optimized, correctAnswer)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -169,4 +169,6 @@ case class InMemoryRelation(

override protected def otherCopyArgs: Seq[AnyRef] =
Seq(_cachedColumnBuffers, sizeInBytesStats, statsOfPlanToCache)

override def sortedOrder: Seq[SortOrder] = child.outputOrdering
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,11 @@ import org.apache.spark.sql.{execution, Row}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.{Cross, FullOuter, Inner, LeftOuter, RightOuter}
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Repartition}
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Repartition, Sort}
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution.columnar.InMemoryRelation
import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReusedExchangeExec, ReuseExchange, ShuffleExchangeExec}
import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReusedExchangeExec, ReuseExchange,
ShuffleExchangeExec}
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, SortMergeJoinExec}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
Expand Down Expand Up @@ -197,6 +198,19 @@ class PlannerSuite extends SharedSQLContext {
assert(planned.child.isInstanceOf[CollectLimitExec])
}

test("SPARK-23375: Cached sorted data doesn't need to be re-sorted") {
val query = testData.select('key, 'value).sort('key.desc).cache()
assert(query.queryExecution.optimizedPlan.isInstanceOf[InMemoryRelation])
val resorted = query.sort('key.desc)
assert(resorted.queryExecution.optimizedPlan.collect { case s: Sort => s}.isEmpty)
assert(resorted.select('key).collect().map(_.getInt(0)).toSeq ==
(1 to 100).sorted(Ordering[Int].reverse))
// with a different order, the sort is needed
val sortedAsc = query.sort('key)
assert(sortedAsc.queryExecution.optimizedPlan.collect { case s: Sort => s}.nonEmpty)
assert(sortedAsc.select('key).collect().map(_.getInt(0)).toSeq == (1 to 100))
}

test("PartitioningCollection") {
withTempView("normal", "small", "tiny") {
testData.createOrReplaceTempView("normal")
Expand Down

0 comments on commit 81e4828

Please sign in to comment.