diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala index 3ffd9f9d88750..f47391c049298 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala @@ -40,7 +40,11 @@ class EquivalentExpressions { * Returns true if there was already a matching expression. */ def addExpr(expr: Expression): Boolean = { - updateExprInMap(expr, equivalenceMap) + if (supportedExpression(expr)) { + updateExprInMap(expr, equivalenceMap) + } else { + false + } } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala index b16629f59aa2d..44d8ea3a112e1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala @@ -21,7 +21,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{BinaryType, DataType, IntegerType} +import org.apache.spark.sql.types.{BinaryType, DataType, IntegerType, ObjectType} class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHelper { test("Semantic equals and hash") { @@ -449,6 +449,22 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel assert(e2.getCommonSubexpressions.size == 1) assert(e2.getCommonSubexpressions.head == add) } + + test("SPARK-42851: Handle supportExpression consistently across add and get") { + val expr = { + val function = (lambda: Expression) => Add(lambda, Literal(1)) + val elementType = IntegerType + val colClass = classOf[Array[Int]] + val inputType = ObjectType(colClass) + val inputObject = BoundReference(0, inputType, nullable = true) + objects.MapObjects(function, inputObject, elementType, true, Option(colClass)) + } + val equivalence = new EquivalentExpressions + equivalence.addExpr(expr) + val hasMatching = equivalence.addExpr(expr) + val cseState = equivalence.getExprState(expr) + assert(hasMatching == cseState.isDefined) + } } case class CodegenFallbackExpression(child: Expression) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 737d31cc6e913..2ba9039166f48 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -1538,6 +1538,13 @@ class DataFrameAggregateSuite extends QueryTest ) checkAnswer(res, Row(1, 1, 1) :: Row(4, 1, 2) :: Nil) } + + test("SPARK-42851: common subexpression should consistently handle aggregate and result exprs") { + val res = sql( + "select max(transform(array(id), x -> x)), max(transform(array(id), x -> x)) from range(2)" + ) + checkAnswer(res, Row(Array(1), Array(1))) + } } case class B(c: Option[Double])