diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala index 743cb591b306f..dcb02ab8556ec 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala @@ -44,7 +44,8 @@ class SparkOptimizer( Batch("PartitionPruning", Once, PartitionPruning) :+ Batch("InjectRuntimeFilter", FixedPoint(1), - InjectRuntimeFilter) :+ + InjectRuntimeFilter, + RewritePredicateSubquery) :+ Batch("Pushdown Filters from PartitionPruning", fixedPoint, PushDownPredicates) :+ Batch("Cleanup filters that cannot be pushed down", Once, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/InjectRuntimeFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/InjectRuntimeFilterSuite.scala index a5e27fbfda42a..0da3667382c16 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/InjectRuntimeFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/InjectRuntimeFilterSuite.scala @@ -19,7 +19,8 @@ package org.apache.spark.sql import org.apache.spark.sql.catalyst.expressions.{Alias, BloomFilterMightContain, Literal} import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, BloomFilterAggregate} -import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LogicalPlan} +import org.apache.spark.sql.catalyst.plans.LeftSemi +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, Join, LogicalPlan} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.{SharedSparkSession, SQLTestUtils} import org.apache.spark.sql.types.{IntegerType, StructType} @@ -213,6 +214,15 @@ class InjectRuntimeFilterSuite extends QueryTest with SQLTestUtils with SharedSp super.afterAll() } + private def ensureLeftSemiJoinExists(plan: LogicalPlan): Unit = { + assert( + plan.find { + case j: Join if j.joinType == LeftSemi => true + case _ => false + }.isDefined + ) + } + def checkWithAndWithoutFeatureEnabled(query: String, testSemiJoin: Boolean, shouldReplace: Boolean): Unit = { var planDisabled: LogicalPlan = null @@ -234,6 +244,7 @@ class InjectRuntimeFilterSuite extends QueryTest with SQLTestUtils with SharedSp if (shouldReplace) { val normalizedEnabled = normalizePlan(normalizeExprIds(planEnabled)) val normalizedDisabled = normalizePlan(normalizeExprIds(planDisabled)) + ensureLeftSemiJoinExists(planEnabled) assert(normalizedEnabled != normalizedDisabled) } else { comparePlans(planDisabled, planEnabled)