diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 059b9bbf57dc1..82a0ca6a5795f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -736,7 +736,7 @@ object EliminateSorts extends Rule[LogicalPlan] { } /** - * Removes Sort operations on already sorted data + * Removes Sort operation if the child is already sorted */ object RemoveRedundantSorts extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 4b2b1b7c2f9c4..ff1409a1fe450 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -525,7 +525,14 @@ case class Range( Statistics(sizeInBytes = LongType.defaultSize * numElements) } - override def outputOrdering: Seq[SortOrder] = output.map(a => SortOrder(a, Descending)) + override def outputOrdering: Seq[SortOrder] = { + val order = if (step > 0) { + Ascending + } else { + Descending + } + output.map(a => SortOrder(a, order)) + } } case class Aggregate( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantSortsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantSortsSuite.scala index 9bc53b756ae01..2319ab8046e56 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantSortsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantSortsSuite.scala @@ -29,9 +29,6 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.{CASE_SENSITIVE, ORDER_BY_ORDINAL} class RemoveRedundantSortsSuite extends PlanTest { - override val conf = new SQLConf().copy(CASE_SENSITIVE -> true, ORDER_BY_ORDINAL -> false) - val catalog = new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, conf) - val analyzer = new Analyzer(catalog, conf) object Optimize extends RuleExecutor[LogicalPlan] { val batches = @@ -46,48 +43,59 @@ class RemoveRedundantSortsSuite extends PlanTest { test("remove redundant order by") { val orderedPlan = testRelation.select('a, 'b).orderBy('a.asc, 'b.desc_nullsFirst) val unnecessaryReordered = orderedPlan.select('a).orderBy('a.asc, 'b.desc_nullsFirst) - val optimized = Optimize.execute(analyzer.execute(unnecessaryReordered)) - val correctAnswer = analyzer.execute(orderedPlan.select('a)) + val optimized = Optimize.execute(unnecessaryReordered.analyze) + val correctAnswer = orderedPlan.select('a).analyze comparePlans(Optimize.execute(optimized), correctAnswer) } test("do not remove sort if the order is different") { val orderedPlan = testRelation.select('a, 'b).orderBy('a.asc, 'b.desc_nullsFirst) val reorderedDifferently = orderedPlan.select('a).orderBy('a.asc, 'b.desc) - val optimized = Optimize.execute(analyzer.execute(reorderedDifferently)) - val correctAnswer = analyzer.execute(reorderedDifferently) + val optimized = Optimize.execute(reorderedDifferently.analyze) + val correctAnswer = reorderedDifferently.analyze comparePlans(optimized, correctAnswer) } test("filters don't affect order") { val orderedPlan = testRelation.select('a, 'b).orderBy('a.asc, 'b.desc) val filteredAndReordered = orderedPlan.where('a > Literal(10)).orderBy('a.asc, 'b.desc) - val optimized = Optimize.execute(analyzer.execute(filteredAndReordered)) - val correctAnswer = analyzer.execute(orderedPlan.where('a > Literal(10))) + val optimized = Optimize.execute(filteredAndReordered.analyze) + val correctAnswer = orderedPlan.where('a > Literal(10)).analyze comparePlans(optimized, correctAnswer) } test("limits don't affect order") { val orderedPlan = testRelation.select('a, 'b).orderBy('a.asc, 'b.desc) val filteredAndReordered = orderedPlan.limit(Literal(10)).orderBy('a.asc, 'b.desc) - val optimized = Optimize.execute(analyzer.execute(filteredAndReordered)) - val correctAnswer = analyzer.execute(orderedPlan.limit(Literal(10))) + val optimized = Optimize.execute(filteredAndReordered.analyze) + val correctAnswer = orderedPlan.limit(Literal(10)).analyze comparePlans(optimized, correctAnswer) } test("range is already sorted") { val inputPlan = Range(1L, 1000L, 1, 10) - val orderedPlan = inputPlan.orderBy('id.desc) - val optimized = Optimize.execute(analyzer.execute(orderedPlan)) - val correctAnswer = analyzer.execute(inputPlan) + val orderedPlan = inputPlan.orderBy('id.asc) + val optimized = Optimize.execute(orderedPlan.analyze) + val correctAnswer = inputPlan.analyze comparePlans(optimized, correctAnswer) + + val reversedPlan = inputPlan.orderBy('id.desc) + val reversedOptimized = Optimize.execute(reversedPlan.analyze) + val reversedCorrectAnswer = reversedPlan.analyze + comparePlans(reversedOptimized, reversedCorrectAnswer) + + val negativeStepInputPlan = Range(10L, 1L, -1, 10) + val negativeStepOrderedPlan = negativeStepInputPlan.orderBy('id.desc) + val negativeStepOptimized = Optimize.execute(negativeStepOrderedPlan.analyze) + val negativeStepCorrectAnswer = negativeStepInputPlan.analyze + comparePlans(negativeStepOptimized, negativeStepCorrectAnswer) } test("sort should not be removed when there is a node which doesn't guarantee any order") { val orderedPlan = testRelation.select('a, 'b).orderBy('a.asc) val groupedAndResorted = orderedPlan.groupBy('a)(sum('a)).orderBy('a.asc) - val optimized = Optimize.execute(analyzer.execute(groupedAndResorted)) - val correctAnswer = analyzer.execute(groupedAndResorted) + val optimized = Optimize.execute(groupedAndResorted.analyze) + val correctAnswer = groupedAndResorted.analyze comparePlans(optimized, correctAnswer) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index a158a42b3df9c..4c454917bdfab 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -204,10 +204,10 @@ class PlannerSuite extends SharedSQLContext { val resorted = query.sort('key.desc) assert(resorted.queryExecution.optimizedPlan.collect { case s: Sort => s}.isEmpty) assert(resorted.select('key).collect().map(_.getInt(0)).toSeq == - (1 to 100).sorted(Ordering[Int].reverse)) + (1 to 100).reverse) // with a different order, the sort is needed val sortedAsc = query.sort('key) - assert(sortedAsc.queryExecution.optimizedPlan.collect { case s: Sort => s}.nonEmpty) + assert(sortedAsc.queryExecution.optimizedPlan.collect { case s: Sort => s}.size == 1) assert(sortedAsc.select('key).collect().map(_.getInt(0)).toSeq == (1 to 100)) }