From 5f64e80843ae47746b2999b4b277ecc622516cd2 Mon Sep 17 00:00:00 2001 From: Stefan Kandic Date: Wed, 9 Oct 2024 10:30:55 +0900 Subject: [PATCH] [SPARK-49895][SQL] Improve error when encountering trailing comma in SELECT clause ### What changes were proposed in this pull request? Introduced a specific error message for cases where a trailing comma appears at the end of the SELECT clause. ### Why are the changes needed? The previous error message was unclear and often pointed to an incorrect location in the query, leading to confusion. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Unit tests. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48370 from stefankandic/fixTrailingComma. Lead-authored-by: Stefan Kandic Co-authored-by: Hyukjin Kwon Signed-off-by: Hyukjin Kwon --- .../resources/error/error-conditions.json | 6 ++ .../sql/catalyst/analysis/Analyzer.scala | 13 ++++- .../sql/catalyst/analysis/CheckAnalysis.scala | 37 ++++++++++++ .../sql/errors/QueryCompilationErrors.scala | 8 +++ .../errors/QueryCompilationErrorsSuite.scala | 56 +++++++++++++++++++ 5 files changed, 118 insertions(+), 2 deletions(-) diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index 8100f0580b21f..1b7f42e105077 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -4497,6 +4497,12 @@ ], "sqlState" : "428EK" }, + "TRAILING_COMMA_IN_SELECT" : { + "message" : [ + "Trailing comma detected in SELECT clause. Remove the trailing comma before the FROM clause." + ], + "sqlState" : "42601" + }, "TRANSPOSE_EXCEED_ROW_LIMIT" : { "message" : [ "Number of rows exceeds the allowed limit of for TRANSPOSE. If this was intended, set to at least the current row count." diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index b2e9115dd512f..5d41c07b47842 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1591,7 +1591,11 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor // If the projection list contains Stars, expand it. case p: Project if containsStar(p.projectList) => - p.copy(projectList = buildExpandedProjectList(p.projectList, p.child)) + val expanded = p.copy(projectList = buildExpandedProjectList(p.projectList, p.child)) + if (expanded.projectList.size < p.projectList.size) { + checkTrailingCommaInSelect(expanded, starRemoved = true) + } + expanded // If the filter list contains Stars, expand it. case p: Filter if containsStar(Seq(p.condition)) => p.copy(expandStarExpression(p.condition, p.child)) @@ -1600,7 +1604,12 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor if (a.groupingExpressions.exists(_.isInstanceOf[UnresolvedOrdinal])) { throw QueryCompilationErrors.starNotAllowedWhenGroupByOrdinalPositionUsedError() } else { - a.copy(aggregateExpressions = buildExpandedProjectList(a.aggregateExpressions, a.child)) + val expanded = a.copy(aggregateExpressions = + buildExpandedProjectList(a.aggregateExpressions, a.child)) + if (expanded.aggregateExpressions.size < a.aggregateExpressions.size) { + checkTrailingCommaInSelect(expanded, starRemoved = true) + } + expanded } case c: CollectMetrics if containsStar(c.metrics) => c.copy(metrics = buildExpandedProjectList(c.metrics, c.child)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index b600f455f16ac..a4f424ba4b421 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -173,6 +173,36 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB ) } + /** + * Checks for errors in a `SELECT` clause, such as a trailing comma or an empty select list. + * + * @param plan The logical plan of the query. + * @param starRemoved Whether a '*' (wildcard) was removed from the select list. + * @throws AnalysisException if the select list is empty or ends with a trailing comma. + */ + protected def checkTrailingCommaInSelect( + plan: LogicalPlan, + starRemoved: Boolean = false): Unit = { + val exprList = plan match { + case proj: Project if proj.projectList.nonEmpty => + proj.projectList + case agg: Aggregate if agg.aggregateExpressions.nonEmpty => + agg.aggregateExpressions + case _ => + Seq.empty + } + + exprList.lastOption match { + case Some(Alias(UnresolvedAttribute(Seq(name)), _)) => + if (name.equalsIgnoreCase("FROM") && plan.exists(_.isInstanceOf[OneRowRelation])) { + if (exprList.size > 1 || starRemoved) { + throw QueryCompilationErrors.trailingCommaInSelectError(exprList.last.origin) + } + } + case _ => + } + } + def checkAnalysis(plan: LogicalPlan): Unit = { // We should inline all CTE relations to restore the original plan shape, as the analysis check // may need to match certain plan shapes. For dangling CTE relations, they will still be kept @@ -210,6 +240,13 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB val tblName = write.table.asInstanceOf[UnresolvedRelation].multipartIdentifier write.table.tableNotFound(tblName) + // We should check for trailing comma errors first, since we would get less obvious + // unresolved column errors if we do it bottom up + case proj: Project => + checkTrailingCommaInSelect(proj) + case agg: Aggregate => + checkTrailingCommaInSelect(agg) + case _ => } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala index 22cc001c0c78e..1f43b3dfa4a16 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala @@ -358,6 +358,14 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase with Compilat ) } + def trailingCommaInSelectError(origin: Origin): Throwable = { + new AnalysisException( + errorClass = "TRAILING_COMMA_IN_SELECT", + messageParameters = Map.empty, + origin = origin + ) + } + def unresolvedUsingColForJoinError( colName: String, suggestion: String, side: String): Throwable = { new AnalysisException( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryCompilationErrorsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryCompilationErrorsSuite.scala index 61b3489083a06..b4fdf50447458 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryCompilationErrorsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryCompilationErrorsSuite.scala @@ -979,6 +979,62 @@ class QueryCompilationErrorsSuite ) } + test("SPARK-49895: trailing comma in select statement") { + withTable("t1") { + sql(s"CREATE TABLE t1 (c1 INT, c2 INT) USING PARQUET") + + val queries = Seq( + "SELECT *? FROM t1", + "SELECT c1? FROM t1", + "SELECT c1? FROM t1 WHERE c1 = 1", + "SELECT c1? FROM t1 GROUP BY c1", + "SELECT *, RANK() OVER (ORDER BY c1)? FROM t1", + "SELECT c1? FROM t1 ORDER BY c1", + "WITH cte AS (SELECT c1? FROM t1) SELECT * FROM cte", + "WITH cte AS (SELECT c1 FROM t1) SELECT *? FROM cte", + "SELECT * FROM (SELECT c1? FROM t1)") + + queries.foreach { query => + val queryWithoutTrailingComma = query.replaceAll("\\?", "") + val queryWithTrailingComma = query.replaceAll("\\?", ",") + + sql(queryWithoutTrailingComma) + print(queryWithTrailingComma) + val exception = intercept[AnalysisException] { + sql(queryWithTrailingComma) + } + assert(exception.getErrorClass === "TRAILING_COMMA_IN_SELECT") + } + + val unresolvedColumnErrors = Seq( + "SELECT c3 FROM t1", + "SELECT from FROM t1", + "SELECT from FROM (SELECT 'a' as c1)", + "SELECT from AS col FROM t1", + "SELECT from AS from FROM t1", + "SELECT from from FROM t1") + unresolvedColumnErrors.foreach { query => + val exception = intercept[AnalysisException] { + sql(query) + } + assert(exception.getErrorClass === "UNRESOLVED_COLUMN.WITH_SUGGESTION") + } + + // sanity checks + withTable("from") { + sql(s"CREATE TABLE from (from INT) USING PARQUET") + + sql(s"SELECT from FROM from") + sql(s"SELECT from as from FROM from") + sql(s"SELECT from from FROM from from") + sql(s"SELECT c1, from FROM VALUES(1, 2) AS T(c1, from)") + + intercept[ParseException] { + sql("SELECT 1,") + } + } + } + } } class MyCastToString extends SparkUserDefinedFunction(