diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/BaseAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/BaseAggregateExec.scala new file mode 100644 index 0000000000000..0eaa0f53fdacd --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/BaseAggregateExec.scala @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.aggregate + +import org.apache.spark.sql.catalyst.expressions.{Attribute, NamedExpression} +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression +import org.apache.spark.sql.execution.{ExplainUtils, UnaryExecNode} + +/** + * Holds common logic for aggregate operators + */ +trait BaseAggregateExec extends UnaryExecNode { + def groupingExpressions: Seq[NamedExpression] + def aggregateExpressions: Seq[AggregateExpression] + def aggregateAttributes: Seq[Attribute] + def resultExpressions: Seq[NamedExpression] + + override def verboseStringWithOperatorId(): String = { + val inputString = child.output.mkString("[", ", ", "]") + val keyString = groupingExpressions.mkString("[", ", ", "]") + val functionString = aggregateExpressions.mkString("[", ", ", "]") + val aggregateAttributeString = aggregateAttributes.mkString("[", ", ", "]") + val resultString = resultExpressions.mkString("[", ", ", "]") + s""" + |(${ExplainUtils.getOpId(this)}) $nodeName ${ExplainUtils.getCodegenId(this)} + |Input: $inputString + |Keys: $keyString + |Functions: $functionString + |Aggregate Attributes: $aggregateAttributeString + |Results: $resultString + """.stripMargin + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index f73e214a6b41f..7a26fd7a8541a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -53,7 +53,7 @@ case class HashAggregateExec( initialInputBufferOffset: Int, resultExpressions: Seq[NamedExpression], child: SparkPlan) - extends UnaryExecNode with BlockingOperatorWithCodegen with AliasAwareOutputPartitioning { + extends BaseAggregateExec with BlockingOperatorWithCodegen with AliasAwareOutputPartitioning { private[this] val aggregateBufferAttributes = { aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala index 4376f6b6edd57..3fb58eb2cc8ba 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala @@ -67,7 +67,7 @@ case class ObjectHashAggregateExec( initialInputBufferOffset: Int, resultExpressions: Seq[NamedExpression], child: SparkPlan) - extends UnaryExecNode with AliasAwareOutputPartitioning { + extends BaseAggregateExec with AliasAwareOutputPartitioning { private[this] val aggregateBufferAttributes = { aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala index b6e684e62ea5c..77ed469016fa3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.catalyst.util.truncatedString -import org.apache.spark.sql.execution.{AliasAwareOutputPartitioning, SparkPlan, UnaryExecNode} +import org.apache.spark.sql.execution.{AliasAwareOutputPartitioning, SparkPlan} import org.apache.spark.sql.execution.metric.SQLMetrics /** @@ -38,7 +38,7 @@ case class SortAggregateExec( initialInputBufferOffset: Int, resultExpressions: Seq[NamedExpression], child: SparkPlan) - extends UnaryExecNode with AliasAwareOutputPartitioning { + extends BaseAggregateExec with AliasAwareOutputPartitioning { private[this] val aggregateBufferAttributes = { aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) diff --git a/sql/core/src/test/resources/sql-tests/inputs/explain.sql b/sql/core/src/test/resources/sql-tests/inputs/explain.sql index d5253e3daddb0..497b61c6134a2 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/explain.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/explain.sql @@ -5,6 +5,7 @@ CREATE table explain_temp1 (key int, val int) USING PARQUET; CREATE table explain_temp2 (key int, val int) USING PARQUET; CREATE table explain_temp3 (key int, val int) USING PARQUET; +CREATE table explain_temp4 (key int, val string) USING PARQUET; SET spark.sql.codegen.wholeStage = true; @@ -61,7 +62,7 @@ EXPLAIN FORMATTED FROM explain_temp2 WHERE val > 0) OR - key = (SELECT max(key) + key = (SELECT avg(key) FROM explain_temp3 WHERE val > 0); @@ -93,6 +94,25 @@ EXPLAIN FORMATTED CREATE VIEW explain_view AS SELECT key, val FROM explain_temp1; +-- HashAggregate +EXPLAIN FORMATTED + SELECT + COUNT(val) + SUM(key) as TOTAL, + COUNT(key) FILTER (WHERE val > 1) + FROM explain_temp1; + +-- ObjectHashAggregate +EXPLAIN FORMATTED + SELECT key, sort_array(collect_set(val))[0] + FROM explain_temp4 + GROUP BY key; + +-- SortAggregate +EXPLAIN FORMATTED + SELECT key, MIN(val) + FROM explain_temp4 + GROUP BY key; + -- cleanup DROP TABLE explain_temp1; DROP TABLE explain_temp2; diff --git a/sql/core/src/test/resources/sql-tests/results/explain.sql.out b/sql/core/src/test/resources/sql-tests/results/explain.sql.out index 756c14f28a657..bc28d7f87bf00 100644 --- a/sql/core/src/test/resources/sql-tests/results/explain.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/explain.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 18 +-- Number of queries: 22 -- !query @@ -26,6 +26,14 @@ struct<> +-- !query +CREATE table explain_temp4 (key int, val string) USING PARQUET +-- !query schema +struct<> +-- !query output + + + -- !query SET spark.sql.codegen.wholeStage = true -- !query schema @@ -76,12 +84,20 @@ Input : [key#x, val#x] (5) HashAggregate [codegen id : 1] Input: [key#x, val#x] +Keys: [key#x] +Functions: [partial_max(val#x)] +Aggregate Attributes: [max#x] +Results: [key#x, max#x] (6) Exchange Input: [key#x, max#x] (7) HashAggregate [codegen id : 2] Input: [key#x, max#x] +Keys: [key#x] +Functions: [max(val#x)] +Aggregate Attributes: [max(val#x)#x] +Results: [key#x, max(val#x)#x AS max(val)#x] (8) Exchange Input: [key#x, max(val)#x] @@ -132,12 +148,20 @@ Input : [key#x, val#x] (5) HashAggregate [codegen id : 1] Input: [key#x, val#x] +Keys: [key#x] +Functions: [partial_max(val#x)] +Aggregate Attributes: [max#x] +Results: [key#x, max#x] (6) Exchange Input: [key#x, max#x] (7) HashAggregate [codegen id : 2] Input: [key#x, max#x] +Keys: [key#x] +Functions: [max(val#x)] +Aggregate Attributes: [max(val#x)#x] +Results: [key#x, max(val#x)#x AS max(val)#x, max(val#x)#x AS max(val#x)#x] (8) Filter [codegen id : 2] Input : [key#x, max(val)#x, max(val#x)#x] @@ -211,12 +235,20 @@ Input : [key#x, val#x] (10) HashAggregate [codegen id : 3] Input: [key#x, val#x] +Keys: [key#x, val#x] +Functions: [] +Aggregate Attributes: [] +Results: [key#x, val#x] (11) Exchange Input: [key#x, val#x] (12) HashAggregate [codegen id : 4] Input: [key#x, val#x] +Keys: [key#x, val#x] +Functions: [] +Aggregate Attributes: [] +Results: [key#x, val#x] -- !query @@ -413,12 +445,20 @@ Input : [key#x, val#x] (9) HashAggregate [codegen id : 1] Input: [key#x] +Keys: [] +Functions: [partial_max(key#x)] +Aggregate Attributes: [max#x] +Results: [max#x] (10) Exchange Input: [max#x] (11) HashAggregate [codegen id : 2] Input: [max#x] +Keys: [] +Functions: [max(key#x)] +Aggregate Attributes: [max(key#x)#x] +Results: [max(key#x)#x AS max(key)#x] Subquery:2 Hosting operator id = 7 Hosting Expression = Subquery scalar-subquery#x, [id=#x] * HashAggregate (18) @@ -450,12 +490,20 @@ Input : [key#x, val#x] (16) HashAggregate [codegen id : 1] Input: [key#x] +Keys: [] +Functions: [partial_max(key#x)] +Aggregate Attributes: [max#x] +Results: [max#x] (17) Exchange Input: [max#x] (18) HashAggregate [codegen id : 2] Input: [max#x] +Keys: [] +Functions: [max(key#x)] +Aggregate Attributes: [max(key#x)#x] +Results: [max(key#x)#x AS max(key)#x] -- !query @@ -466,7 +514,7 @@ EXPLAIN FORMATTED FROM explain_temp2 WHERE val > 0) OR - key = (SELECT max(key) + key = (SELECT avg(key) FROM explain_temp3 WHERE val > 0) -- !query schema @@ -489,7 +537,7 @@ Input: [key#x, val#x] (3) Filter [codegen id : 1] Input : [key#x, val#x] -Condition : ((key#x = Subquery scalar-subquery#x, [id=#x]) OR (key#x = Subquery scalar-subquery#x, [id=#x])) +Condition : ((key#x = Subquery scalar-subquery#x, [id=#x]) OR (cast(key#x as double) = Subquery scalar-subquery#x, [id=#x])) ===== Subqueries ===== @@ -523,12 +571,20 @@ Input : [key#x, val#x] (8) HashAggregate [codegen id : 1] Input: [key#x] +Keys: [] +Functions: [partial_max(key#x)] +Aggregate Attributes: [max#x] +Results: [max#x] (9) Exchange Input: [max#x] (10) HashAggregate [codegen id : 2] Input: [max#x] +Keys: [] +Functions: [max(key#x)] +Aggregate Attributes: [max(key#x)#x] +Results: [max(key#x)#x AS max(key)#x] Subquery:2 Hosting operator id = 3 Hosting Expression = Subquery scalar-subquery#x, [id=#x] * HashAggregate (17) @@ -560,12 +616,20 @@ Input : [key#x, val#x] (15) HashAggregate [codegen id : 1] Input: [key#x] +Keys: [] +Functions: [partial_avg(cast(key#x as bigint))] +Aggregate Attributes: [sum#x, count#xL] +Results: [sum#x, count#xL] (16) Exchange -Input: [max#x] +Input: [sum#x, count#xL] (17) HashAggregate [codegen id : 2] -Input: [max#x] +Input: [sum#x, count#xL] +Keys: [] +Functions: [avg(cast(key#x as bigint))] +Aggregate Attributes: [avg(cast(key#x as bigint))#x] +Results: [avg(cast(key#x as bigint))#x AS avg(key)#x] -- !query @@ -615,12 +679,20 @@ Input: [key#x] (6) HashAggregate [codegen id : 1] Input: [key#x] +Keys: [] +Functions: [partial_avg(cast(key#x as bigint))] +Aggregate Attributes: [sum#x, count#xL] +Results: [sum#x, count#xL] (7) Exchange Input: [sum#x, count#xL] (8) HashAggregate [codegen id : 2] Input: [sum#x, count#xL] +Keys: [] +Functions: [avg(cast(key#x as bigint))] +Aggregate Attributes: [avg(cast(key#x as bigint))#x] +Results: [avg(cast(key#x as bigint))#x AS avg(key)#x] Subquery:2 Hosting operator id = 3 Hosting Expression = ReusedSubquery Subquery scalar-subquery#x, [id=#x] @@ -740,18 +812,30 @@ Input : [key#x, val#x] (5) HashAggregate [codegen id : 1] Input: [key#x, val#x] +Keys: [key#x] +Functions: [partial_max(val#x)] +Aggregate Attributes: [max#x] +Results: [key#x, max#x] (6) Exchange Input: [key#x, max#x] (7) HashAggregate [codegen id : 4] Input: [key#x, max#x] +Keys: [key#x] +Functions: [max(val#x)] +Aggregate Attributes: [max(val#x)#x] +Results: [key#x, max(val#x)#x AS max(val)#x] (8) ReusedExchange [Reuses operator id: 6] Output : ArrayBuffer(key#x, max#x) (9) HashAggregate [codegen id : 3] Input: [key#x, max#x] +Keys: [key#x] +Functions: [max(val#x)] +Aggregate Attributes: [max(val#x)#x] +Results: [key#x, max(val#x)#x AS max(val)#x] (10) BroadcastExchange Input: [key#x, max(val)#x] @@ -786,6 +870,144 @@ Output: [] (4) Project +-- !query +EXPLAIN FORMATTED + SELECT + COUNT(val) + SUM(key) as TOTAL, + COUNT(key) FILTER (WHERE val > 1) + FROM explain_temp1 +-- !query schema +struct +-- !query output +== Physical Plan == +* HashAggregate (5) ++- Exchange (4) + +- HashAggregate (3) + +- * ColumnarToRow (2) + +- Scan parquet default.explain_temp1 (1) + + +(1) Scan parquet default.explain_temp1 +Output: [key#x, val#x] +Batched: true +Location [not included in comparison]/{warehouse_dir}/explain_temp1] +ReadSchema: struct + +(2) ColumnarToRow [codegen id : 1] +Input: [key#x, val#x] + +(3) HashAggregate +Input: [key#x, val#x] +Keys: [] +Functions: [partial_count(val#x), partial_sum(cast(key#x as bigint)), partial_count(key#x) FILTER (WHERE (val#x > 1))] +Aggregate Attributes: [count#xL, sum#xL, count#xL] +Results: [count#xL, sum#xL, count#xL] + +(4) Exchange +Input: [count#xL, sum#xL, count#xL] + +(5) HashAggregate [codegen id : 2] +Input: [count#xL, sum#xL, count#xL] +Keys: [] +Functions: [count(val#x), sum(cast(key#x as bigint)), count(key#x)] +Aggregate Attributes: [count(val#x)#xL, sum(cast(key#x as bigint))#xL, count(key#x)#xL] +Results: [(count(val#x)#xL + sum(cast(key#x as bigint))#xL) AS TOTAL#xL, count(key#x)#xL AS count(key) FILTER (WHERE (val > 1))#xL] + + +-- !query +EXPLAIN FORMATTED + SELECT key, sort_array(collect_set(val))[0] + FROM explain_temp4 + GROUP BY key +-- !query schema +struct +-- !query output +== Physical Plan == +ObjectHashAggregate (5) ++- Exchange (4) + +- ObjectHashAggregate (3) + +- * ColumnarToRow (2) + +- Scan parquet default.explain_temp4 (1) + + +(1) Scan parquet default.explain_temp4 +Output: [key#x, val#x] +Batched: true +Location [not included in comparison]/{warehouse_dir}/explain_temp4] +ReadSchema: struct + +(2) ColumnarToRow [codegen id : 1] +Input: [key#x, val#x] + +(3) ObjectHashAggregate +Input: [key#x, val#x] +Keys: [key#x] +Functions: [partial_collect_set(val#x, 0, 0)] +Aggregate Attributes: [buf#x] +Results: [key#x, buf#x] + +(4) Exchange +Input: [key#x, buf#x] + +(5) ObjectHashAggregate +Input: [key#x, buf#x] +Keys: [key#x] +Functions: [collect_set(val#x, 0, 0)] +Aggregate Attributes: [collect_set(val#x, 0, 0)#x] +Results: [key#x, sort_array(collect_set(val#x, 0, 0)#x, true)[0] AS sort_array(collect_set(val), true)[0]#x] + + +-- !query +EXPLAIN FORMATTED + SELECT key, MIN(val) + FROM explain_temp4 + GROUP BY key +-- !query schema +struct +-- !query output +== Physical Plan == +SortAggregate (7) ++- * Sort (6) + +- Exchange (5) + +- SortAggregate (4) + +- * Sort (3) + +- * ColumnarToRow (2) + +- Scan parquet default.explain_temp4 (1) + + +(1) Scan parquet default.explain_temp4 +Output: [key#x, val#x] +Batched: true +Location [not included in comparison]/{warehouse_dir}/explain_temp4] +ReadSchema: struct + +(2) ColumnarToRow [codegen id : 1] +Input: [key#x, val#x] + +(3) Sort [codegen id : 1] +Input: [key#x, val#x] + +(4) SortAggregate +Input: [key#x, val#x] +Keys: [key#x] +Functions: [partial_min(val#x)] +Aggregate Attributes: [min#x] +Results: [key#x, min#x] + +(5) Exchange +Input: [key#x, min#x] + +(6) Sort [codegen id : 2] +Input: [key#x, min#x] + +(7) SortAggregate +Input: [key#x, min#x] +Keys: [key#x] +Functions: [min(val#x)] +Aggregate Attributes: [min(val#x)#x] +Results: [key#x, min(val#x)#x AS min(val)#x] + + -- !query DROP TABLE explain_temp1 -- !query schema