Skip to content

Commit

Permalink
[SPARK-49895][SQL] Improve error when encountering trailing comma in …
Browse files Browse the repository at this point in the history
…SELECT clause

### What changes were proposed in this pull request?
Introduced a specific error message for cases where a trailing comma appears at the end of the SELECT clause.

### Why are the changes needed?
The previous error message was unclear and often pointed to an incorrect location in the query, leading to confusion.

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

Unit tests.

### Was this patch authored or co-authored using generative AI tooling?

No.

Closes #48370 from stefankandic/fixTrailingComma.

Lead-authored-by: Stefan Kandic <[email protected]>
Co-authored-by: Hyukjin Kwon <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
  • Loading branch information
stefankandic and HyukjinKwon committed Oct 9, 2024
1 parent 80ae411 commit 5f64e80
Show file tree
Hide file tree
Showing 5 changed files with 118 additions and 2 deletions.
6 changes: 6 additions & 0 deletions common/utils/src/main/resources/error/error-conditions.json
Original file line number Diff line number Diff line change
Expand Up @@ -4497,6 +4497,12 @@
],
"sqlState" : "428EK"
},
"TRAILING_COMMA_IN_SELECT" : {
"message" : [
"Trailing comma detected in SELECT clause. Remove the trailing comma before the FROM clause."
],
"sqlState" : "42601"
},
"TRANSPOSE_EXCEED_ROW_LIMIT" : {
"message" : [
"Number of rows exceeds the allowed limit of <maxValues> for TRANSPOSE. If this was intended, set <config> to at least the current row count."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1591,7 +1591,11 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor

// If the projection list contains Stars, expand it.
case p: Project if containsStar(p.projectList) =>
p.copy(projectList = buildExpandedProjectList(p.projectList, p.child))
val expanded = p.copy(projectList = buildExpandedProjectList(p.projectList, p.child))
if (expanded.projectList.size < p.projectList.size) {
checkTrailingCommaInSelect(expanded, starRemoved = true)
}
expanded
// If the filter list contains Stars, expand it.
case p: Filter if containsStar(Seq(p.condition)) =>
p.copy(expandStarExpression(p.condition, p.child))
Expand All @@ -1600,7 +1604,12 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
if (a.groupingExpressions.exists(_.isInstanceOf[UnresolvedOrdinal])) {
throw QueryCompilationErrors.starNotAllowedWhenGroupByOrdinalPositionUsedError()
} else {
a.copy(aggregateExpressions = buildExpandedProjectList(a.aggregateExpressions, a.child))
val expanded = a.copy(aggregateExpressions =
buildExpandedProjectList(a.aggregateExpressions, a.child))
if (expanded.aggregateExpressions.size < a.aggregateExpressions.size) {
checkTrailingCommaInSelect(expanded, starRemoved = true)
}
expanded
}
case c: CollectMetrics if containsStar(c.metrics) =>
c.copy(metrics = buildExpandedProjectList(c.metrics, c.child))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,36 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB
)
}

/**
* Checks for errors in a `SELECT` clause, such as a trailing comma or an empty select list.
*
* @param plan The logical plan of the query.
* @param starRemoved Whether a '*' (wildcard) was removed from the select list.
* @throws AnalysisException if the select list is empty or ends with a trailing comma.
*/
protected def checkTrailingCommaInSelect(
plan: LogicalPlan,
starRemoved: Boolean = false): Unit = {
val exprList = plan match {
case proj: Project if proj.projectList.nonEmpty =>
proj.projectList
case agg: Aggregate if agg.aggregateExpressions.nonEmpty =>
agg.aggregateExpressions
case _ =>
Seq.empty
}

exprList.lastOption match {
case Some(Alias(UnresolvedAttribute(Seq(name)), _)) =>
if (name.equalsIgnoreCase("FROM") && plan.exists(_.isInstanceOf[OneRowRelation])) {
if (exprList.size > 1 || starRemoved) {
throw QueryCompilationErrors.trailingCommaInSelectError(exprList.last.origin)
}
}
case _ =>
}
}

def checkAnalysis(plan: LogicalPlan): Unit = {
// We should inline all CTE relations to restore the original plan shape, as the analysis check
// may need to match certain plan shapes. For dangling CTE relations, they will still be kept
Expand Down Expand Up @@ -210,6 +240,13 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB
val tblName = write.table.asInstanceOf[UnresolvedRelation].multipartIdentifier
write.table.tableNotFound(tblName)

// We should check for trailing comma errors first, since we would get less obvious
// unresolved column errors if we do it bottom up
case proj: Project =>
checkTrailingCommaInSelect(proj)
case agg: Aggregate =>
checkTrailingCommaInSelect(agg)

case _ =>
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,14 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase with Compilat
)
}

def trailingCommaInSelectError(origin: Origin): Throwable = {
new AnalysisException(
errorClass = "TRAILING_COMMA_IN_SELECT",
messageParameters = Map.empty,
origin = origin
)
}

def unresolvedUsingColForJoinError(
colName: String, suggestion: String, side: String): Throwable = {
new AnalysisException(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -979,6 +979,62 @@ class QueryCompilationErrorsSuite
)
}

test("SPARK-49895: trailing comma in select statement") {
withTable("t1") {
sql(s"CREATE TABLE t1 (c1 INT, c2 INT) USING PARQUET")

val queries = Seq(
"SELECT *? FROM t1",
"SELECT c1? FROM t1",
"SELECT c1? FROM t1 WHERE c1 = 1",
"SELECT c1? FROM t1 GROUP BY c1",
"SELECT *, RANK() OVER (ORDER BY c1)? FROM t1",
"SELECT c1? FROM t1 ORDER BY c1",
"WITH cte AS (SELECT c1? FROM t1) SELECT * FROM cte",
"WITH cte AS (SELECT c1 FROM t1) SELECT *? FROM cte",
"SELECT * FROM (SELECT c1? FROM t1)")

queries.foreach { query =>
val queryWithoutTrailingComma = query.replaceAll("\\?", "")
val queryWithTrailingComma = query.replaceAll("\\?", ",")

sql(queryWithoutTrailingComma)
print(queryWithTrailingComma)
val exception = intercept[AnalysisException] {
sql(queryWithTrailingComma)
}
assert(exception.getErrorClass === "TRAILING_COMMA_IN_SELECT")
}

val unresolvedColumnErrors = Seq(
"SELECT c3 FROM t1",
"SELECT from FROM t1",
"SELECT from FROM (SELECT 'a' as c1)",
"SELECT from AS col FROM t1",
"SELECT from AS from FROM t1",
"SELECT from from FROM t1")
unresolvedColumnErrors.foreach { query =>
val exception = intercept[AnalysisException] {
sql(query)
}
assert(exception.getErrorClass === "UNRESOLVED_COLUMN.WITH_SUGGESTION")
}

// sanity checks
withTable("from") {
sql(s"CREATE TABLE from (from INT) USING PARQUET")

sql(s"SELECT from FROM from")
sql(s"SELECT from as from FROM from")
sql(s"SELECT from from FROM from from")
sql(s"SELECT c1, from FROM VALUES(1, 2) AS T(c1, from)")

intercept[ParseException] {
sql("SELECT 1,")
}
}
}
}
}

class MyCastToString extends SparkUserDefinedFunction(
Expand Down

0 comments on commit 5f64e80

Please sign in to comment.