From fd6d226528ed0dd2dd0593b70b13cfd67ef6da06 Mon Sep 17 00:00:00 2001 From: Xinyi Yu Date: Wed, 21 Dec 2022 10:53:54 +0800 Subject: [PATCH] [SPARK-41631][SQL] Support implicit lateral column alias resolution on Aggregate ### What changes were proposed in this pull request? This PR implements the implicit lateral column alias on `Aggregate` case. For example, ```sql -- LCA in Aggregate. The avg_salary references an attribute defined by a previous alias SELECT dept, average(salary) AS avg_salary, avg_salary + average(bonus) FROM employee GROUP BY dept ``` The high level implementation idea is to insert the `Project` node above, and falling back to the resolution of lateral alias of Project code path in the last PR. * Phase 1: recognize resolved lateral alias, wrap the attributes referencing them with `LateralColumnAliasReference` * Phase 2: when the `Aggregate` operator is resolved, it goes through the whole aggregation list, extracts the aggregation expressions and grouping expressions to keep them in this `Aggregate` node, and add a `Project` above with the original output. It doesn't do anything on `LateralColumnAliasReference`, but completely leave it to the Project in the future turns of this rule. Example: ``` // Before rewrite: Aggregate [dept#14] [dept#14 AS a#12, 'a + 1, avg(salary#16) AS b#13, 'b + avg(bonus#17)] +- Child [dept#14,name#15,salary#16,bonus#17] // After phase 1: Aggregate [dept#14] [dept#14 AS a#12, lca(a) + 1, avg(salary#16) AS b#13, lca(b) + avg(bonus#17)] +- Child [dept#14,name#15,salary#16,bonus#17] // After phase 2: Project [dept#14 AS a#12, lca(a) + 1, avg(salary)#26 AS b#13, lca(b) + avg(bonus)#27] +- Aggregate [dept#14] [avg(salary#16) AS avg(salary)#26, avg(bonus#17) AS avg(bonus)#27, dept#14] +- Child [dept#14,name#15,salary#16,bonus#17] // Now the problem falls back to the lateral alias resolution in Project. // After future rounds of this rule: Project [a#12, a#12 + 1, b#13, b#13 + avg(bonus)#27] +- Project [dept#14 AS a#12, avg(salary)#26 AS b#13] +- Aggregate [dept#14] [avg(salary#16) AS avg(salary)#26, avg(bonus#17) AS avg(bonus)#27, dept#14] +- Child [dept#14,name#15,salary#16,bonus#17] ``` Similar as the last PR (https://github.com/apache/spark/pull/38776), because lateral column alias has higher resolution priority than outer reference, it will try to resolve an `OuterReference` using lateral column alias, similar as an `UnresolvedAttribute`. If success, it strips `OuterReference` and also wraps it with `LateralColumnAliasReference`. ### Why are the changes needed? Similar as stated in https://github.com/apache/spark/pull/38776. ### Does this PR introduce _any_ user-facing change? Yes, as shown in the above example, it will be able to resolve lateral column alias in Aggregate. ### How was this patch tested? Existing tests and newly added tests. Closes #39040 from anchovYu/SPARK-27561-agg. Authored-by: Xinyi Yu Signed-off-by: Wenchen Fan --- .../main/resources/error/error-classes.json | 5 + .../sql/catalyst/analysis/Analyzer.scala | 32 +- .../ResolveLateralColumnAliasReference.scala | 107 ++- .../sql/errors/QueryCompilationErrors.scala | 15 +- .../apache/spark/sql/internal/SQLConf.scala | 2 +- .../spark/sql/LateralColumnAliasSuite.scala | 613 +++++++++++++++--- .../org/apache/spark/sql/QueryTest.scala | 2 +- 7 files changed, 674 insertions(+), 102 deletions(-) diff --git a/core/src/main/resources/error/error-classes.json b/core/src/main/resources/error/error-classes.json index f176726d0ce54..e6ae567899326 100644 --- a/core/src/main/resources/error/error-classes.json +++ b/core/src/main/resources/error/error-classes.json @@ -1339,6 +1339,11 @@ "The target JDBC server does not support transactions and can only support ALTER TABLE with a single action." ] }, + "LATERAL_COLUMN_ALIAS_IN_AGGREGATE_FUNC" : { + "message" : [ + "Referencing a lateral column alias in the aggregate function ." + ] + }, "LATERAL_JOIN_USING" : { "message" : [ "JOIN USING with LATERAL correlation." diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 7eed76c11222f..e959e7208a42f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1818,7 +1818,7 @@ class Analyzer(override val catalogManager: CatalogManager) val aliases = aliasMap.get(u.nameParts.head).get aliases.size match { case n if n > 1 => - throw QueryCompilationErrors.ambiguousLateralColumnAlias(u.name, n) + throw QueryCompilationErrors.ambiguousLateralColumnAliasError(u.name, n) case n if n == 1 && aliases.head.alias.resolved => // Only resolved alias can be the lateral column alias // The lateral alias can be a struct and have nested field, need to construct @@ -1838,7 +1838,7 @@ class Analyzer(override val catalogManager: CatalogManager) val aliases = aliasMap.get(nameParts.head).get aliases.size match { case n if n > 1 => - throw QueryCompilationErrors.ambiguousLateralColumnAlias(nameParts, n) + throw QueryCompilationErrors.ambiguousLateralColumnAliasError(nameParts, n) case n if n == 1 && aliases.head.alias.resolved => resolveByLateralAlias(nameParts, aliases.head.alias).getOrElse(o) case _ => o @@ -1853,8 +1853,8 @@ class Analyzer(override val catalogManager: CatalogManager) plan.resolveOperatorsUpWithPruning( _.containsAnyPattern(UNRESOLVED_ATTRIBUTE, OUTER_REFERENCE), ruleId) { case p @ Project(projectList, _) if p.childrenResolved - && !ResolveReferences.containsStar(projectList) - && projectList.exists(_.containsAnyPattern(UNRESOLVED_ATTRIBUTE, OUTER_REFERENCE)) => + && !ResolveReferences.containsStar(projectList) + && projectList.exists(_.containsAnyPattern(UNRESOLVED_ATTRIBUTE, OUTER_REFERENCE)) => var aliasMap = CaseInsensitiveMap(Map[String, Seq[AliasEntry]]()) val newProjectList = projectList.zipWithIndex.map { case (a: Alias, idx) => @@ -1869,6 +1869,30 @@ class Analyzer(override val catalogManager: CatalogManager) wrapLCARef(e, p, aliasMap) } p.copy(projectList = newProjectList) + + // Implementation notes: + // In Aggregate, introducing and wrapping this resolved leaf expression + // LateralColumnAliasReference is especially needed because it needs an accurate condition + // to trigger adding a Project above and extracting and pushing down aggregate functions + // or grouping expressions. Such operation can only be done once. With this + // LateralColumnAliasReference, that condition can simply be when the whole Aggregate is + // resolved. Otherwise, it can't tell if all aggregate functions are created and + // resolved so that it can start the extraction, because the lateral alias reference is + // unresolved and can be the argument to functions, blocking the resolution of functions. + case agg @ Aggregate(_, aggExprs, _) if agg.childrenResolved + && !ResolveReferences.containsStar(aggExprs) + && aggExprs.exists(_.containsAnyPattern(UNRESOLVED_ATTRIBUTE, OUTER_REFERENCE)) => + + var aliasMap = CaseInsensitiveMap(Map[String, Seq[AliasEntry]]()) + val newAggExprs = aggExprs.zipWithIndex.map { + case (a: Alias, idx) => + val lcaWrapped = wrapLCARef(a, agg, aliasMap).asInstanceOf[Alias] + aliasMap = insertIntoAliasMap(lcaWrapped, idx, aliasMap) + lcaWrapped + case (e, _) => + wrapLCARef(e, agg, aliasMap) + } + agg.copy(aggregateExpressions = newAggExprs) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAliasReference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAliasReference.scala index 2ca187b95ffda..ec8bdb97fbc67 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAliasReference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAliasReference.scala @@ -17,11 +17,14 @@ package org.apache.spark.sql.catalyst.analysis -import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeMap, LateralColumnAliasReference, NamedExpression} -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project} +import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeMap, Expression, LateralColumnAliasReference, LeafExpression, Literal, NamedExpression, ScalarSubquery} +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Project} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.trees.TreeNodeTag import org.apache.spark.sql.catalyst.trees.TreePattern.LATERAL_COLUMN_ALIAS_REFERENCE +import org.apache.spark.sql.catalyst.util.toPrettySQL +import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.SQLConf /** @@ -31,22 +34,26 @@ import org.apache.spark.sql.internal.SQLConf * Plan-wise, it handles two types of operators: Project and Aggregate. * - in Project, pushing down the referenced lateral alias into a newly created Project, resolve * the attributes referencing these aliases - * - in Aggregate TODO. + * - in Aggregate, inserting the Project node above and falling back to the resolution of Project. * * The whole process is generally divided into two phases: * 1) recognize resolved lateral alias, wrap the attributes referencing them with * [[LateralColumnAliasReference]] - * 2) when the whole operator is resolved, unwrap [[LateralColumnAliasReference]]. - * For Project, it further resolves the attributes and push down the referenced lateral aliases. - * For Aggregate, TODO + * 2) when the whole operator is resolved, + * For Project, it unwrap [[LateralColumnAliasReference]], further resolves the attributes and + * push down the referenced lateral aliases. + * For Aggregate, it goes through the whole aggregation list, extracts the aggregation + * expressions and grouping expressions to keep them in this Aggregate node, and add a Project + * above with the original output. It doesn't do anything on [[LateralColumnAliasReference]], but + * completely leave it to the Project in the future turns of this rule. * - * Example for Project: + * ** Example for Project: * Before rewrite: * Project [age AS a, 'a + 1] * +- Child * * After phase 1: - * Project [age AS a, lateralalias(a) + 1] + * Project [age AS a, lca(a) + 1] * +- Child * * After phase 2: @@ -54,7 +61,27 @@ import org.apache.spark.sql.internal.SQLConf * +- Project [child output, age AS a] * +- Child * - * Example for Aggregate TODO + * ** Example for Aggregate: + * Before rewrite: + * Aggregate [dept#14] [dept#14 AS a#12, 'a + 1, avg(salary#16) AS b#13, 'b + avg(bonus#17)] + * +- Child [dept#14,name#15,salary#16,bonus#17] + * + * After phase 1: + * Aggregate [dept#14] [dept#14 AS a#12, lca(a) + 1, avg(salary#16) AS b#13, lca(b) + avg(bonus#17)] + * +- Child [dept#14,name#15,salary#16,bonus#17] + * + * After phase 2: + * Project [dept#14 AS a#12, lca(a) + 1, avg(salary)#26 AS b#13, lca(b) + avg(bonus)#27] + * +- Aggregate [dept#14] [avg(salary#16) AS avg(salary)#26, avg(bonus#17) AS avg(bonus)#27,dept#14] + * +- Child [dept#14,name#15,salary#16,bonus#17] + * + * Now the problem falls back to the lateral alias resolution in Project. + * After future rounds of this rule: + * Project [a#12, a#12 + 1, b#13, b#13 + avg(bonus)#27] + * +- Project [dept#14 AS a#12, avg(salary)#26 AS b#13] + * +- Aggregate [dept#14] [avg(salary#16) AS avg(salary)#26, avg(bonus#17) AS avg(bonus)#27, + * dept#14] + * +- Child [dept#14,name#15,salary#16,bonus#17] * * * The name resolution priority: @@ -75,6 +102,13 @@ object ResolveLateralColumnAliasReference extends Rule[LogicalPlan] { */ val NAME_PARTS_FROM_UNRESOLVED_ATTR = TreeNodeTag[Seq[String]]("name_parts_from_unresolved_attr") + private def assignAlias(expr: Expression): NamedExpression = { + expr match { + case ne: NamedExpression => ne + case e => Alias(e, toPrettySQL(e))() + } + } + override def apply(plan: LogicalPlan): LogicalPlan = { if (!conf.getConf(SQLConf.LATERAL_COLUMN_ALIAS_IMPLICIT_ENABLED)) { plan @@ -129,6 +163,61 @@ object ResolveLateralColumnAliasReference extends Rule[LogicalPlan] { child = Project(innerProjectList.toSeq, child) ) } + + case agg @ Aggregate(groupingExpressions, aggregateExpressions, _) if agg.resolved + && aggregateExpressions.exists(_.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE)) => + + // Check if current Aggregate is eligible to lift up with Project: the aggregate + // expression only contains: 1) aggregate functions, 2) grouping expressions, 3) lateral + // column alias reference or 4) literals. + // This check is to prevent unnecessary transformation on invalid plan, to guarantee it + // throws the same exception. For example, cases like non-aggregate expressions not + // in group by, once transformed, will throw a different exception: missing input. + def eligibleToLiftUp(exp: Expression): Boolean = { + exp match { + case e if AggregateExpression.isAggregate(e) => true + case e if groupingExpressions.exists(_.semanticEquals(e)) => true + case _: Literal | _: LateralColumnAliasReference => true + case s: ScalarSubquery if s.children.nonEmpty + && !groupingExpressions.exists(_.semanticEquals(s)) => false + case _: LeafExpression => false + case e => e.children.forall(eligibleToLiftUp) + } + } + if (!aggregateExpressions.forall(eligibleToLiftUp)) { + return agg + } + + val newAggExprs = collection.mutable.Set.empty[NamedExpression] + val expressionMap = collection.mutable.LinkedHashMap.empty[Expression, NamedExpression] + val projectExprs = aggregateExpressions.map { exp => + exp.transformDown { + case aggExpr: AggregateExpression => + // Doesn't support referencing a lateral alias in aggregate function + if (aggExpr.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE)) { + aggExpr.collectFirst { + case lcaRef: LateralColumnAliasReference => + throw QueryCompilationErrors.lateralColumnAliasInAggFuncUnsupportedError( + lcaRef.nameParts, aggExpr) + } + } + val ne = expressionMap.getOrElseUpdate(aggExpr.canonicalized, assignAlias(aggExpr)) + newAggExprs += ne + ne.toAttribute + case e if groupingExpressions.exists(_.semanticEquals(e)) => + val ne = expressionMap.getOrElseUpdate(e.canonicalized, assignAlias(e)) + newAggExprs += ne + ne.toAttribute + }.asInstanceOf[NamedExpression] + } + if (newAggExprs.isEmpty) { + agg + } else { + Project( + projectList = projectExprs, + child = agg.copy(aggregateExpressions = newAggExprs.toSeq) + ) + } } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala index b0cf8f6876ccf..d537c1685d644 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala @@ -3395,7 +3395,7 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase { } } - def ambiguousLateralColumnAlias(name: String, numOfMatches: Int): Throwable = { + def ambiguousLateralColumnAliasError(name: String, numOfMatches: Int): Throwable = { new AnalysisException( errorClass = "AMBIGUOUS_LATERAL_COLUMN_ALIAS", messageParameters = Map( @@ -3404,7 +3404,7 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase { ) ) } - def ambiguousLateralColumnAlias(nameParts: Seq[String], numOfMatches: Int): Throwable = { + def ambiguousLateralColumnAliasError(nameParts: Seq[String], numOfMatches: Int): Throwable = { new AnalysisException( errorClass = "AMBIGUOUS_LATERAL_COLUMN_ALIAS", messageParameters = Map( @@ -3413,4 +3413,15 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase { ) ) } + + def lateralColumnAliasInAggFuncUnsupportedError( + lcaNameParts: Seq[String], aggExpr: Expression): Throwable = { + new AnalysisException( + errorClass = "UNSUPPORTED_FEATURE.LATERAL_COLUMN_ALIAS_IN_AGGREGATE_FUNC", + messageParameters = Map( + "lca" -> toSQLId(lcaNameParts), + "aggFunc" -> toSQLExpr(aggExpr) + ) + ) + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 5565926afddfa..19c302641b16d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -4045,7 +4045,7 @@ object SQLConf { "higher resolution priority than the lateral column alias.") .version("3.4.0") .booleanConf - .createWithDefault(false) + .createWithDefault(true) /** * Holds information about keys that have been deprecated. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala index abeb3bb784124..624d5f98642a2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala @@ -20,11 +20,28 @@ package org.apache.spark.sql import org.scalactic.source.Position import org.scalatest.Tag +import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, ExpressionSet} +import org.apache.spark.sql.catalyst.plans.logical.Aggregate import org.apache.spark.sql.catalyst.trees.TreePattern.OUTER_REFERENCE import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession -class LateralColumnAliasSuite extends QueryTest with SharedSparkSession { +/** + * Lateral column alias base suite with LCA off, extended by LateralColumnAliasSuite with LCA on. + * Should test behaviors remaining the same no matter LCA conf is on or off. + */ +class LateralColumnAliasSuiteBase extends QueryTest with SharedSparkSession { + // by default the tests in this suites run with LCA off + val lcaEnabled: Boolean = false + override protected def test(testName: String, testTags: Tag*)(testFun: => Any) + (implicit pos: Position): Unit = { + super.test(testName, testTags: _*) { + withSQLConf(SQLConf.LATERAL_COLUMN_ALIAS_IMPLICIT_ENABLED.key -> lcaEnabled.toString) { + testFun + } + } + } + protected val testTable: String = "employee" override def beforeAll(): Unit = { @@ -58,76 +75,299 @@ class LateralColumnAliasSuite extends QueryTest with SharedSparkSession { } } - val lcaEnabled: Boolean = true - // by default the tests in this suites run with LCA on - override protected def test(testName: String, testTags: Tag*)(testFun: => Any) - (implicit pos: Position): Unit = { - super.test(testName, testTags: _*) { - withSQLConf(SQLConf.LATERAL_COLUMN_ALIAS_IMPLICIT_ENABLED.key -> lcaEnabled.toString) { - testFun - } + protected def withLCAOff(f: => Unit): Unit = { + withSQLConf(SQLConf.LATERAL_COLUMN_ALIAS_IMPLICIT_ENABLED.key -> "false") { + f + } + } + protected def withLCAOn(f: => Unit): Unit = { + withSQLConf(SQLConf.LATERAL_COLUMN_ALIAS_IMPLICIT_ENABLED.key -> "true") { + f } } + + test("Lateral alias conflicts with table column - Project") { + checkAnswer( + sql( + "select salary * 2 as salary, salary * 2 + bonus as " + + s"new_income from $testTable where name = 'amy'"), + Row(20000, 21000)) + + checkAnswer( + sql( + "select salary * 2 as salary, (salary + bonus) * 3 - (salary + bonus) as " + + s"new_income from $testTable where name = 'amy'"), + Row(20000, 22000)) + + checkAnswer( + sql(s"SELECT named_struct('joinYear', 2022) AS properties, properties.joinYear " + + s"FROM $testTable WHERE name = 'amy'"), + Row(Row(2022), 2019)) + + checkAnswer( + sql(s"SELECT named_struct('name', 'someone') AS $testTable, $testTable.name " + + s"FROM $testTable WHERE name = 'amy'"), + Row(Row("someone"), "amy")) + + // CTE table + checkAnswer( + sql( + s""" + |WITH temp_table(x, y) AS (SELECT 1, 2) + |SELECT 100 AS x, x + 1 + |FROM temp_table + |""".stripMargin + ), + Row(100, 2)) + } + + test("Lateral alias conflicts with table column - Aggregate") { + checkAnswer( + sql( + s""" + |SELECT + | sum(salary) AS salary, + | sum(bonus) AS bonus, + | avg(salary) AS avg_s, + | avg(salary + bonus) AS avg_t + |FROM $testTable GROUP BY dept ORDER BY dept + |""".stripMargin), + Row(19000, 2200, 9500.0, 10600.0) :: + Row(22000, 2500, 11000.0, 12250.0) :: + Row(12000, 1200, 12000.0, 13200.0) :: + Nil) + + // TODO: how does it correctly resolve to the right dept in SORT? + checkAnswer( + sql(s"SELECT avg(bonus) AS dept, dept, avg(salary) " + + s"FROM $testTable GROUP BY dept ORDER BY dept"), + Row(1100, 1, 9500.0) :: Row(1250, 2, 11000) :: Row(1200, 6, 12000) :: Nil + ) + + checkAnswer( + sql("SELECT named_struct('joinYear', 2022) AS properties, min(properties.joinYear) " + + s"FROM $testTable GROUP BY dept ORDER BY dept"), + Row(Row(2022), 2019) :: Row(Row(2022), 2017) :: Row(Row(2022), 2018) :: Nil) + + checkAnswer( + sql(s"SELECT named_struct('salary', 20000) AS $testTable, avg($testTable.salary) " + + s"FROM $testTable GROUP BY dept ORDER BY dept"), + Row(Row(20000), 9500) :: Row(Row(20000), 11000) :: Row(Row(20000), 12000) :: Nil) + + // CTE table + checkAnswer( + sql( + s""" + |WITH temp_table(x, y) AS (SELECT 1, 2) + |SELECT 100 AS x, x + 1 + |FROM temp_table + |GROUP BY x + |""".stripMargin), + Row(100, 2)) + } +} + +/** + * Lateral column alias base with LCA on. + */ +class LateralColumnAliasSuite extends LateralColumnAliasSuiteBase { + // by default the tests in this suites run with LCA on + override val lcaEnabled: Boolean = true + // mark special testcases test both LCA on and off protected def testOnAndOff(testName: String, testTags: Tag*)(testFun: => Any) - (implicit pos: Position): Unit = { + (implicit pos: Position): Unit = { super.test(testName, testTags: _*)(testFun) } - private def withLCAOff(f: => Unit): Unit = { - withSQLConf(SQLConf.LATERAL_COLUMN_ALIAS_IMPLICIT_ENABLED.key -> "false") { - f - } + private def checkDuplicatedAliasErrorHelper( + query: String, parameters: Map[String, String]): Unit = { + checkError( + exception = intercept[AnalysisException] {sql(query)}, + errorClass = "AMBIGUOUS_LATERAL_COLUMN_ALIAS", + sqlState = "42000", + parameters = parameters + ) } - private def withLCAOn(f: => Unit): Unit = { - withSQLConf(SQLConf.LATERAL_COLUMN_ALIAS_IMPLICIT_ENABLED.key -> "true") { - f + + private def checkAnswerWhenOnAndExceptionWhenOff( + query: String, expectedAnswerLCAOn: Seq[Row]): Unit = { + withLCAOn { checkAnswer(sql(query), expectedAnswerLCAOn) } + withLCAOff { + assert(intercept[AnalysisException]{ sql(query) } + .getErrorClass == "UNRESOLVED_COLUMN.WITH_SUGGESTION") } } testOnAndOff("Lateral alias basics - Project") { - def checkAnswerWhenOnAndExceptionWhenOff(query: String, expectedAnswerLCAOn: Row): Unit = { - withLCAOn { checkAnswer(sql(query), expectedAnswerLCAOn) } - withLCAOff { - assert(intercept[AnalysisException]{ sql(query) } - .getErrorClass == "UNRESOLVED_COLUMN.WITH_SUGGESTION") - } - } - checkAnswerWhenOnAndExceptionWhenOff( s"select dept as d, d + 1 as e from $testTable where name = 'amy'", - Row(1, 2)) + Row(1, 2) :: Nil) checkAnswerWhenOnAndExceptionWhenOff( s"select salary * 2 as new_salary, new_salary + bonus from $testTable where name = 'amy'", - Row(20000, 21000)) + Row(20000, 21000) :: Nil) checkAnswerWhenOnAndExceptionWhenOff( s"select salary * 2 as new_salary, new_salary + bonus * 2 as new_income from $testTable" + s" where name = 'amy'", - Row(20000, 22000)) + Row(20000, 22000) :: Nil) checkAnswerWhenOnAndExceptionWhenOff( "select salary * 2 as new_salary, (new_salary + bonus) * 3 - new_salary * 2 as " + s"new_income from $testTable where name = 'amy'", - Row(20000, 23000)) + Row(20000, 23000) :: Nil) // should referring to the previously defined LCA checkAnswerWhenOnAndExceptionWhenOff( s"SELECT salary * 1.5 AS d, d, 10000 AS d FROM $testTable WHERE name = 'jen'", - Row(18000, 18000, 10000) - ) + Row(18000, 18000, 10000) :: Nil) + + // LCA and conflicted table column mixed + checkAnswerWhenOnAndExceptionWhenOff( + "select salary * 2 as salary, (salary + bonus) * 2 as bonus, " + + s"salary + bonus as prev_income, prev_income + bonus + salary from $testTable" + + " where name = 'amy'", + Row(20000, 22000, 11000, 22000) :: Nil) } - test("Duplicated lateral alias names - Project") { - def checkDuplicatedAliasErrorHelper(query: String, parameters: Map[String, String]): Unit = { + testOnAndOff("Lateral alias basics - Aggregate") { + // doesn't support lca used in aggregation functions + withLCAOn( checkError( - exception = intercept[AnalysisException] {sql(query)}, - errorClass = "AMBIGUOUS_LATERAL_COLUMN_ALIAS", - sqlState = "42000", - parameters = parameters - ) - } + exception = intercept[AnalysisException] { + sql(s"SELECT 10000 AS lca, count(lca) FROM $testTable GROUP BY dept") + }, + errorClass = "UNSUPPORTED_FEATURE.LATERAL_COLUMN_ALIAS_IN_AGGREGATE_FUNC", + sqlState = "0A000", + parameters = Map( + "lca" -> "`lca`", + "aggFunc" -> "\"count(lateralAliasReference(lca))\"" + ))) + withLCAOn( + checkError( + exception = intercept[AnalysisException] { + sql(s"SELECT dept AS lca, avg(lca) FROM $testTable GROUP BY dept") + }, + errorClass = "UNSUPPORTED_FEATURE.LATERAL_COLUMN_ALIAS_IN_AGGREGATE_FUNC", + sqlState = "0A000", + parameters = Map( + "lca" -> "`lca`", + "aggFunc" -> "\"avg(lateralAliasReference(lca))\"" + ))) + // doesn't support nested aggregate expressions + withLCAOn( + checkError( + exception = intercept[AnalysisException] { + sql(s"SELECT sum(salary) AS a, avg(a) FROM $testTable") + }, + errorClass = "UNSUPPORTED_FEATURE.LATERAL_COLUMN_ALIAS_IN_AGGREGATE_FUNC", + sqlState = "0A000", + parameters = Map( + "lca" -> "`a`", + "aggFunc" -> "\"avg(lateralAliasReference(a))\"" + ))) + + // literal as LCA, used in various cases of expressions + checkAnswerWhenOnAndExceptionWhenOff( + s""" + |SELECT + | 10000 AS baseline_salary, + | baseline_salary * 1.5, + | baseline_salary + dept * 10000, + | baseline_salary + avg(bonus) + |FROM $testTable + |GROUP BY dept + |ORDER BY dept + |""".stripMargin, + Row(10000, 15000.0, 20000, 11100.0) :: + Row(10000, 15000.0, 30000, 11250.0) :: + Row(10000, 15000.0, 70000, 11200.0) :: Nil + ) + + // grouping attribute as LCA, used in various cases of expressions + checkAnswerWhenOnAndExceptionWhenOff( + s""" + |SELECT + | salary + 1000 AS new_salary, + | new_salary - 1000 AS prev_salary, + | new_salary - salary, + | new_salary - avg(salary) + |FROM $testTable + |GROUP BY salary + |ORDER BY salary + |""".stripMargin, + Row(10000, 9000, 1000, 1000.0) :: + Row(11000, 10000, 1000, 1000.0) :: + Row(13000, 12000, 1000, 1000.0) :: Nil + ) + + // aggregate expression as LCA, used in various cases of expressions + checkAnswerWhenOnAndExceptionWhenOff( + s""" + |SELECT + | sum(salary) AS dept_salary_sum, + | sum(bonus) AS dept_bonus_sum, + | dept_salary_sum * 1.5, + | concat(string(dept_salary_sum), ': dept', string(dept)), + | dept_salary_sum + sum(bonus), + | dept_salary_sum + dept_bonus_sum, + | avg(salary * 1.5 + 10000 + bonus * 1.0) AS avg_total, + | avg_total + |FROM $testTable + |GROUP BY dept + |ORDER BY dept + |""".stripMargin, + Row(19000, 2200, 28500.0, "19000: dept1", 21200, 21200, 25350, 25350) :: + Row(22000, 2500, 33000.0, "22000: dept2", 24500, 24500, 27750, 27750) :: + Row(12000, 1200, 18000.0, "12000: dept6", 13200, 13200, 29200, 29200) :: + Nil + ) + checkAnswerWhenOnAndExceptionWhenOff( + s"SELECT sum(salary) AS s, s + sum(bonus) AS total FROM $testTable", + Row(53000, 58900) :: Nil + ) + + // grouping expression are correctly recognized and pushed down + checkAnswer( + sql( + s""" + |SELECT dept AS a, dept + 10 AS b, avg(salary) + dept, avg(salary) AS c, + | c + dept, avg(salary + dept), count(dept) + |FROM $testTable GROUP BY dept ORDER BY dept + |""".stripMargin), + Row(1, 11, 9501, 9500, 9501, 9501, 2) :: + Row(2, 12, 11002, 11000, 11002, 11002, 2) :: + Row(6, 16, 12006, 12000, 12006, 12006, 1) :: Nil) + + // two grouping expressions + checkAnswer( + sql( + s""" + |SELECT dept + salary, avg(salary) + dept, avg(bonus) AS c, c + salary + dept, + | avg(bonus) + salary + |FROM $testTable GROUP BY dept, salary HAVING dept = 2 ORDER BY dept, salary + |""".stripMargin + ), + Row(10002, 10002, 1300, 11302, 11300) :: Row(12002, 12002, 1200, 13202, 13200) :: Nil + ) + // LCA and conflicted table column mixed + checkAnswerWhenOnAndExceptionWhenOff( + s""" + |SELECT + | sum(salary) AS salary, + | sum(bonus) AS bonus, + | avg(salary) AS avg_s, + | avg(salary + bonus) AS avg_t, + | avg_s + avg_t + |FROM $testTable GROUP BY dept ORDER BY dept + |""".stripMargin, + Row(19000, 2200, 9500.0, 10600.0, 20100.0) :: + Row(22000, 2500, 11000.0, 12250.0, 23250.0) :: + Row(12000, 1200, 12000.0, 13200.0, 25200.0) :: Nil) + } + + test("Duplicated lateral alias names - Project") { // Has duplicated names but not referenced is fine checkAnswer( sql(s"SELECT salary AS d, bonus AS d FROM $testTable WHERE name = 'jen'"), @@ -175,35 +415,58 @@ class LateralColumnAliasSuite extends QueryTest with SharedSparkSession { ) } - test("Lateral alias conflicts with table column - Project") { + test("Duplicated lateral alias names - Aggregate") { + // Has duplicated names but not referenced is fine checkAnswer( - sql( - "select salary * 2 as salary, salary * 2 + bonus as " + - s"new_income from $testTable where name = 'amy'"), - Row(20000, 21000)) - + sql(s"SELECT dept AS d, name AS d FROM $testTable GROUP BY dept, name ORDER BY dept, name"), + Row(1, "amy") :: Row(1, "cathy") :: Row(2, "alex") :: Row(2, "david") :: Row(6, "jen") :: Nil + ) checkAnswer( - sql( - "select salary * 2 as salary, (salary + bonus) * 3 - (salary + bonus) as " + - s"new_income from $testTable where name = 'amy'"), - Row(20000, 22000)) - + sql(s"SELECT dept AS d, d, 10 AS d FROM $testTable GROUP BY dept ORDER BY dept"), + Row(1, 1, 10) :: Row(2, 2, 10) :: Row(6, 6, 10) :: Nil + ) + checkAnswer( + sql(s"SELECT sum(salary * 1.5) AS d, d, 10 AS d FROM $testTable GROUP BY dept ORDER BY dept"), + Row(28500, 28500, 10) :: Row(33000, 33000, 10) :: Row(18000, 18000, 10) :: Nil + ) checkAnswer( sql( - "select salary * 2 as salary, (salary + bonus) * 2 as bonus, " + - s"salary + bonus as prev_income, prev_income + bonus + salary from $testTable" + - " where name = 'amy'"), - Row(20000, 22000, 11000, 22000)) + s""" + |SELECT sum(salary * 1.5) AS d, d, d + sum(bonus) AS d + |FROM $testTable + |GROUP BY dept + |ORDER BY dept + |""".stripMargin), + Row(28500, 28500, 30700) :: Row(33000, 33000, 35500) :: Row(18000, 18000, 19200) :: Nil + ) - checkAnswer( - sql(s"SELECT named_struct('joinYear', 2022) AS properties, properties.joinYear " + - s"FROM $testTable WHERE name = 'amy'"), - Row(Row(2022), 2019)) + // Referencing duplicated names raises error + checkDuplicatedAliasErrorHelper( + s"SELECT dept * 2.0 AS d, d, 10000 AS d, d + 1 FROM $testTable GROUP BY dept", + parameters = Map("name" -> "`d`", "n" -> "2") + ) + checkDuplicatedAliasErrorHelper( + s"SELECT 10000 AS d, d * 1.0, dept * 2.0 AS d, d FROM $testTable GROUP BY dept", + parameters = Map("name" -> "`d`", "n" -> "2") + ) + checkDuplicatedAliasErrorHelper( + s"SELECT avg(salary) AS d, d * 1.0, avg(bonus * 1.5) AS d, d FROM $testTable GROUP BY dept", + parameters = Map("name" -> "`d`", "n" -> "2") + ) + checkDuplicatedAliasErrorHelper( + s"SELECT dept AS d, d + 1 AS d, d + 1 AS d FROM $testTable GROUP BY dept", + parameters = Map("name" -> "`d`", "n" -> "2") + ) checkAnswer( - sql(s"SELECT named_struct('name', 'someone') AS $testTable, $testTable.name " + - s"FROM $testTable WHERE name = 'amy'"), - Row(Row("someone"), "amy")) + sql(s""" + |SELECT avg(salary * 1.5) AS salary, sum(salary), dept AS salary, avg(salary) + |FROM $testTable + |GROUP BY dept + |HAVING dept = 6 + |""".stripMargin), + Row(18000, 12000, 6, 12000) + ) } testOnAndOff("Lateral alias conflicts with OuterReference - Project") { @@ -258,7 +521,7 @@ class LateralColumnAliasSuite extends QueryTest with SharedSparkSession { | FROM (SELECT dept * 2.0 AS id, id + 1 AS id2 FROM $testTable)) > 5 |ORDER BY id |""".stripMargin - withLCAOff { intercept[AnalysisException] { sql(query4) } } // surprisingly can't run .. + withLCAOff { intercept[AnalysisException] { sql(query4) } } withLCAOn { val analyzedPlan = sql(query4).queryExecution.analyzed assert(!analyzedPlan.containsPattern(OUTER_REFERENCE)) @@ -268,46 +531,83 @@ class LateralColumnAliasSuite extends QueryTest with SharedSparkSession { } // TODO: more tests on LCA in subquery - test("Lateral alias of a complex type - Project") { - checkAnswer( - sql("SELECT named_struct('a', 1) AS foo, foo.a + 1 AS bar, bar + 1"), - Row(Row(1), 2, 3)) + test("Lateral alias conflicts with OuterReference - Aggregate") { + // test if lca rule strips the OuterReference and resolves to lateral alias + val query = + s""" + |SELECT * + |FROM range(1, 7) + |WHERE ( + | SELECT id2 + | FROM (SELECT avg(salary * 1.0) AS id, id + 1 AS id2 FROM $testTable GROUP BY dept)) > 5 + |""".stripMargin + val analyzedPlan = sql(query).queryExecution.analyzed + assert(!analyzedPlan.containsPattern(OUTER_REFERENCE)) + } + test("Lateral alias of a complex type") { + // test both Project and Aggregate + val querySuffixes = Seq("", s"FROM $testTable GROUP BY dept HAVING dept = 6") + querySuffixes.foreach { querySuffix => + checkAnswer( + sql(s"SELECT named_struct('a', 1) AS foo, foo.a + 1 AS bar, bar + 1 $querySuffix"), + Row(Row(1), 2, 3)) + checkAnswer( + sql("SELECT named_struct('a', named_struct('b', 1)) AS foo, foo.a.b + 1 AS bar " + + s"$querySuffix"), + Row(Row(Row(1)), 2)) + + checkAnswer( + sql(s"SELECT array(1, 2, 3) AS foo, foo[1] AS bar, bar + 1 $querySuffix"), + Row(Seq(1, 2, 3), 2, 3)) checkAnswer( - sql("SELECT named_struct('a', named_struct('b', 1)) AS foo, foo.a.b + 1 AS bar"), - Row(Row(Row(1)), 2) - ) - + sql("SELECT array(array(1, 2), array(1, 2, 3), array(100)) AS foo, foo[2][0] + 1 AS bar " + + s"$querySuffix"), + Row(Seq(Seq(1, 2), Seq(1, 2, 3), Seq(100)), 101)) checkAnswer( - sql("SELECT array(1, 2, 3) AS foo, foo[1] AS bar, bar + 1"), - Row(Seq(1, 2, 3), 2, 3) - ) + sql("SELECT array(named_struct('a', 1), named_struct('a', 2)) AS foo, foo[0].a + 1 AS bar" + + s" $querySuffix"), + Row(Seq(Row(1), Row(2)), 2)) + + checkAnswer( + sql(s"SELECT map('a', 1, 'b', 2) AS foo, foo['b'] AS bar, bar + 1 $querySuffix"), + Row(Map("a" -> 1, "b" -> 2), 2, 3)) + } + checkAnswer( - sql("SELECT array(array(1, 2), array(1, 2, 3), array(100)) AS foo, foo[2][0] + 1 AS bar"), - Row(Seq(Seq(1, 2), Seq(1, 2, 3), Seq(100)), 101) - ) + sql("SELECT named_struct('s', salary * 1.0) AS foo, foo.s + 1 AS bar, bar + 1 " + + s"FROM $testTable WHERE dept = 1 ORDER BY name"), + Row(Row(10000), 10001, 10002) :: Row(Row(9000), 9001, 9002) :: Nil) + checkAnswer( - sql("SELECT array(named_struct('a', 1), named_struct('a', 2)) AS foo, foo[0].a + 1 AS bar"), - Row(Seq(Row(1), Row(2)), 2) - ) + sql(s"SELECT properties AS foo, foo.joinYear AS bar, bar + 1 " + + s"FROM $testTable GROUP BY properties HAVING properties.mostRecentEmployer = 'B'"), + Row(Row(2020, "B"), 2020, 2021)) checkAnswer( - sql("SELECT map('a', 1, 'b', 2) AS foo, foo['b'] AS bar, bar + 1"), - Row(Map("a" -> 1, "b" -> 2), 2, 3) + sql(s"SELECT named_struct('avg_salary', avg(salary)) AS foo, foo.avg_salary + 1 AS bar " + + s"FROM $testTable GROUP BY dept ORDER BY dept"), + Row(Row(9500), 9501) :: Row(Row(11000), 11001) :: Row(Row(12000), 12001) :: Nil ) } - test("Lateral alias reference attribute further be used by upper plan - Project") { - // this is out of the scope of lateral alias project functionality requirements, but naturally - // supported by the current design + test("Lateral alias reference attribute further be used by upper plan") { + // underlying this is not in the scope of lateral alias project but things already supported checkAnswer( sql(s"SELECT properties AS new_properties, new_properties.joinYear AS new_join_year " + s"FROM $testTable WHERE dept = 1 ORDER BY new_join_year DESC"), Row(Row(2020, "B"), 2020) :: Row(Row(2019, "A"), 2019) :: Nil ) + + checkAnswer( + sql(s"SELECT avg(bonus) AS avg_bonus, avg_bonus * 1.0 AS new_avg_bonus, avg(salary) " + + s"FROM $testTable GROUP BY dept ORDER BY new_avg_bonus"), + Row(1100, 1100, 9500.0) :: Row(1200, 1200, 12000) :: Row(1250, 1250, 11000) :: Nil + ) } - test("Lateral alias chaining - Project") { + test("Lateral alias chaining") { + // Project checkAnswer( sql( s""" @@ -323,5 +623,148 @@ class LateralColumnAliasSuite extends QueryTest with SharedSparkSession { sql("SELECT 1 AS a, a + 1 AS b, b - 1, b + 1 AS c, c + 1 AS d, d - a AS e, e + 1"), Row(1, 2, 1, 3, 4, 3, 4) ) + + // Aggregate + checkAnswer( + sql( + s""" + |SELECT + | dept, + | sum(salary) AS salary_sum, + | salary_sum + sum(bonus) AS salary_total, + | salary_total * 1.5 AS new_total, + | new_total - salary_sum + |FROM $testTable + |GROUP BY dept + |ORDER BY dept + |""".stripMargin), + Row(1, 19000, 21200, 31800.0, 12800.0) :: + Row(2, 22000, 24500, 36750.0, 14750.0) :: + Row(6, 12000, 13200, 19800.0, 7800.0) :: Nil + ) + } + + test("non-deterministic expression as LCA is evaluated only once") { + val querySuffixes = Seq(s"FROM $testTable", s"FROM $testTable GROUP BY dept") + querySuffixes.foreach { querySuffix => + sql(s"SELECT dept, rand(0) AS r, r $querySuffix").collect().toSeq.foreach { row => + assert(QueryTest.compare(row(1), row(2))) + } + sql(s"SELECT dept + rand(0) AS r, r $querySuffix").collect().toSeq.foreach { row => + assert(QueryTest.compare(row(0), row(1))) + } + } + sql(s"SELECT avg(salary) + rand(0) AS r, r ${querySuffixes(1)}").collect().toSeq.foreach { + row => assert(QueryTest.compare(row(0), row(1))) + } + } + + test("Case insensitive lateral column alias") { + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { + checkAnswer( + sql(s"SELECT salary AS new_salary, New_Salary + 1 FROM $testTable WHERE name = 'jen'"), + Row(12000, 12001)) + checkAnswer( + sql( + s""" + |SELECT avg(salary) AS AVG_SALARY, avg_salary + avg(bonus) + |FROM $testTable + |GROUP BY dept + |HAVING dept = 1 + |""".stripMargin), + Row(9500, 10600)) + } + } + + test("Attribute cannot be resolved by LCA remain unresolved") { + assert(intercept[AnalysisException] { + sql(s"SELECT dept AS d, d AS new_dept, new_dep + 1 AS newer_dept FROM $testTable") + }.getErrorClass == "UNRESOLVED_COLUMN.WITH_SUGGESTION") + + assert(intercept[AnalysisException] { + sql(s"SELECT count(name) AS cnt, cnt + 1, count(unresovled) FROM $testTable GROUP BY dept") + }.getErrorClass == "UNRESOLVED_COLUMN.WITH_SUGGESTION") + + assert(intercept[AnalysisException] { + sql(s"SELECT * FROM range(1, 7) WHERE (" + + s"SELECT id2 FROM (SELECT 1 AS id, other_id + 1 AS id2)) > 5") + }.getErrorClass == "UNRESOLVED_COLUMN.WITHOUT_SUGGESTION") + } + + test("Pushed-down aggregateExpressions should have no duplicates") { + val query = s""" + |SELECT dept, avg(salary) AS a, a + avg(bonus), dept + 1, + | concat(string(dept), string(avg(bonus))), avg(salary) + |FROM $testTable + |GROUP BY dept + |HAVING dept = 2 + |""".stripMargin + val analyzedPlan = sql(query).queryExecution.analyzed + analyzedPlan.collect { + case Aggregate(_, aggregateExpressions, _) => + val extracted = aggregateExpressions.collect { + case Alias(child, _) => child + case a: Attribute => a + } + val expressionSet = ExpressionSet(extracted) + assert( + extracted.size == expressionSet.size, + "The pushed-down aggregateExpressions in Aggregate should have no duplicates " + + s"after extracted from Alias. Current aggregateExpressions: $aggregateExpressions") + } + } + + test("Aggregate expressions not eligible to lift up, throws same error as inline") { + def checkSameMissingAggregationError(q1: String, q2: String, expressionParam: String): Unit = { + Seq(q1, q2).foreach { query => + val e = intercept[AnalysisException] { sql(query) } + assert(e.getErrorClass == "MISSING_AGGREGATION") + assert(e.messageParameters.get("expression").exists(_ == expressionParam)) + } + } + + val suffix = s"FROM $testTable GROUP BY dept" + checkSameMissingAggregationError( + s"SELECT dept AS a, dept, salary $suffix", + s"SELECT dept AS a, a, salary $suffix", + "\"salary\"") + checkSameMissingAggregationError( + s"SELECT dept AS a, dept + salary $suffix", + s"SELECT dept AS a, a + salary $suffix", + "\"salary\"") + checkSameMissingAggregationError( + s"SELECT avg(salary) AS a, avg(salary) + bonus $suffix", + s"SELECT avg(salary) AS a, a + bonus $suffix", + "\"bonus\"") + checkSameMissingAggregationError( + s"SELECT dept AS a, dept, avg(salary) + bonus + 10 $suffix", + s"SELECT dept AS a, a, avg(salary) + bonus + 10 $suffix", + "\"bonus\"") + checkSameMissingAggregationError( + s"SELECT avg(salary) AS a, avg(salary), dept FROM $testTable GROUP BY dept + 10", + s"SELECT avg(salary) AS a, a, dept FROM $testTable GROUP BY dept + 10", + "\"dept\"") + checkSameMissingAggregationError( + s"SELECT avg(salary) AS a, avg(salary) + dept + 10 FROM $testTable GROUP BY dept + 10", + s"SELECT avg(salary) AS a, a + dept + 10 FROM $testTable GROUP BY dept + 10", + "\"dept\"") + Seq( + s"SELECT dept AS a, dept, " + + s"(SELECT count(col) FROM VALUES (1), (2) AS data(col) WHERE col = dept) $suffix", + s"SELECT dept AS a, a, " + + s"(SELECT count(col) FROM VALUES (1), (2) AS data(col) WHERE col = dept) $suffix" + ).foreach { query => + val e = intercept[AnalysisException] { sql(query) } + assert(e.getErrorClass == "_LEGACY_ERROR_TEMP_2423") } + + // one exception: no longer throws NESTED_AGGREGATE_FUNCTION but UNSUPPORTED_FEATURE + checkError( + exception = intercept[AnalysisException] { + sql(s"SELECT avg(salary) AS a, avg(a) FROM $testTable GROUP BY dept") + }, + errorClass = "UNSUPPORTED_FEATURE.LATERAL_COLUMN_ALIAS_IN_AGGREGATE_FUNC", + sqlState = "0A000", + parameters = Map("lca" -> "`a`", "aggFunc" -> "\"avg(lateralAliasReference(a))\"") + ) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index 0bb5e5230c188..22cc4fd46cbd8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -360,7 +360,7 @@ object QueryTest extends Assertions { None } - private def compare(obj1: Any, obj2: Any): Boolean = (obj1, obj2) match { + def compare(obj1: Any, obj2: Any): Boolean = (obj1, obj2) match { case (null, null) => true case (null, _) => false case (_, null) => false