diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index dc54e732066fd..d3b4cf8e34242 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -70,8 +70,6 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool EliminateAnalysisOperators) ) - private def trimAliases(e: Expression) = e.transform { case Alias(c, _) => c } - /** * Makes sure all attributes and logical plans have been resolved. */ @@ -98,7 +96,7 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool object TrimGroupingAliases extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case Aggregate(groups, aggs, child) => - Aggregate(groups.map(trimAliases), aggs, child) + Aggregate(groups.map(_.transform { case Alias(c, _) => c }), aggs, child) } } @@ -118,7 +116,12 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool } aggregateExprs.find { e => - !isValidAggregateExpression(trimAliases(e)) + !isValidAggregateExpression(e.transform { + // Should trim aliases around `GetField`s. These aliases are introduced while + // resolving struct field accesses, because `GetField` is not a `NamedExpression`. + // (Should we just turn `GetField` into a `NamedExpression`?) + case Alias(g: GetField, _) => g + }) }.foreach { e => throw new TreeNodeException(plan, s"Expression not in GROUP BY: $e") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index be38b5c7f3a91..7eb7f29626c35 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -146,17 +146,23 @@ object PartialAggregation { case other => (other, Alias(other, "PartialGroup")()) }.toMap - def trimAliases(e: Expression) = e.transform { case Alias(c, _) => c } + def trimGetFieldAliases(e: Expression) = e.transform { case Alias(g: GetField, _) => g } // Replace aggregations with a new expression that computes the result from the already // computed partial evaluations and grouping values. val rewrittenAggregateExpressions = aggregateExpressions.map(_.transformUp { case e: Expression if partialEvaluations.contains(new TreeNodeRef(e)) => partialEvaluations(new TreeNodeRef(e)).finalEvaluation - case e: Expression if namedGroupingExpressions.contains(e) => - namedGroupingExpressions(e).toAttribute - case e: Expression if namedGroupingExpressions.contains(trimAliases(e)) => - namedGroupingExpressions(trimAliases(e)).toAttribute + + case e: Expression => + // Should trim aliases around `GetField`s. These aliases are introduced while + // resolving struct field accesses, because `GetField` is not a `NamedExpression`. + // (Should we just turn `GetField` into a `NamedExpression`?) + namedGroupingExpressions + .get(e) + .orElse(namedGroupingExpressions.get(trimGetFieldAliases(e))) + .map(_.toAttribute) + .getOrElse(e) }).asInstanceOf[Seq[NamedExpression]] val partialComputation =