diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index 862f78702c4e6..89e30b83784cf 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -143,6 +143,25 @@ class SqlParser extends StandardTokenParsers with PackratParsers { case (e, i) => Alias(e, s"c$i")() } } + + /** Creates the aliases to the grouping expressions */ + protected def assignAliasesForGroups( + grpExprs: Seq[Expression], + projExprs: Seq[Expression]): Seq[NamedExpression] = { + grpExprs.zipWithIndex.map { + case (e, i) => + var aliasForGrp:NamedExpression = null + projExprs.foreach { + case Alias(pe,pi) if pe.fastEquals(e) => aliasForGrp = Alias(e, pi)() + case _ => + } + if (aliasForGrp == null) { + Alias(e, s"c$i")() + } else { + aliasForGrp + } + } + } protected lazy val query: Parser[LogicalPlan] = ( select * ( @@ -166,7 +185,7 @@ class SqlParser extends StandardTokenParsers with PackratParsers { val withFilter = f.map(f => Filter(f, base)).getOrElse(base) val withProjection = g.map {g => - Aggregate(assignAliases(g), assignAliases(p), withFilter) + Aggregate(assignAliasesForGroups(g,p), assignAliases(p), withFilter) }.getOrElse(Project(assignAliases(p), withFilter)) val withDistinct = d.map(_ => Distinct(withProjection)).getOrElse(withProjection) val withHaving = h.map(h => Filter(h, withDistinct)).getOrElse(withDistinct) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 08376eb5e5c4e..32ad9bd35f6f7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -672,4 +672,11 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { sql("SELECT CAST(TRUE AS STRING), CAST(FALSE AS STRING) FROM testData LIMIT 1"), ("true", "false") :: Nil) } + + test("SPARK-3371 Renaming a function expression with group by gives error") { + registerFunction("len", (s: String) => s.length) + checkAnswer( + sql("SELECT len(value) as temp FROM testData WHERE key = 1 group by len(value)"), + Seq(Seq("1"))) + } }