From c0c52dd2eb06e9cd315bc5b9ff95763c4f61ca89 Mon Sep 17 00:00:00 2001 From: ulysses-you Date: Tue, 29 Mar 2022 21:33:04 +0800 Subject: [PATCH] [SPARK-32268][SQL][FOLLOWUP] Add RewritePredicateSubquery below the InjectRuntimeFilter ### What changes were proposed in this pull request? Add `RewritePredicateSubquery` below the `InjectRuntimeFilter` in `SparkOptimizer`. ### Why are the changes needed? It seems if the runtime use in-subquery to do the filter, it won't be converted to semi-join as the design said. This pr fixes the issue. ### Does this PR introduce _any_ user-facing change? No, not released ### How was this patch tested? Improve the test by adding: ensure the semi-join exists if the runtime filter use in-subquery code path. Closes #35998 from ulysses-you/SPARK-32268-FOllOWUP. Authored-by: ulysses-you Signed-off-by: Wenchen Fan --- .../apache/spark/sql/execution/SparkOptimizer.scala | 3 ++- .../apache/spark/sql/InjectRuntimeFilterSuite.scala | 13 ++++++++++++- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala index 743cb591b306f..dcb02ab8556ec 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala @@ -44,7 +44,8 @@ class SparkOptimizer( Batch("PartitionPruning", Once, PartitionPruning) :+ Batch("InjectRuntimeFilter", FixedPoint(1), - InjectRuntimeFilter) :+ + InjectRuntimeFilter, + RewritePredicateSubquery) :+ Batch("Pushdown Filters from PartitionPruning", fixedPoint, PushDownPredicates) :+ Batch("Cleanup filters that cannot be pushed down", Once, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/InjectRuntimeFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/InjectRuntimeFilterSuite.scala index a5e27fbfda42a..0da3667382c16 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/InjectRuntimeFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/InjectRuntimeFilterSuite.scala @@ -19,7 +19,8 @@ package org.apache.spark.sql import org.apache.spark.sql.catalyst.expressions.{Alias, BloomFilterMightContain, Literal} import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, BloomFilterAggregate} -import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LogicalPlan} +import org.apache.spark.sql.catalyst.plans.LeftSemi +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, Join, LogicalPlan} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.{SharedSparkSession, SQLTestUtils} import org.apache.spark.sql.types.{IntegerType, StructType} @@ -213,6 +214,15 @@ class InjectRuntimeFilterSuite extends QueryTest with SQLTestUtils with SharedSp super.afterAll() } + private def ensureLeftSemiJoinExists(plan: LogicalPlan): Unit = { + assert( + plan.find { + case j: Join if j.joinType == LeftSemi => true + case _ => false + }.isDefined + ) + } + def checkWithAndWithoutFeatureEnabled(query: String, testSemiJoin: Boolean, shouldReplace: Boolean): Unit = { var planDisabled: LogicalPlan = null @@ -234,6 +244,7 @@ class InjectRuntimeFilterSuite extends QueryTest with SQLTestUtils with SharedSp if (shouldReplace) { val normalizedEnabled = normalizePlan(normalizeExprIds(planEnabled)) val normalizedDisabled = normalizePlan(normalizeExprIds(planDisabled)) + ensureLeftSemiJoinExists(planEnabled) assert(normalizedEnabled != normalizedDisabled) } else { comparePlans(planDisabled, planEnabled)