From 2b48f4b75b78aa89b8b92ac667c9a90c43f09c41 Mon Sep 17 00:00:00 2001 From: Eric Wu <492960551@qq.com> Date: Mon, 27 Jan 2020 22:42:48 +0800 Subject: [PATCH 1/9] Add detailed information for Aggregate operators in EXPLAIN FORMATTED --- .../aggregate/HashAggregateExec.scala | 16 ++ .../aggregate/ObjectHashAggregateExec.scala | 16 ++ .../aggregate/SortAggregateExec.scala | 18 +- .../resources/sql-tests/inputs/explain.sql | 20 ++ .../sql-tests/results/explain.sql.out | 199 +++++++++++++++++- 5 files changed, 267 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index f73e214a6b41f..072ff1bb7bed3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -1110,6 +1110,22 @@ case class HashAggregateExec( override def simpleString(maxFields: Int): String = toString(verbose = false, maxFields) + override def verboseStringWithOperatorId(): String = { + val allAggregateExpressions = aggregateExpressions + + val keyString = groupingExpressions.mkString("[", ", ", "]") + val functionString = allAggregateExpressions.mkString("[", ", ", "]") + val inputString = child.output.mkString("[", ", ", "]") + val outputString = output.mkString("[", ", ", "]") + s""" + |(${ExplainUtils.getOpId(this)}) $nodeName ${ExplainUtils.getCodegenId(this)} + |Input: $inputString + |Output: $outputString + |Keys: $keyString + |Functions: $functionString + """.stripMargin + } + private def toString(verbose: Boolean, maxFields: Int): String = { val allAggregateExpressions = aggregateExpressions diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala index 4376f6b6edd57..b7631c53dfd9f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala @@ -143,6 +143,22 @@ case class ObjectHashAggregateExec( override def simpleString(maxFields: Int): String = toString(verbose = false, maxFields) + override def verboseStringWithOperatorId(): String = { + val allAggregateExpressions = aggregateExpressions + + val keyString = groupingExpressions.mkString("[", ", ", "]") + val functionString = allAggregateExpressions.mkString("[", ", ", "]") + val inputString = child.output.mkString("[", ", ", "]") + val outputString = output.mkString("[", ", ", "]") + s""" + |(${ExplainUtils.getOpId(this)}) $nodeName ${ExplainUtils.getCodegenId(this)} + |Input: $inputString + |Output: $outputString + |Keys: $keyString + |Functions: $functionString + """.stripMargin + } + private def toString(verbose: Boolean, maxFields: Int): String = { val allAggregateExpressions = aggregateExpressions val keyString = truncatedString(groupingExpressions, "[", ", ", "]", maxFields) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala index b6e684e62ea5c..1cc59ce2d8a20 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.catalyst.util.truncatedString -import org.apache.spark.sql.execution.{AliasAwareOutputPartitioning, SparkPlan, UnaryExecNode} +import org.apache.spark.sql.execution.{AliasAwareOutputPartitioning, ExplainUtils, SparkPlan, UnaryExecNode} import org.apache.spark.sql.execution.metric.SQLMetrics /** @@ -111,6 +111,22 @@ case class SortAggregateExec( override def verboseString(maxFields: Int): String = toString(verbose = true, maxFields) + override def verboseStringWithOperatorId(): String = { + val allAggregateExpressions = aggregateExpressions + + val keyString = groupingExpressions.mkString("[", ", ", "]") + val functionString = allAggregateExpressions.mkString("[", ", ", "]") + val inputString = child.output.mkString("[", ", ", "]") + val outputString = output.mkString("[", ", ", "]") + s""" + |(${ExplainUtils.getOpId(this)}) $nodeName ${ExplainUtils.getCodegenId(this)} + |Input: $inputString + |Output: $outputString + |Keys: $keyString + |Functions: $functionString + """.stripMargin + } + private def toString(verbose: Boolean, maxFields: Int): String = { val allAggregateExpressions = aggregateExpressions diff --git a/sql/core/src/test/resources/sql-tests/inputs/explain.sql b/sql/core/src/test/resources/sql-tests/inputs/explain.sql index d5253e3daddb0..a1aadf0dc34fb 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/explain.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/explain.sql @@ -5,6 +5,7 @@ CREATE table explain_temp1 (key int, val int) USING PARQUET; CREATE table explain_temp2 (key int, val int) USING PARQUET; CREATE table explain_temp3 (key int, val int) USING PARQUET; +CREATE table explain_temp4 (key int, val string) USING PARQUET; SET spark.sql.codegen.wholeStage = true; @@ -93,6 +94,25 @@ EXPLAIN FORMATTED CREATE VIEW explain_view AS SELECT key, val FROM explain_temp1; +-- HashAggregate +EXPLAIN FORMATTED + SELECT + COUNT(val) FILTER (WHERE val = 1), + COUNT(key) FILTER (WHERE val > 1) + FROM explain_temp1; + +-- ObjectHashAggregate +EXPLAIN FORMATTED + SELECT key, sort_array(collect_set(val))[0] + FROM explain_temp4 + GROUP BY key; + +-- SortAggregate +EXPLAIN FORMATTED + SELECT key, MIN(val) + FROM explain_temp4 + GROUP BY key; + -- cleanup DROP TABLE explain_temp1; DROP TABLE explain_temp2; diff --git a/sql/core/src/test/resources/sql-tests/results/explain.sql.out b/sql/core/src/test/resources/sql-tests/results/explain.sql.out index 756c14f28a657..07399c3912afa 100644 --- a/sql/core/src/test/resources/sql-tests/results/explain.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/explain.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 18 +-- Number of queries: 22 -- !query @@ -26,6 +26,14 @@ struct<> +-- !query +CREATE table explain_temp4 (key int, val string) USING PARQUET +-- !query schema +struct<> +-- !query output + + + -- !query SET spark.sql.codegen.wholeStage = true -- !query schema @@ -76,12 +84,18 @@ Input : [key#x, val#x] (5) HashAggregate [codegen id : 1] Input: [key#x, val#x] +Output: [key#x, max#x] +Keys: [key#x] +Functions: [partial_max(val#x)] (6) Exchange Input: [key#x, max#x] (7) HashAggregate [codegen id : 2] Input: [key#x, max#x] +Output: [key#x, max(val)#x] +Keys: [key#x] +Functions: [max(val#x)] (8) Exchange Input: [key#x, max(val)#x] @@ -132,12 +146,18 @@ Input : [key#x, val#x] (5) HashAggregate [codegen id : 1] Input: [key#x, val#x] +Output: [key#x, max#x] +Keys: [key#x] +Functions: [partial_max(val#x)] (6) Exchange Input: [key#x, max#x] (7) HashAggregate [codegen id : 2] Input: [key#x, max#x] +Output: [key#x, max(val)#x, max(val#x)#x] +Keys: [key#x] +Functions: [max(val#x)] (8) Filter [codegen id : 2] Input : [key#x, max(val)#x, max(val#x)#x] @@ -211,12 +231,18 @@ Input : [key#x, val#x] (10) HashAggregate [codegen id : 3] Input: [key#x, val#x] +Output: [key#x, val#x] +Keys: [key#x, val#x] +Functions: [] (11) Exchange Input: [key#x, val#x] (12) HashAggregate [codegen id : 4] Input: [key#x, val#x] +Output: [key#x, val#x] +Keys: [key#x, val#x] +Functions: [] -- !query @@ -413,12 +439,18 @@ Input : [key#x, val#x] (9) HashAggregate [codegen id : 1] Input: [key#x] +Output: [max#x] +Keys: [] +Functions: [partial_max(key#x)] (10) Exchange Input: [max#x] (11) HashAggregate [codegen id : 2] Input: [max#x] +Output: [max(key)#x] +Keys: [] +Functions: [max(key#x)] Subquery:2 Hosting operator id = 7 Hosting Expression = Subquery scalar-subquery#x, [id=#x] * HashAggregate (18) @@ -450,12 +482,18 @@ Input : [key#x, val#x] (16) HashAggregate [codegen id : 1] Input: [key#x] +Output: [max#x] +Keys: [] +Functions: [partial_max(key#x)] (17) Exchange Input: [max#x] (18) HashAggregate [codegen id : 2] Input: [max#x] +Output: [max(key)#x] +Keys: [] +Functions: [max(key#x)] -- !query @@ -523,12 +561,18 @@ Input : [key#x, val#x] (8) HashAggregate [codegen id : 1] Input: [key#x] +Output: [max#x] +Keys: [] +Functions: [partial_max(key#x)] (9) Exchange Input: [max#x] (10) HashAggregate [codegen id : 2] Input: [max#x] +Output: [max(key)#x] +Keys: [] +Functions: [max(key#x)] Subquery:2 Hosting operator id = 3 Hosting Expression = Subquery scalar-subquery#x, [id=#x] * HashAggregate (17) @@ -560,12 +604,18 @@ Input : [key#x, val#x] (15) HashAggregate [codegen id : 1] Input: [key#x] +Output: [max#x] +Keys: [] +Functions: [partial_max(key#x)] (16) Exchange Input: [max#x] (17) HashAggregate [codegen id : 2] Input: [max#x] +Output: [max(key)#x] +Keys: [] +Functions: [max(key#x)] -- !query @@ -615,12 +665,18 @@ Input: [key#x] (6) HashAggregate [codegen id : 1] Input: [key#x] +Output: [sum#x, count#xL] +Keys: [] +Functions: [partial_avg(cast(key#x as bigint))] (7) Exchange Input: [sum#x, count#xL] (8) HashAggregate [codegen id : 2] Input: [sum#x, count#xL] +Output: [avg(key)#x] +Keys: [] +Functions: [avg(cast(key#x as bigint))] Subquery:2 Hosting operator id = 3 Hosting Expression = ReusedSubquery Subquery scalar-subquery#x, [id=#x] @@ -740,18 +796,27 @@ Input : [key#x, val#x] (5) HashAggregate [codegen id : 1] Input: [key#x, val#x] +Output: [key#x, max#x] +Keys: [key#x] +Functions: [partial_max(val#x)] (6) Exchange Input: [key#x, max#x] (7) HashAggregate [codegen id : 4] Input: [key#x, max#x] +Output: [key#x, max(val)#x] +Keys: [key#x] +Functions: [max(val#x)] (8) ReusedExchange [Reuses operator id: 6] Output : ArrayBuffer(key#x, max#x) (9) HashAggregate [codegen id : 3] Input: [key#x, max#x] +Output: [key#x, max(val)#x] +Keys: [key#x] +Functions: [max(val#x)] (10) BroadcastExchange Input: [key#x, max(val)#x] @@ -786,6 +851,138 @@ Output: [] (4) Project +-- !query +EXPLAIN FORMATTED + SELECT + COUNT(val) FILTER (WHERE val = 1), + COUNT(key) FILTER (WHERE val > 1) + FROM explain_temp1 +-- !query schema +struct +-- !query output +== Physical Plan == +* HashAggregate (5) ++- Exchange (4) + +- HashAggregate (3) + +- * ColumnarToRow (2) + +- Scan parquet default.explain_temp1 (1) + + +(1) Scan parquet default.explain_temp1 +Output: [key#x, val#x] +Batched: true +Location [not included in comparison]/{warehouse_dir}/explain_temp1] +ReadSchema: struct + +(2) ColumnarToRow [codegen id : 1] +Input: [key#x, val#x] + +(3) HashAggregate +Input: [key#x, val#x] +Output: [count#xL, count#xL] +Keys: [] +Functions: [partial_count(val#x) FILTER (WHERE (val#x = 1)), partial_count(key#x) FILTER (WHERE (val#x > 1))] + +(4) Exchange +Input: [count#xL, count#xL] + +(5) HashAggregate [codegen id : 2] +Input: [count#xL, count#xL] +Output: [count(val) FILTER (WHERE (val = 1))#xL, count(key) FILTER (WHERE (val > 1))#xL] +Keys: [] +Functions: [count(val#x), count(key#x)] + + +-- !query +EXPLAIN FORMATTED + SELECT key, sort_array(collect_set(val))[0] + FROM explain_temp4 + GROUP BY key +-- !query schema +struct +-- !query output +== Physical Plan == +ObjectHashAggregate (5) ++- Exchange (4) + +- ObjectHashAggregate (3) + +- * ColumnarToRow (2) + +- Scan parquet default.explain_temp4 (1) + + +(1) Scan parquet default.explain_temp4 +Output: [key#x, val#x] +Batched: true +Location [not included in comparison]/{warehouse_dir}/explain_temp4] +ReadSchema: struct + +(2) ColumnarToRow [codegen id : 1] +Input: [key#x, val#x] + +(3) ObjectHashAggregate +Input: [key#x, val#x] +Output: [key#x, buf#x] +Keys: [key#x] +Functions: [partial_collect_set(val#x, 0, 0)] + +(4) Exchange +Input: [key#x, buf#x] + +(5) ObjectHashAggregate +Input: [key#x, buf#x] +Output: [key#x, sort_array(collect_set(val), true)[0]#x] +Keys: [key#x] +Functions: [collect_set(val#x, 0, 0)] + + +-- !query +EXPLAIN FORMATTED + SELECT key, MIN(val) + FROM explain_temp4 + GROUP BY key +-- !query schema +struct +-- !query output +== Physical Plan == +SortAggregate (7) ++- * Sort (6) + +- Exchange (5) + +- SortAggregate (4) + +- * Sort (3) + +- * ColumnarToRow (2) + +- Scan parquet default.explain_temp4 (1) + + +(1) Scan parquet default.explain_temp4 +Output: [key#x, val#x] +Batched: true +Location [not included in comparison]/{warehouse_dir}/explain_temp4] +ReadSchema: struct + +(2) ColumnarToRow [codegen id : 1] +Input: [key#x, val#x] + +(3) Sort [codegen id : 1] +Input: [key#x, val#x] + +(4) SortAggregate +Input: [key#x, val#x] +Output: [key#x, min#x] +Keys: [key#x] +Functions: [partial_min(val#x)] + +(5) Exchange +Input: [key#x, min#x] + +(6) Sort [codegen id : 2] +Input: [key#x, min#x] + +(7) SortAggregate +Input: [key#x, min#x] +Output: [key#x, min(val)#x] +Keys: [key#x] +Functions: [min(val#x)] + + -- !query DROP TABLE explain_temp1 -- !query schema From 4ff35f992a9c50d84da8499f1973574bdce352f0 Mon Sep 17 00:00:00 2001 From: Eric Wu <492960551@qq.com> Date: Tue, 28 Jan 2020 15:43:51 +0800 Subject: [PATCH 2/9] Abstract common logic of aggregate operators --- .../execution/aggregate/AggregateExec.scala | 47 +++++++++++++++++++ .../aggregate/HashAggregateExec.scala | 19 +------- .../aggregate/ObjectHashAggregateExec.scala | 19 +------- .../aggregate/SortAggregateExec.scala | 21 ++------- 4 files changed, 54 insertions(+), 52 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregateExec.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregateExec.scala new file mode 100644 index 0000000000000..94475ea5e630a --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregateExec.scala @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.aggregate + +import org.apache.spark.sql.catalyst.expressions.NamedExpression +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression +import org.apache.spark.sql.execution.{ExplainUtils, UnaryExecNode} + +/** + * Holds common logic for aggregate operators + */ +abstract class AggregateExec( + groupingExpressions: Seq[NamedExpression], + aggregateExpressions: Seq[AggregateExpression]) + extends UnaryExecNode { + + override def verboseStringWithOperatorId(): String = { + val allAggregateExpressions = aggregateExpressions + + val keyString = groupingExpressions.mkString("[", ", ", "]") + val functionString = allAggregateExpressions.mkString("[", ", ", "]") + val inputString = child.output.mkString("[", ", ", "]") + val outputString = output.mkString("[", ", ", "]") + s""" + |(${ExplainUtils.getOpId(this)}) $nodeName ${ExplainUtils.getCodegenId(this)} + |Input: $inputString + |Output: $outputString + |Keys: $keyString + |Functions: $functionString + """.stripMargin + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 072ff1bb7bed3..8f0fa924949e5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -53,7 +53,8 @@ case class HashAggregateExec( initialInputBufferOffset: Int, resultExpressions: Seq[NamedExpression], child: SparkPlan) - extends UnaryExecNode with BlockingOperatorWithCodegen with AliasAwareOutputPartitioning { + extends AggregateExec(groupingExpressions, aggregateExpressions) + with BlockingOperatorWithCodegen with AliasAwareOutputPartitioning { private[this] val aggregateBufferAttributes = { aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) @@ -1110,22 +1111,6 @@ case class HashAggregateExec( override def simpleString(maxFields: Int): String = toString(verbose = false, maxFields) - override def verboseStringWithOperatorId(): String = { - val allAggregateExpressions = aggregateExpressions - - val keyString = groupingExpressions.mkString("[", ", ", "]") - val functionString = allAggregateExpressions.mkString("[", ", ", "]") - val inputString = child.output.mkString("[", ", ", "]") - val outputString = output.mkString("[", ", ", "]") - s""" - |(${ExplainUtils.getOpId(this)}) $nodeName ${ExplainUtils.getCodegenId(this)} - |Input: $inputString - |Output: $outputString - |Keys: $keyString - |Functions: $functionString - """.stripMargin - } - private def toString(verbose: Boolean, maxFields: Int): String = { val allAggregateExpressions = aggregateExpressions diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala index b7631c53dfd9f..07667352b1c92 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala @@ -67,7 +67,8 @@ case class ObjectHashAggregateExec( initialInputBufferOffset: Int, resultExpressions: Seq[NamedExpression], child: SparkPlan) - extends UnaryExecNode with AliasAwareOutputPartitioning { + extends AggregateExec(groupingExpressions, aggregateExpressions) + with AliasAwareOutputPartitioning { private[this] val aggregateBufferAttributes = { aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) @@ -143,22 +144,6 @@ case class ObjectHashAggregateExec( override def simpleString(maxFields: Int): String = toString(verbose = false, maxFields) - override def verboseStringWithOperatorId(): String = { - val allAggregateExpressions = aggregateExpressions - - val keyString = groupingExpressions.mkString("[", ", ", "]") - val functionString = allAggregateExpressions.mkString("[", ", ", "]") - val inputString = child.output.mkString("[", ", ", "]") - val outputString = output.mkString("[", ", ", "]") - s""" - |(${ExplainUtils.getOpId(this)}) $nodeName ${ExplainUtils.getCodegenId(this)} - |Input: $inputString - |Output: $outputString - |Keys: $keyString - |Functions: $functionString - """.stripMargin - } - private def toString(verbose: Boolean, maxFields: Int): String = { val allAggregateExpressions = aggregateExpressions val keyString = truncatedString(groupingExpressions, "[", ", ", "]", maxFields) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala index 1cc59ce2d8a20..746d0c82b4db2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.catalyst.util.truncatedString -import org.apache.spark.sql.execution.{AliasAwareOutputPartitioning, ExplainUtils, SparkPlan, UnaryExecNode} +import org.apache.spark.sql.execution.{AliasAwareOutputPartitioning, SparkPlan} import org.apache.spark.sql.execution.metric.SQLMetrics /** @@ -38,7 +38,8 @@ case class SortAggregateExec( initialInputBufferOffset: Int, resultExpressions: Seq[NamedExpression], child: SparkPlan) - extends UnaryExecNode with AliasAwareOutputPartitioning { + extends AggregateExec(groupingExpressions, aggregateExpressions) + with AliasAwareOutputPartitioning { private[this] val aggregateBufferAttributes = { aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) @@ -111,22 +112,6 @@ case class SortAggregateExec( override def verboseString(maxFields: Int): String = toString(verbose = true, maxFields) - override def verboseStringWithOperatorId(): String = { - val allAggregateExpressions = aggregateExpressions - - val keyString = groupingExpressions.mkString("[", ", ", "]") - val functionString = allAggregateExpressions.mkString("[", ", ", "]") - val inputString = child.output.mkString("[", ", ", "]") - val outputString = output.mkString("[", ", ", "]") - s""" - |(${ExplainUtils.getOpId(this)}) $nodeName ${ExplainUtils.getCodegenId(this)} - |Input: $inputString - |Output: $outputString - |Keys: $keyString - |Functions: $functionString - """.stripMargin - } - private def toString(verbose: Boolean, maxFields: Int): String = { val allAggregateExpressions = aggregateExpressions From 426b953ed4d3270c79843568a1f7abc915b727c8 Mon Sep 17 00:00:00 2001 From: Eric Wu <492960551@qq.com> Date: Tue, 4 Feb 2020 18:12:42 +0800 Subject: [PATCH 3/9] Adjust field order and test case --- .../execution/aggregate/AggregateExec.scala | 8 +-- .../resources/sql-tests/inputs/explain.sql | 2 +- .../sql-tests/results/explain.sql.out | 60 +++++++++---------- 3 files changed, 34 insertions(+), 36 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregateExec.scala index 94475ea5e630a..79e34ae990a08 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregateExec.scala @@ -30,18 +30,16 @@ abstract class AggregateExec( extends UnaryExecNode { override def verboseStringWithOperatorId(): String = { - val allAggregateExpressions = aggregateExpressions - - val keyString = groupingExpressions.mkString("[", ", ", "]") - val functionString = allAggregateExpressions.mkString("[", ", ", "]") val inputString = child.output.mkString("[", ", ", "]") + val keyString = groupingExpressions.mkString("[", ", ", "]") + val functionString = aggregateExpressions.mkString("[", ", ", "]") val outputString = output.mkString("[", ", ", "]") s""" |(${ExplainUtils.getOpId(this)}) $nodeName ${ExplainUtils.getCodegenId(this)} |Input: $inputString - |Output: $outputString |Keys: $keyString |Functions: $functionString + |Output: $outputString """.stripMargin } } diff --git a/sql/core/src/test/resources/sql-tests/inputs/explain.sql b/sql/core/src/test/resources/sql-tests/inputs/explain.sql index a1aadf0dc34fb..0e411877bde78 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/explain.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/explain.sql @@ -97,7 +97,7 @@ EXPLAIN FORMATTED -- HashAggregate EXPLAIN FORMATTED SELECT - COUNT(val) FILTER (WHERE val = 1), + COUNT(val) + SUM(key), COUNT(key) FILTER (WHERE val > 1) FROM explain_temp1; diff --git a/sql/core/src/test/resources/sql-tests/results/explain.sql.out b/sql/core/src/test/resources/sql-tests/results/explain.sql.out index 07399c3912afa..6e342744e0229 100644 --- a/sql/core/src/test/resources/sql-tests/results/explain.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/explain.sql.out @@ -84,18 +84,18 @@ Input : [key#x, val#x] (5) HashAggregate [codegen id : 1] Input: [key#x, val#x] -Output: [key#x, max#x] Keys: [key#x] Functions: [partial_max(val#x)] +Output: [key#x, max#x] (6) Exchange Input: [key#x, max#x] (7) HashAggregate [codegen id : 2] Input: [key#x, max#x] -Output: [key#x, max(val)#x] Keys: [key#x] Functions: [max(val#x)] +Output: [key#x, max(val)#x] (8) Exchange Input: [key#x, max(val)#x] @@ -146,18 +146,18 @@ Input : [key#x, val#x] (5) HashAggregate [codegen id : 1] Input: [key#x, val#x] -Output: [key#x, max#x] Keys: [key#x] Functions: [partial_max(val#x)] +Output: [key#x, max#x] (6) Exchange Input: [key#x, max#x] (7) HashAggregate [codegen id : 2] Input: [key#x, max#x] -Output: [key#x, max(val)#x, max(val#x)#x] Keys: [key#x] Functions: [max(val#x)] +Output: [key#x, max(val)#x, max(val#x)#x] (8) Filter [codegen id : 2] Input : [key#x, max(val)#x, max(val#x)#x] @@ -231,18 +231,18 @@ Input : [key#x, val#x] (10) HashAggregate [codegen id : 3] Input: [key#x, val#x] -Output: [key#x, val#x] Keys: [key#x, val#x] Functions: [] +Output: [key#x, val#x] (11) Exchange Input: [key#x, val#x] (12) HashAggregate [codegen id : 4] Input: [key#x, val#x] -Output: [key#x, val#x] Keys: [key#x, val#x] Functions: [] +Output: [key#x, val#x] -- !query @@ -439,18 +439,18 @@ Input : [key#x, val#x] (9) HashAggregate [codegen id : 1] Input: [key#x] -Output: [max#x] Keys: [] Functions: [partial_max(key#x)] +Output: [max#x] (10) Exchange Input: [max#x] (11) HashAggregate [codegen id : 2] Input: [max#x] -Output: [max(key)#x] Keys: [] Functions: [max(key#x)] +Output: [max(key)#x] Subquery:2 Hosting operator id = 7 Hosting Expression = Subquery scalar-subquery#x, [id=#x] * HashAggregate (18) @@ -482,18 +482,18 @@ Input : [key#x, val#x] (16) HashAggregate [codegen id : 1] Input: [key#x] -Output: [max#x] Keys: [] Functions: [partial_max(key#x)] +Output: [max#x] (17) Exchange Input: [max#x] (18) HashAggregate [codegen id : 2] Input: [max#x] -Output: [max(key)#x] Keys: [] Functions: [max(key#x)] +Output: [max(key)#x] -- !query @@ -561,18 +561,18 @@ Input : [key#x, val#x] (8) HashAggregate [codegen id : 1] Input: [key#x] -Output: [max#x] Keys: [] Functions: [partial_max(key#x)] +Output: [max#x] (9) Exchange Input: [max#x] (10) HashAggregate [codegen id : 2] Input: [max#x] -Output: [max(key)#x] Keys: [] Functions: [max(key#x)] +Output: [max(key)#x] Subquery:2 Hosting operator id = 3 Hosting Expression = Subquery scalar-subquery#x, [id=#x] * HashAggregate (17) @@ -604,18 +604,18 @@ Input : [key#x, val#x] (15) HashAggregate [codegen id : 1] Input: [key#x] -Output: [max#x] Keys: [] Functions: [partial_max(key#x)] +Output: [max#x] (16) Exchange Input: [max#x] (17) HashAggregate [codegen id : 2] Input: [max#x] -Output: [max(key)#x] Keys: [] Functions: [max(key#x)] +Output: [max(key)#x] -- !query @@ -665,18 +665,18 @@ Input: [key#x] (6) HashAggregate [codegen id : 1] Input: [key#x] -Output: [sum#x, count#xL] Keys: [] Functions: [partial_avg(cast(key#x as bigint))] +Output: [sum#x, count#xL] (7) Exchange Input: [sum#x, count#xL] (8) HashAggregate [codegen id : 2] Input: [sum#x, count#xL] -Output: [avg(key)#x] Keys: [] Functions: [avg(cast(key#x as bigint))] +Output: [avg(key)#x] Subquery:2 Hosting operator id = 3 Hosting Expression = ReusedSubquery Subquery scalar-subquery#x, [id=#x] @@ -796,27 +796,27 @@ Input : [key#x, val#x] (5) HashAggregate [codegen id : 1] Input: [key#x, val#x] -Output: [key#x, max#x] Keys: [key#x] Functions: [partial_max(val#x)] +Output: [key#x, max#x] (6) Exchange Input: [key#x, max#x] (7) HashAggregate [codegen id : 4] Input: [key#x, max#x] -Output: [key#x, max(val)#x] Keys: [key#x] Functions: [max(val#x)] +Output: [key#x, max(val)#x] (8) ReusedExchange [Reuses operator id: 6] Output : ArrayBuffer(key#x, max#x) (9) HashAggregate [codegen id : 3] Input: [key#x, max#x] -Output: [key#x, max(val)#x] Keys: [key#x] Functions: [max(val#x)] +Output: [key#x, max(val)#x] (10) BroadcastExchange Input: [key#x, max(val)#x] @@ -854,7 +854,7 @@ Output: [] -- !query EXPLAIN FORMATTED SELECT - COUNT(val) FILTER (WHERE val = 1), + COUNT(val) + SUM(key), COUNT(key) FILTER (WHERE val > 1) FROM explain_temp1 -- !query schema @@ -879,18 +879,18 @@ Input: [key#x, val#x] (3) HashAggregate Input: [key#x, val#x] -Output: [count#xL, count#xL] Keys: [] -Functions: [partial_count(val#x) FILTER (WHERE (val#x = 1)), partial_count(key#x) FILTER (WHERE (val#x > 1))] +Functions: [partial_count(val#x), partial_sum(cast(key#x as bigint)), partial_count(key#x) FILTER (WHERE (val#x > 1))] +Output: [count#xL, sum#xL, count#xL] (4) Exchange -Input: [count#xL, count#xL] +Input: [count#xL, sum#xL, count#xL] (5) HashAggregate [codegen id : 2] -Input: [count#xL, count#xL] -Output: [count(val) FILTER (WHERE (val = 1))#xL, count(key) FILTER (WHERE (val > 1))#xL] +Input: [count#xL, sum#xL, count#xL] Keys: [] -Functions: [count(val#x), count(key#x)] +Functions: [count(val#x), sum(cast(key#x as bigint)), count(key#x)] +Output: [(count(val) + sum(key))#xL, count(key) FILTER (WHERE (val > 1))#xL] -- !query @@ -920,18 +920,18 @@ Input: [key#x, val#x] (3) ObjectHashAggregate Input: [key#x, val#x] -Output: [key#x, buf#x] Keys: [key#x] Functions: [partial_collect_set(val#x, 0, 0)] +Output: [key#x, buf#x] (4) Exchange Input: [key#x, buf#x] (5) ObjectHashAggregate Input: [key#x, buf#x] -Output: [key#x, sort_array(collect_set(val), true)[0]#x] Keys: [key#x] Functions: [collect_set(val#x, 0, 0)] +Output: [key#x, sort_array(collect_set(val), true)[0]#x] -- !query @@ -966,9 +966,9 @@ Input: [key#x, val#x] (4) SortAggregate Input: [key#x, val#x] -Output: [key#x, min#x] Keys: [key#x] Functions: [partial_min(val#x)] +Output: [key#x, min#x] (5) Exchange Input: [key#x, min#x] @@ -978,9 +978,9 @@ Input: [key#x, min#x] (7) SortAggregate Input: [key#x, min#x] -Output: [key#x, min(val)#x] Keys: [key#x] Functions: [min(val#x)] +Output: [key#x, min(val)#x] -- !query From 65aea2ae05c01cb5f134fa998382f29fb01131eb Mon Sep 17 00:00:00 2001 From: Eric Wu <492960551@qq.com> Date: Tue, 4 Feb 2020 21:51:40 +0800 Subject: [PATCH 4/9] Add full aggregation field --- .../execution/aggregate/AggregateExec.scala | 5 +++- .../aggregate/HashAggregateExec.scala | 2 +- .../aggregate/ObjectHashAggregateExec.scala | 2 +- .../aggregate/SortAggregateExec.scala | 2 +- .../resources/sql-tests/inputs/explain.sql | 2 +- .../sql-tests/results/explain.sql.out | 29 +++++++++++++++++-- 6 files changed, 35 insertions(+), 7 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregateExec.scala index 79e34ae990a08..5a174832980ce 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregateExec.scala @@ -26,19 +26,22 @@ import org.apache.spark.sql.execution.{ExplainUtils, UnaryExecNode} */ abstract class AggregateExec( groupingExpressions: Seq[NamedExpression], - aggregateExpressions: Seq[AggregateExpression]) + aggregateExpressions: Seq[AggregateExpression], + resultExpressions: Seq[NamedExpression]) extends UnaryExecNode { override def verboseStringWithOperatorId(): String = { val inputString = child.output.mkString("[", ", ", "]") val keyString = groupingExpressions.mkString("[", ", ", "]") val functionString = aggregateExpressions.mkString("[", ", ", "]") + val resultString = resultExpressions.mkString("[", ", ", "]") val outputString = output.mkString("[", ", ", "]") s""" |(${ExplainUtils.getOpId(this)}) $nodeName ${ExplainUtils.getCodegenId(this)} |Input: $inputString |Keys: $keyString |Functions: $functionString + |Results: $resultString |Output: $outputString """.stripMargin } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 8f0fa924949e5..bbb8634facc91 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -53,7 +53,7 @@ case class HashAggregateExec( initialInputBufferOffset: Int, resultExpressions: Seq[NamedExpression], child: SparkPlan) - extends AggregateExec(groupingExpressions, aggregateExpressions) + extends AggregateExec(groupingExpressions, aggregateExpressions, resultExpressions) with BlockingOperatorWithCodegen with AliasAwareOutputPartitioning { private[this] val aggregateBufferAttributes = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala index 07667352b1c92..73a5de1655ed9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala @@ -67,7 +67,7 @@ case class ObjectHashAggregateExec( initialInputBufferOffset: Int, resultExpressions: Seq[NamedExpression], child: SparkPlan) - extends AggregateExec(groupingExpressions, aggregateExpressions) + extends AggregateExec(groupingExpressions, aggregateExpressions, resultExpressions) with AliasAwareOutputPartitioning { private[this] val aggregateBufferAttributes = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala index 746d0c82b4db2..13cc71db48c79 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala @@ -38,7 +38,7 @@ case class SortAggregateExec( initialInputBufferOffset: Int, resultExpressions: Seq[NamedExpression], child: SparkPlan) - extends AggregateExec(groupingExpressions, aggregateExpressions) + extends AggregateExec(groupingExpressions, aggregateExpressions, resultExpressions) with AliasAwareOutputPartitioning { private[this] val aggregateBufferAttributes = { diff --git a/sql/core/src/test/resources/sql-tests/inputs/explain.sql b/sql/core/src/test/resources/sql-tests/inputs/explain.sql index 0e411877bde78..66826fc3e83af 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/explain.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/explain.sql @@ -97,7 +97,7 @@ EXPLAIN FORMATTED -- HashAggregate EXPLAIN FORMATTED SELECT - COUNT(val) + SUM(key), + COUNT(val) + SUM(key) as TOTAL, COUNT(key) FILTER (WHERE val > 1) FROM explain_temp1; diff --git a/sql/core/src/test/resources/sql-tests/results/explain.sql.out b/sql/core/src/test/resources/sql-tests/results/explain.sql.out index 6e342744e0229..d8b74ed0e8f66 100644 --- a/sql/core/src/test/resources/sql-tests/results/explain.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/explain.sql.out @@ -86,6 +86,7 @@ Input : [key#x, val#x] Input: [key#x, val#x] Keys: [key#x] Functions: [partial_max(val#x)] +Results: [key#x, max#x] Output: [key#x, max#x] (6) Exchange @@ -95,6 +96,7 @@ Input: [key#x, max#x] Input: [key#x, max#x] Keys: [key#x] Functions: [max(val#x)] +Results: [key#x, max(val#x)#x AS max(val)#x] Output: [key#x, max(val)#x] (8) Exchange @@ -148,6 +150,7 @@ Input : [key#x, val#x] Input: [key#x, val#x] Keys: [key#x] Functions: [partial_max(val#x)] +Results: [key#x, max#x] Output: [key#x, max#x] (6) Exchange @@ -157,6 +160,7 @@ Input: [key#x, max#x] Input: [key#x, max#x] Keys: [key#x] Functions: [max(val#x)] +Results: [key#x, max(val#x)#x AS max(val)#x, max(val#x)#x AS max(val#x)#x] Output: [key#x, max(val)#x, max(val#x)#x] (8) Filter [codegen id : 2] @@ -233,6 +237,7 @@ Input : [key#x, val#x] Input: [key#x, val#x] Keys: [key#x, val#x] Functions: [] +Results: [key#x, val#x] Output: [key#x, val#x] (11) Exchange @@ -242,6 +247,7 @@ Input: [key#x, val#x] Input: [key#x, val#x] Keys: [key#x, val#x] Functions: [] +Results: [key#x, val#x] Output: [key#x, val#x] @@ -441,6 +447,7 @@ Input : [key#x, val#x] Input: [key#x] Keys: [] Functions: [partial_max(key#x)] +Results: [max#x] Output: [max#x] (10) Exchange @@ -450,6 +457,7 @@ Input: [max#x] Input: [max#x] Keys: [] Functions: [max(key#x)] +Results: [max(key#x)#x AS max(key)#x] Output: [max(key)#x] Subquery:2 Hosting operator id = 7 Hosting Expression = Subquery scalar-subquery#x, [id=#x] @@ -484,6 +492,7 @@ Input : [key#x, val#x] Input: [key#x] Keys: [] Functions: [partial_max(key#x)] +Results: [max#x] Output: [max#x] (17) Exchange @@ -493,6 +502,7 @@ Input: [max#x] Input: [max#x] Keys: [] Functions: [max(key#x)] +Results: [max(key#x)#x AS max(key)#x] Output: [max(key)#x] @@ -563,6 +573,7 @@ Input : [key#x, val#x] Input: [key#x] Keys: [] Functions: [partial_max(key#x)] +Results: [max#x] Output: [max#x] (9) Exchange @@ -572,6 +583,7 @@ Input: [max#x] Input: [max#x] Keys: [] Functions: [max(key#x)] +Results: [max(key#x)#x AS max(key)#x] Output: [max(key)#x] Subquery:2 Hosting operator id = 3 Hosting Expression = Subquery scalar-subquery#x, [id=#x] @@ -606,6 +618,7 @@ Input : [key#x, val#x] Input: [key#x] Keys: [] Functions: [partial_max(key#x)] +Results: [max#x] Output: [max#x] (16) Exchange @@ -615,6 +628,7 @@ Input: [max#x] Input: [max#x] Keys: [] Functions: [max(key#x)] +Results: [max(key#x)#x AS max(key)#x] Output: [max(key)#x] @@ -667,6 +681,7 @@ Input: [key#x] Input: [key#x] Keys: [] Functions: [partial_avg(cast(key#x as bigint))] +Results: [sum#x, count#xL] Output: [sum#x, count#xL] (7) Exchange @@ -676,6 +691,7 @@ Input: [sum#x, count#xL] Input: [sum#x, count#xL] Keys: [] Functions: [avg(cast(key#x as bigint))] +Results: [avg(cast(key#x as bigint))#x AS avg(key)#x] Output: [avg(key)#x] Subquery:2 Hosting operator id = 3 Hosting Expression = ReusedSubquery Subquery scalar-subquery#x, [id=#x] @@ -798,6 +814,7 @@ Input : [key#x, val#x] Input: [key#x, val#x] Keys: [key#x] Functions: [partial_max(val#x)] +Results: [key#x, max#x] Output: [key#x, max#x] (6) Exchange @@ -807,6 +824,7 @@ Input: [key#x, max#x] Input: [key#x, max#x] Keys: [key#x] Functions: [max(val#x)] +Results: [key#x, max(val#x)#x AS max(val)#x] Output: [key#x, max(val)#x] (8) ReusedExchange [Reuses operator id: 6] @@ -816,6 +834,7 @@ Output : ArrayBuffer(key#x, max#x) Input: [key#x, max#x] Keys: [key#x] Functions: [max(val#x)] +Results: [key#x, max(val#x)#x AS max(val)#x] Output: [key#x, max(val)#x] (10) BroadcastExchange @@ -854,7 +873,7 @@ Output: [] -- !query EXPLAIN FORMATTED SELECT - COUNT(val) + SUM(key), + COUNT(val) + SUM(key) as TOTAL, COUNT(key) FILTER (WHERE val > 1) FROM explain_temp1 -- !query schema @@ -881,6 +900,7 @@ Input: [key#x, val#x] Input: [key#x, val#x] Keys: [] Functions: [partial_count(val#x), partial_sum(cast(key#x as bigint)), partial_count(key#x) FILTER (WHERE (val#x > 1))] +Results: [count#xL, sum#xL, count#xL] Output: [count#xL, sum#xL, count#xL] (4) Exchange @@ -890,7 +910,8 @@ Input: [count#xL, sum#xL, count#xL] Input: [count#xL, sum#xL, count#xL] Keys: [] Functions: [count(val#x), sum(cast(key#x as bigint)), count(key#x)] -Output: [(count(val) + sum(key))#xL, count(key) FILTER (WHERE (val > 1))#xL] +Results: [(count(val#x)#xL + sum(cast(key#x as bigint))#xL) AS TOTAL#xL, count(key#x)#xL AS count(key) FILTER (WHERE (val > 1))#xL] +Output: [TOTAL#xL, count(key) FILTER (WHERE (val > 1))#xL] -- !query @@ -922,6 +943,7 @@ Input: [key#x, val#x] Input: [key#x, val#x] Keys: [key#x] Functions: [partial_collect_set(val#x, 0, 0)] +Results: [key#x, buf#x] Output: [key#x, buf#x] (4) Exchange @@ -931,6 +953,7 @@ Input: [key#x, buf#x] Input: [key#x, buf#x] Keys: [key#x] Functions: [collect_set(val#x, 0, 0)] +Results: [key#x, sort_array(collect_set(val#x, 0, 0)#x, true)[0] AS sort_array(collect_set(val), true)[0]#x] Output: [key#x, sort_array(collect_set(val), true)[0]#x] @@ -968,6 +991,7 @@ Input: [key#x, val#x] Input: [key#x, val#x] Keys: [key#x] Functions: [partial_min(val#x)] +Results: [key#x, min#x] Output: [key#x, min#x] (5) Exchange @@ -980,6 +1004,7 @@ Input: [key#x, min#x] Input: [key#x, min#x] Keys: [key#x] Functions: [min(val#x)] +Results: [key#x, min(val#x)#x AS min(val)#x] Output: [key#x, min(val)#x] From 2cc5daf5cce362d1ff3e7dd564d578e92191cae7 Mon Sep 17 00:00:00 2001 From: Eric Wu <492960551@qq.com> Date: Wed, 5 Feb 2020 14:26:07 +0800 Subject: [PATCH 5/9] Improve abstract class style --- .../spark/sql/execution/aggregate/AggregateExec.scala | 9 ++++----- .../sql/execution/aggregate/HashAggregateExec.scala | 3 +-- .../execution/aggregate/ObjectHashAggregateExec.scala | 3 +-- .../sql/execution/aggregate/SortAggregateExec.scala | 3 +-- 4 files changed, 7 insertions(+), 11 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregateExec.scala index 5a174832980ce..53a88675efe25 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregateExec.scala @@ -24,11 +24,10 @@ import org.apache.spark.sql.execution.{ExplainUtils, UnaryExecNode} /** * Holds common logic for aggregate operators */ -abstract class AggregateExec( - groupingExpressions: Seq[NamedExpression], - aggregateExpressions: Seq[AggregateExpression], - resultExpressions: Seq[NamedExpression]) - extends UnaryExecNode { +abstract class BaseAggregateExec extends UnaryExecNode { + val groupingExpressions: Seq[NamedExpression] + val aggregateExpressions: Seq[AggregateExpression] + val resultExpressions: Seq[NamedExpression] override def verboseStringWithOperatorId(): String = { val inputString = child.output.mkString("[", ", ", "]") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index bbb8634facc91..7a26fd7a8541a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -53,8 +53,7 @@ case class HashAggregateExec( initialInputBufferOffset: Int, resultExpressions: Seq[NamedExpression], child: SparkPlan) - extends AggregateExec(groupingExpressions, aggregateExpressions, resultExpressions) - with BlockingOperatorWithCodegen with AliasAwareOutputPartitioning { + extends BaseAggregateExec with BlockingOperatorWithCodegen with AliasAwareOutputPartitioning { private[this] val aggregateBufferAttributes = { aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala index 73a5de1655ed9..3fb58eb2cc8ba 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala @@ -67,8 +67,7 @@ case class ObjectHashAggregateExec( initialInputBufferOffset: Int, resultExpressions: Seq[NamedExpression], child: SparkPlan) - extends AggregateExec(groupingExpressions, aggregateExpressions, resultExpressions) - with AliasAwareOutputPartitioning { + extends BaseAggregateExec with AliasAwareOutputPartitioning { private[this] val aggregateBufferAttributes = { aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala index 13cc71db48c79..77ed469016fa3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala @@ -38,8 +38,7 @@ case class SortAggregateExec( initialInputBufferOffset: Int, resultExpressions: Seq[NamedExpression], child: SparkPlan) - extends AggregateExec(groupingExpressions, aggregateExpressions, resultExpressions) - with AliasAwareOutputPartitioning { + extends BaseAggregateExec with AliasAwareOutputPartitioning { private[this] val aggregateBufferAttributes = { aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) From 5e5d4819e1500435f01f652ff763b31af69d65b7 Mon Sep 17 00:00:00 2001 From: Eric Wu <492960551@qq.com> Date: Thu, 6 Feb 2020 21:16:35 +0800 Subject: [PATCH 6/9] Add function buffer attributes --- ...gateExec.scala => BaseAggregateExec.scala} | 8 ++- .../aggregate/HashAggregateExec.scala | 4 -- .../aggregate/ObjectHashAggregateExec.scala | 4 -- .../aggregate/SortAggregateExec.scala | 4 -- .../resources/sql-tests/inputs/explain.sql | 2 +- .../sql-tests/results/explain.sql.out | 66 +++++++++---------- 6 files changed, 40 insertions(+), 48 deletions(-) rename sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/{AggregateExec.scala => BaseAggregateExec.scala} (88%) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/BaseAggregateExec.scala similarity index 88% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregateExec.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/BaseAggregateExec.scala index 53a88675efe25..7528ff6fa18fc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/BaseAggregateExec.scala @@ -29,19 +29,23 @@ abstract class BaseAggregateExec extends UnaryExecNode { val aggregateExpressions: Seq[AggregateExpression] val resultExpressions: Seq[NamedExpression] + protected val aggregateBufferAttributes = { + aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) + } + override def verboseStringWithOperatorId(): String = { val inputString = child.output.mkString("[", ", ", "]") val keyString = groupingExpressions.mkString("[", ", ", "]") val functionString = aggregateExpressions.mkString("[", ", ", "]") + val funcBufferAttrString = aggregateBufferAttributes.mkString("[", ", ", "]") val resultString = resultExpressions.mkString("[", ", ", "]") - val outputString = output.mkString("[", ", ", "]") s""" |(${ExplainUtils.getOpId(this)}) $nodeName ${ExplainUtils.getCodegenId(this)} |Input: $inputString |Keys: $keyString |Functions: $functionString + |FuncBufferAttrs: $funcBufferAttrString |Results: $resultString - |Output: $outputString """.stripMargin } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 7a26fd7a8541a..06d3d366e1289 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -55,10 +55,6 @@ case class HashAggregateExec( child: SparkPlan) extends BaseAggregateExec with BlockingOperatorWithCodegen with AliasAwareOutputPartitioning { - private[this] val aggregateBufferAttributes = { - aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) - } - require(HashAggregateExec.supportsAggregate(aggregateBufferAttributes)) override lazy val allAttributes: AttributeSeq = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala index 3fb58eb2cc8ba..fadadbf581f5e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala @@ -69,10 +69,6 @@ case class ObjectHashAggregateExec( child: SparkPlan) extends BaseAggregateExec with AliasAwareOutputPartitioning { - private[this] val aggregateBufferAttributes = { - aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) - } - override lazy val allAttributes: AttributeSeq = child.output ++ aggregateBufferAttributes ++ aggregateAttributes ++ aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala index 77ed469016fa3..ffeb77bf212dd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala @@ -40,10 +40,6 @@ case class SortAggregateExec( child: SparkPlan) extends BaseAggregateExec with AliasAwareOutputPartitioning { - private[this] val aggregateBufferAttributes = { - aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) - } - override def producedAttributes: AttributeSet = AttributeSet(aggregateAttributes) ++ AttributeSet(resultExpressions.diff(groupingExpressions).map(_.toAttribute)) ++ diff --git a/sql/core/src/test/resources/sql-tests/inputs/explain.sql b/sql/core/src/test/resources/sql-tests/inputs/explain.sql index 66826fc3e83af..497b61c6134a2 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/explain.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/explain.sql @@ -62,7 +62,7 @@ EXPLAIN FORMATTED FROM explain_temp2 WHERE val > 0) OR - key = (SELECT max(key) + key = (SELECT avg(key) FROM explain_temp3 WHERE val > 0); diff --git a/sql/core/src/test/resources/sql-tests/results/explain.sql.out b/sql/core/src/test/resources/sql-tests/results/explain.sql.out index d8b74ed0e8f66..ae1ee126054d6 100644 --- a/sql/core/src/test/resources/sql-tests/results/explain.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/explain.sql.out @@ -86,8 +86,8 @@ Input : [key#x, val#x] Input: [key#x, val#x] Keys: [key#x] Functions: [partial_max(val#x)] +FuncBufferAttrs: [max#x] Results: [key#x, max#x] -Output: [key#x, max#x] (6) Exchange Input: [key#x, max#x] @@ -96,8 +96,8 @@ Input: [key#x, max#x] Input: [key#x, max#x] Keys: [key#x] Functions: [max(val#x)] +FuncBufferAttrs: [max#x] Results: [key#x, max(val#x)#x AS max(val)#x] -Output: [key#x, max(val)#x] (8) Exchange Input: [key#x, max(val)#x] @@ -150,8 +150,8 @@ Input : [key#x, val#x] Input: [key#x, val#x] Keys: [key#x] Functions: [partial_max(val#x)] +FuncBufferAttrs: [max#x] Results: [key#x, max#x] -Output: [key#x, max#x] (6) Exchange Input: [key#x, max#x] @@ -160,8 +160,8 @@ Input: [key#x, max#x] Input: [key#x, max#x] Keys: [key#x] Functions: [max(val#x)] +FuncBufferAttrs: [max#x] Results: [key#x, max(val#x)#x AS max(val)#x, max(val#x)#x AS max(val#x)#x] -Output: [key#x, max(val)#x, max(val#x)#x] (8) Filter [codegen id : 2] Input : [key#x, max(val)#x, max(val#x)#x] @@ -237,8 +237,8 @@ Input : [key#x, val#x] Input: [key#x, val#x] Keys: [key#x, val#x] Functions: [] +FuncBufferAttrs: [] Results: [key#x, val#x] -Output: [key#x, val#x] (11) Exchange Input: [key#x, val#x] @@ -247,8 +247,8 @@ Input: [key#x, val#x] Input: [key#x, val#x] Keys: [key#x, val#x] Functions: [] +FuncBufferAttrs: [] Results: [key#x, val#x] -Output: [key#x, val#x] -- !query @@ -447,8 +447,8 @@ Input : [key#x, val#x] Input: [key#x] Keys: [] Functions: [partial_max(key#x)] +FuncBufferAttrs: [max#x] Results: [max#x] -Output: [max#x] (10) Exchange Input: [max#x] @@ -457,8 +457,8 @@ Input: [max#x] Input: [max#x] Keys: [] Functions: [max(key#x)] +FuncBufferAttrs: [max#x] Results: [max(key#x)#x AS max(key)#x] -Output: [max(key)#x] Subquery:2 Hosting operator id = 7 Hosting Expression = Subquery scalar-subquery#x, [id=#x] * HashAggregate (18) @@ -492,8 +492,8 @@ Input : [key#x, val#x] Input: [key#x] Keys: [] Functions: [partial_max(key#x)] +FuncBufferAttrs: [max#x] Results: [max#x] -Output: [max#x] (17) Exchange Input: [max#x] @@ -502,8 +502,8 @@ Input: [max#x] Input: [max#x] Keys: [] Functions: [max(key#x)] +FuncBufferAttrs: [max#x] Results: [max(key#x)#x AS max(key)#x] -Output: [max(key)#x] -- !query @@ -514,7 +514,7 @@ EXPLAIN FORMATTED FROM explain_temp2 WHERE val > 0) OR - key = (SELECT max(key) + key = (SELECT avg(key) FROM explain_temp3 WHERE val > 0) -- !query schema @@ -537,7 +537,7 @@ Input: [key#x, val#x] (3) Filter [codegen id : 1] Input : [key#x, val#x] -Condition : ((key#x = Subquery scalar-subquery#x, [id=#x]) OR (key#x = Subquery scalar-subquery#x, [id=#x])) +Condition : ((key#x = Subquery scalar-subquery#x, [id=#x]) OR (cast(key#x as double) = Subquery scalar-subquery#x, [id=#x])) ===== Subqueries ===== @@ -573,8 +573,8 @@ Input : [key#x, val#x] Input: [key#x] Keys: [] Functions: [partial_max(key#x)] +FuncBufferAttrs: [max#x] Results: [max#x] -Output: [max#x] (9) Exchange Input: [max#x] @@ -583,8 +583,8 @@ Input: [max#x] Input: [max#x] Keys: [] Functions: [max(key#x)] +FuncBufferAttrs: [max#x] Results: [max(key#x)#x AS max(key)#x] -Output: [max(key)#x] Subquery:2 Hosting operator id = 3 Hosting Expression = Subquery scalar-subquery#x, [id=#x] * HashAggregate (17) @@ -617,19 +617,19 @@ Input : [key#x, val#x] (15) HashAggregate [codegen id : 1] Input: [key#x] Keys: [] -Functions: [partial_max(key#x)] -Results: [max#x] -Output: [max#x] +Functions: [partial_avg(cast(key#x as bigint))] +FuncBufferAttrs: [sum#x, count#xL] +Results: [sum#x, count#xL] (16) Exchange -Input: [max#x] +Input: [sum#x, count#xL] (17) HashAggregate [codegen id : 2] -Input: [max#x] +Input: [sum#x, count#xL] Keys: [] -Functions: [max(key#x)] -Results: [max(key#x)#x AS max(key)#x] -Output: [max(key)#x] +Functions: [avg(cast(key#x as bigint))] +FuncBufferAttrs: [sum#x, count#xL] +Results: [avg(cast(key#x as bigint))#x AS avg(key)#x] -- !query @@ -681,8 +681,8 @@ Input: [key#x] Input: [key#x] Keys: [] Functions: [partial_avg(cast(key#x as bigint))] +FuncBufferAttrs: [sum#x, count#xL] Results: [sum#x, count#xL] -Output: [sum#x, count#xL] (7) Exchange Input: [sum#x, count#xL] @@ -691,8 +691,8 @@ Input: [sum#x, count#xL] Input: [sum#x, count#xL] Keys: [] Functions: [avg(cast(key#x as bigint))] +FuncBufferAttrs: [sum#x, count#xL] Results: [avg(cast(key#x as bigint))#x AS avg(key)#x] -Output: [avg(key)#x] Subquery:2 Hosting operator id = 3 Hosting Expression = ReusedSubquery Subquery scalar-subquery#x, [id=#x] @@ -814,8 +814,8 @@ Input : [key#x, val#x] Input: [key#x, val#x] Keys: [key#x] Functions: [partial_max(val#x)] +FuncBufferAttrs: [max#x] Results: [key#x, max#x] -Output: [key#x, max#x] (6) Exchange Input: [key#x, max#x] @@ -824,8 +824,8 @@ Input: [key#x, max#x] Input: [key#x, max#x] Keys: [key#x] Functions: [max(val#x)] +FuncBufferAttrs: [max#x] Results: [key#x, max(val#x)#x AS max(val)#x] -Output: [key#x, max(val)#x] (8) ReusedExchange [Reuses operator id: 6] Output : ArrayBuffer(key#x, max#x) @@ -834,8 +834,8 @@ Output : ArrayBuffer(key#x, max#x) Input: [key#x, max#x] Keys: [key#x] Functions: [max(val#x)] +FuncBufferAttrs: [max#x] Results: [key#x, max(val#x)#x AS max(val)#x] -Output: [key#x, max(val)#x] (10) BroadcastExchange Input: [key#x, max(val)#x] @@ -900,8 +900,8 @@ Input: [key#x, val#x] Input: [key#x, val#x] Keys: [] Functions: [partial_count(val#x), partial_sum(cast(key#x as bigint)), partial_count(key#x) FILTER (WHERE (val#x > 1))] +FuncBufferAttrs: [count#xL, sum#xL, count#xL] Results: [count#xL, sum#xL, count#xL] -Output: [count#xL, sum#xL, count#xL] (4) Exchange Input: [count#xL, sum#xL, count#xL] @@ -910,8 +910,8 @@ Input: [count#xL, sum#xL, count#xL] Input: [count#xL, sum#xL, count#xL] Keys: [] Functions: [count(val#x), sum(cast(key#x as bigint)), count(key#x)] +FuncBufferAttrs: [count#xL, sum#xL, count#xL] Results: [(count(val#x)#xL + sum(cast(key#x as bigint))#xL) AS TOTAL#xL, count(key#x)#xL AS count(key) FILTER (WHERE (val > 1))#xL] -Output: [TOTAL#xL, count(key) FILTER (WHERE (val > 1))#xL] -- !query @@ -943,8 +943,8 @@ Input: [key#x, val#x] Input: [key#x, val#x] Keys: [key#x] Functions: [partial_collect_set(val#x, 0, 0)] +FuncBufferAttrs: [buf#x] Results: [key#x, buf#x] -Output: [key#x, buf#x] (4) Exchange Input: [key#x, buf#x] @@ -953,8 +953,8 @@ Input: [key#x, buf#x] Input: [key#x, buf#x] Keys: [key#x] Functions: [collect_set(val#x, 0, 0)] +FuncBufferAttrs: [buf#x] Results: [key#x, sort_array(collect_set(val#x, 0, 0)#x, true)[0] AS sort_array(collect_set(val), true)[0]#x] -Output: [key#x, sort_array(collect_set(val), true)[0]#x] -- !query @@ -991,8 +991,8 @@ Input: [key#x, val#x] Input: [key#x, val#x] Keys: [key#x] Functions: [partial_min(val#x)] +FuncBufferAttrs: [min#x] Results: [key#x, min#x] -Output: [key#x, min#x] (5) Exchange Input: [key#x, min#x] @@ -1004,8 +1004,8 @@ Input: [key#x, min#x] Input: [key#x, min#x] Keys: [key#x] Functions: [min(val#x)] +FuncBufferAttrs: [min#x] Results: [key#x, min(val#x)#x AS min(val)#x] -Output: [key#x, min(val)#x] -- !query From 70cb2df407d788c8dd730d5aa669039948d71192 Mon Sep 17 00:00:00 2001 From: Eric Wu <492960551@qq.com> Date: Thu, 6 Feb 2020 22:09:38 +0800 Subject: [PATCH 7/9] Shwo aggregation attributes --- .../aggregate/BaseAggregateExec.scala | 11 ++-- .../aggregate/HashAggregateExec.scala | 4 ++ .../aggregate/ObjectHashAggregateExec.scala | 4 ++ .../aggregate/SortAggregateExec.scala | 4 ++ .../sql-tests/results/explain.sql.out | 50 +++++++++---------- 5 files changed, 41 insertions(+), 32 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/BaseAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/BaseAggregateExec.scala index 7528ff6fa18fc..904ffc39c9524 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/BaseAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/BaseAggregateExec.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution.aggregate -import org.apache.spark.sql.catalyst.expressions.NamedExpression +import org.apache.spark.sql.catalyst.expressions.{Attribute, NamedExpression} import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.execution.{ExplainUtils, UnaryExecNode} @@ -27,24 +27,21 @@ import org.apache.spark.sql.execution.{ExplainUtils, UnaryExecNode} abstract class BaseAggregateExec extends UnaryExecNode { val groupingExpressions: Seq[NamedExpression] val aggregateExpressions: Seq[AggregateExpression] + val aggregateAttributes: Seq[Attribute] val resultExpressions: Seq[NamedExpression] - protected val aggregateBufferAttributes = { - aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) - } - override def verboseStringWithOperatorId(): String = { val inputString = child.output.mkString("[", ", ", "]") val keyString = groupingExpressions.mkString("[", ", ", "]") val functionString = aggregateExpressions.mkString("[", ", ", "]") - val funcBufferAttrString = aggregateBufferAttributes.mkString("[", ", ", "]") + val aggregateAttributeString = aggregateAttributes.mkString("[", ", ", "]") val resultString = resultExpressions.mkString("[", ", ", "]") s""" |(${ExplainUtils.getOpId(this)}) $nodeName ${ExplainUtils.getCodegenId(this)} |Input: $inputString |Keys: $keyString |Functions: $functionString - |FuncBufferAttrs: $funcBufferAttrString + |Aggregate Attributes: $aggregateAttributeString |Results: $resultString """.stripMargin } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 06d3d366e1289..7a26fd7a8541a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -55,6 +55,10 @@ case class HashAggregateExec( child: SparkPlan) extends BaseAggregateExec with BlockingOperatorWithCodegen with AliasAwareOutputPartitioning { + private[this] val aggregateBufferAttributes = { + aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) + } + require(HashAggregateExec.supportsAggregate(aggregateBufferAttributes)) override lazy val allAttributes: AttributeSeq = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala index fadadbf581f5e..3fb58eb2cc8ba 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala @@ -69,6 +69,10 @@ case class ObjectHashAggregateExec( child: SparkPlan) extends BaseAggregateExec with AliasAwareOutputPartitioning { + private[this] val aggregateBufferAttributes = { + aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) + } + override lazy val allAttributes: AttributeSeq = child.output ++ aggregateBufferAttributes ++ aggregateAttributes ++ aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala index ffeb77bf212dd..77ed469016fa3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala @@ -40,6 +40,10 @@ case class SortAggregateExec( child: SparkPlan) extends BaseAggregateExec with AliasAwareOutputPartitioning { + private[this] val aggregateBufferAttributes = { + aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) + } + override def producedAttributes: AttributeSet = AttributeSet(aggregateAttributes) ++ AttributeSet(resultExpressions.diff(groupingExpressions).map(_.toAttribute)) ++ diff --git a/sql/core/src/test/resources/sql-tests/results/explain.sql.out b/sql/core/src/test/resources/sql-tests/results/explain.sql.out index ae1ee126054d6..bc28d7f87bf00 100644 --- a/sql/core/src/test/resources/sql-tests/results/explain.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/explain.sql.out @@ -86,7 +86,7 @@ Input : [key#x, val#x] Input: [key#x, val#x] Keys: [key#x] Functions: [partial_max(val#x)] -FuncBufferAttrs: [max#x] +Aggregate Attributes: [max#x] Results: [key#x, max#x] (6) Exchange @@ -96,7 +96,7 @@ Input: [key#x, max#x] Input: [key#x, max#x] Keys: [key#x] Functions: [max(val#x)] -FuncBufferAttrs: [max#x] +Aggregate Attributes: [max(val#x)#x] Results: [key#x, max(val#x)#x AS max(val)#x] (8) Exchange @@ -150,7 +150,7 @@ Input : [key#x, val#x] Input: [key#x, val#x] Keys: [key#x] Functions: [partial_max(val#x)] -FuncBufferAttrs: [max#x] +Aggregate Attributes: [max#x] Results: [key#x, max#x] (6) Exchange @@ -160,7 +160,7 @@ Input: [key#x, max#x] Input: [key#x, max#x] Keys: [key#x] Functions: [max(val#x)] -FuncBufferAttrs: [max#x] +Aggregate Attributes: [max(val#x)#x] Results: [key#x, max(val#x)#x AS max(val)#x, max(val#x)#x AS max(val#x)#x] (8) Filter [codegen id : 2] @@ -237,7 +237,7 @@ Input : [key#x, val#x] Input: [key#x, val#x] Keys: [key#x, val#x] Functions: [] -FuncBufferAttrs: [] +Aggregate Attributes: [] Results: [key#x, val#x] (11) Exchange @@ -247,7 +247,7 @@ Input: [key#x, val#x] Input: [key#x, val#x] Keys: [key#x, val#x] Functions: [] -FuncBufferAttrs: [] +Aggregate Attributes: [] Results: [key#x, val#x] @@ -447,7 +447,7 @@ Input : [key#x, val#x] Input: [key#x] Keys: [] Functions: [partial_max(key#x)] -FuncBufferAttrs: [max#x] +Aggregate Attributes: [max#x] Results: [max#x] (10) Exchange @@ -457,7 +457,7 @@ Input: [max#x] Input: [max#x] Keys: [] Functions: [max(key#x)] -FuncBufferAttrs: [max#x] +Aggregate Attributes: [max(key#x)#x] Results: [max(key#x)#x AS max(key)#x] Subquery:2 Hosting operator id = 7 Hosting Expression = Subquery scalar-subquery#x, [id=#x] @@ -492,7 +492,7 @@ Input : [key#x, val#x] Input: [key#x] Keys: [] Functions: [partial_max(key#x)] -FuncBufferAttrs: [max#x] +Aggregate Attributes: [max#x] Results: [max#x] (17) Exchange @@ -502,7 +502,7 @@ Input: [max#x] Input: [max#x] Keys: [] Functions: [max(key#x)] -FuncBufferAttrs: [max#x] +Aggregate Attributes: [max(key#x)#x] Results: [max(key#x)#x AS max(key)#x] @@ -573,7 +573,7 @@ Input : [key#x, val#x] Input: [key#x] Keys: [] Functions: [partial_max(key#x)] -FuncBufferAttrs: [max#x] +Aggregate Attributes: [max#x] Results: [max#x] (9) Exchange @@ -583,7 +583,7 @@ Input: [max#x] Input: [max#x] Keys: [] Functions: [max(key#x)] -FuncBufferAttrs: [max#x] +Aggregate Attributes: [max(key#x)#x] Results: [max(key#x)#x AS max(key)#x] Subquery:2 Hosting operator id = 3 Hosting Expression = Subquery scalar-subquery#x, [id=#x] @@ -618,7 +618,7 @@ Input : [key#x, val#x] Input: [key#x] Keys: [] Functions: [partial_avg(cast(key#x as bigint))] -FuncBufferAttrs: [sum#x, count#xL] +Aggregate Attributes: [sum#x, count#xL] Results: [sum#x, count#xL] (16) Exchange @@ -628,7 +628,7 @@ Input: [sum#x, count#xL] Input: [sum#x, count#xL] Keys: [] Functions: [avg(cast(key#x as bigint))] -FuncBufferAttrs: [sum#x, count#xL] +Aggregate Attributes: [avg(cast(key#x as bigint))#x] Results: [avg(cast(key#x as bigint))#x AS avg(key)#x] @@ -681,7 +681,7 @@ Input: [key#x] Input: [key#x] Keys: [] Functions: [partial_avg(cast(key#x as bigint))] -FuncBufferAttrs: [sum#x, count#xL] +Aggregate Attributes: [sum#x, count#xL] Results: [sum#x, count#xL] (7) Exchange @@ -691,7 +691,7 @@ Input: [sum#x, count#xL] Input: [sum#x, count#xL] Keys: [] Functions: [avg(cast(key#x as bigint))] -FuncBufferAttrs: [sum#x, count#xL] +Aggregate Attributes: [avg(cast(key#x as bigint))#x] Results: [avg(cast(key#x as bigint))#x AS avg(key)#x] Subquery:2 Hosting operator id = 3 Hosting Expression = ReusedSubquery Subquery scalar-subquery#x, [id=#x] @@ -814,7 +814,7 @@ Input : [key#x, val#x] Input: [key#x, val#x] Keys: [key#x] Functions: [partial_max(val#x)] -FuncBufferAttrs: [max#x] +Aggregate Attributes: [max#x] Results: [key#x, max#x] (6) Exchange @@ -824,7 +824,7 @@ Input: [key#x, max#x] Input: [key#x, max#x] Keys: [key#x] Functions: [max(val#x)] -FuncBufferAttrs: [max#x] +Aggregate Attributes: [max(val#x)#x] Results: [key#x, max(val#x)#x AS max(val)#x] (8) ReusedExchange [Reuses operator id: 6] @@ -834,7 +834,7 @@ Output : ArrayBuffer(key#x, max#x) Input: [key#x, max#x] Keys: [key#x] Functions: [max(val#x)] -FuncBufferAttrs: [max#x] +Aggregate Attributes: [max(val#x)#x] Results: [key#x, max(val#x)#x AS max(val)#x] (10) BroadcastExchange @@ -900,7 +900,7 @@ Input: [key#x, val#x] Input: [key#x, val#x] Keys: [] Functions: [partial_count(val#x), partial_sum(cast(key#x as bigint)), partial_count(key#x) FILTER (WHERE (val#x > 1))] -FuncBufferAttrs: [count#xL, sum#xL, count#xL] +Aggregate Attributes: [count#xL, sum#xL, count#xL] Results: [count#xL, sum#xL, count#xL] (4) Exchange @@ -910,7 +910,7 @@ Input: [count#xL, sum#xL, count#xL] Input: [count#xL, sum#xL, count#xL] Keys: [] Functions: [count(val#x), sum(cast(key#x as bigint)), count(key#x)] -FuncBufferAttrs: [count#xL, sum#xL, count#xL] +Aggregate Attributes: [count(val#x)#xL, sum(cast(key#x as bigint))#xL, count(key#x)#xL] Results: [(count(val#x)#xL + sum(cast(key#x as bigint))#xL) AS TOTAL#xL, count(key#x)#xL AS count(key) FILTER (WHERE (val > 1))#xL] @@ -943,7 +943,7 @@ Input: [key#x, val#x] Input: [key#x, val#x] Keys: [key#x] Functions: [partial_collect_set(val#x, 0, 0)] -FuncBufferAttrs: [buf#x] +Aggregate Attributes: [buf#x] Results: [key#x, buf#x] (4) Exchange @@ -953,7 +953,7 @@ Input: [key#x, buf#x] Input: [key#x, buf#x] Keys: [key#x] Functions: [collect_set(val#x, 0, 0)] -FuncBufferAttrs: [buf#x] +Aggregate Attributes: [collect_set(val#x, 0, 0)#x] Results: [key#x, sort_array(collect_set(val#x, 0, 0)#x, true)[0] AS sort_array(collect_set(val), true)[0]#x] @@ -991,7 +991,7 @@ Input: [key#x, val#x] Input: [key#x, val#x] Keys: [key#x] Functions: [partial_min(val#x)] -FuncBufferAttrs: [min#x] +Aggregate Attributes: [min#x] Results: [key#x, min#x] (5) Exchange @@ -1004,7 +1004,7 @@ Input: [key#x, min#x] Input: [key#x, min#x] Keys: [key#x] Functions: [min(val#x)] -FuncBufferAttrs: [min#x] +Aggregate Attributes: [min(val#x)#x] Results: [key#x, min(val#x)#x AS min(val)#x] From 5b91b19e4a518f537c9fd385c209dade477eead4 Mon Sep 17 00:00:00 2001 From: Eric Wu <492960551@qq.com> Date: Sat, 8 Feb 2020 17:35:14 +0800 Subject: [PATCH 8/9] Add override to improve readability --- .../spark/sql/execution/aggregate/HashAggregateExec.scala | 8 ++++---- .../sql/execution/aggregate/ObjectHashAggregateExec.scala | 8 ++++---- .../spark/sql/execution/aggregate/SortAggregateExec.scala | 8 ++++---- 3 files changed, 12 insertions(+), 12 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 7a26fd7a8541a..a5123e70832cc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -47,11 +47,11 @@ import org.apache.spark.util.Utils */ case class HashAggregateExec( requiredChildDistributionExpressions: Option[Seq[Expression]], - groupingExpressions: Seq[NamedExpression], - aggregateExpressions: Seq[AggregateExpression], - aggregateAttributes: Seq[Attribute], + override val groupingExpressions: Seq[NamedExpression], + override val aggregateExpressions: Seq[AggregateExpression], + override val aggregateAttributes: Seq[Attribute], initialInputBufferOffset: Int, - resultExpressions: Seq[NamedExpression], + override val resultExpressions: Seq[NamedExpression], child: SparkPlan) extends BaseAggregateExec with BlockingOperatorWithCodegen with AliasAwareOutputPartitioning { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala index 3fb58eb2cc8ba..019ee43ff2e46 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala @@ -61,11 +61,11 @@ import org.apache.spark.sql.execution.metric.SQLMetrics */ case class ObjectHashAggregateExec( requiredChildDistributionExpressions: Option[Seq[Expression]], - groupingExpressions: Seq[NamedExpression], - aggregateExpressions: Seq[AggregateExpression], - aggregateAttributes: Seq[Attribute], + override val groupingExpressions: Seq[NamedExpression], + override val aggregateExpressions: Seq[AggregateExpression], + override val aggregateAttributes: Seq[Attribute], initialInputBufferOffset: Int, - resultExpressions: Seq[NamedExpression], + override val resultExpressions: Seq[NamedExpression], child: SparkPlan) extends BaseAggregateExec with AliasAwareOutputPartitioning { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala index 77ed469016fa3..2ee8ace1de1c0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala @@ -32,11 +32,11 @@ import org.apache.spark.sql.execution.metric.SQLMetrics */ case class SortAggregateExec( requiredChildDistributionExpressions: Option[Seq[Expression]], - groupingExpressions: Seq[NamedExpression], - aggregateExpressions: Seq[AggregateExpression], - aggregateAttributes: Seq[Attribute], + override val groupingExpressions: Seq[NamedExpression], + override val aggregateExpressions: Seq[AggregateExpression], + override val aggregateAttributes: Seq[Attribute], initialInputBufferOffset: Int, - resultExpressions: Seq[NamedExpression], + override val resultExpressions: Seq[NamedExpression], child: SparkPlan) extends BaseAggregateExec with AliasAwareOutputPartitioning { From dd0988adfaf6b9fc09283ec8cfa1c14bc3f71f8e Mon Sep 17 00:00:00 2001 From: Eric Wu <492960551@qq.com> Date: Mon, 10 Feb 2020 19:16:50 +0800 Subject: [PATCH 9/9] Address comments of abstraction --- .../sql/execution/aggregate/BaseAggregateExec.scala | 10 +++++----- .../sql/execution/aggregate/HashAggregateExec.scala | 8 ++++---- .../execution/aggregate/ObjectHashAggregateExec.scala | 8 ++++---- .../sql/execution/aggregate/SortAggregateExec.scala | 8 ++++---- 4 files changed, 17 insertions(+), 17 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/BaseAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/BaseAggregateExec.scala index 904ffc39c9524..0eaa0f53fdacd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/BaseAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/BaseAggregateExec.scala @@ -24,11 +24,11 @@ import org.apache.spark.sql.execution.{ExplainUtils, UnaryExecNode} /** * Holds common logic for aggregate operators */ -abstract class BaseAggregateExec extends UnaryExecNode { - val groupingExpressions: Seq[NamedExpression] - val aggregateExpressions: Seq[AggregateExpression] - val aggregateAttributes: Seq[Attribute] - val resultExpressions: Seq[NamedExpression] +trait BaseAggregateExec extends UnaryExecNode { + def groupingExpressions: Seq[NamedExpression] + def aggregateExpressions: Seq[AggregateExpression] + def aggregateAttributes: Seq[Attribute] + def resultExpressions: Seq[NamedExpression] override def verboseStringWithOperatorId(): String = { val inputString = child.output.mkString("[", ", ", "]") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index a5123e70832cc..7a26fd7a8541a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -47,11 +47,11 @@ import org.apache.spark.util.Utils */ case class HashAggregateExec( requiredChildDistributionExpressions: Option[Seq[Expression]], - override val groupingExpressions: Seq[NamedExpression], - override val aggregateExpressions: Seq[AggregateExpression], - override val aggregateAttributes: Seq[Attribute], + groupingExpressions: Seq[NamedExpression], + aggregateExpressions: Seq[AggregateExpression], + aggregateAttributes: Seq[Attribute], initialInputBufferOffset: Int, - override val resultExpressions: Seq[NamedExpression], + resultExpressions: Seq[NamedExpression], child: SparkPlan) extends BaseAggregateExec with BlockingOperatorWithCodegen with AliasAwareOutputPartitioning { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala index 019ee43ff2e46..3fb58eb2cc8ba 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala @@ -61,11 +61,11 @@ import org.apache.spark.sql.execution.metric.SQLMetrics */ case class ObjectHashAggregateExec( requiredChildDistributionExpressions: Option[Seq[Expression]], - override val groupingExpressions: Seq[NamedExpression], - override val aggregateExpressions: Seq[AggregateExpression], - override val aggregateAttributes: Seq[Attribute], + groupingExpressions: Seq[NamedExpression], + aggregateExpressions: Seq[AggregateExpression], + aggregateAttributes: Seq[Attribute], initialInputBufferOffset: Int, - override val resultExpressions: Seq[NamedExpression], + resultExpressions: Seq[NamedExpression], child: SparkPlan) extends BaseAggregateExec with AliasAwareOutputPartitioning { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala index 2ee8ace1de1c0..77ed469016fa3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala @@ -32,11 +32,11 @@ import org.apache.spark.sql.execution.metric.SQLMetrics */ case class SortAggregateExec( requiredChildDistributionExpressions: Option[Seq[Expression]], - override val groupingExpressions: Seq[NamedExpression], - override val aggregateExpressions: Seq[AggregateExpression], - override val aggregateAttributes: Seq[Attribute], + groupingExpressions: Seq[NamedExpression], + aggregateExpressions: Seq[AggregateExpression], + aggregateAttributes: Seq[Attribute], initialInputBufferOffset: Int, - override val resultExpressions: Seq[NamedExpression], + resultExpressions: Seq[NamedExpression], child: SparkPlan) extends BaseAggregateExec with AliasAwareOutputPartitioning {