Skip to content

Commit

Permalink
[SPARK-32268][SQL][FOLLOWUP] Add ColumnPruning in injectBloomFilter
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
Add `ColumnPruning` in `InjectRuntimeFilter.injectBloomFilter` to optimize the BoomFilter creation query.

### Why are the changes needed?
It seems BloomFilter subqueries injected by `InjectRuntimeFilter` will read as many columns as filterCreationSidePlan. This does not match "Only scan the required columns" as the design said. We can check this by a simple case in `InjectRuntimeFilterSuite`:
```scala
withSQLConf(SQLConf.RUNTIME_BLOOM_FILTER_ENABLED.key -> "true",
  SQLConf.RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD.key -> "3000",
  SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "2000") {
  val query = "select * from bf1 join bf2 on bf1.c1 = bf2.c2 where bf2.a2 = 62"
  sql(query).explain()
}
```
The reason is subqueries have not been optimized by `ColumnPruning`, and this pr will fix it.

### Does this PR introduce _any_ user-facing change?
No, not released

### How was this patch tested?
Improve the test by adding `columnPruningTakesEffect` to check the optimizedPlan of bloom filter join.

Closes apache#36047 from Flyangz/SPARK-32268-FOllOWUP.

Authored-by: Yang Liu <[email protected]>
Signed-off-by: Yuming Wang <[email protected]>
  • Loading branch information
Yang Liu authored and wangyum committed Apr 4, 2022
1 parent 41a8249 commit c98725a
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,8 @@ object InjectRuntimeFilter extends Rule[LogicalPlan] with PredicateHelper with J
}
val aggExp = AggregateExpression(bloomFilterAgg, Complete, isDistinct = false, None)
val alias = Alias(aggExp, "bloomFilter")()
val aggregate = ConstantFolding(Aggregate(Nil, Seq(alias), filterCreationSidePlan))
val aggregate =
ConstantFolding(ColumnPruning(Aggregate(Nil, Seq(alias), filterCreationSidePlan)))
val bloomFilterSubquery = ScalarSubquery(aggregate, Nil)
val filter = BloomFilterMightContain(bloomFilterSubquery,
new XxHash64(Seq(filterApplicationSideExp)))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,7 @@ class InjectRuntimeFilterSuite extends QueryTest with SQLTestUtils with SharedSp
planEnabled = sql(query).queryExecution.optimizedPlan
checkAnswer(sql(query), expectedAnswer)
if (shouldReplace) {
assert(!columnPruningTakesEffect(planEnabled))
assert(getNumBloomFilters(planEnabled) > getNumBloomFilters(planDisabled))
} else {
assert(getNumBloomFilters(planEnabled) == getNumBloomFilters(planDisabled))
Expand Down Expand Up @@ -288,6 +289,20 @@ class InjectRuntimeFilterSuite extends QueryTest with SQLTestUtils with SharedSp
numMightContains
}

def columnPruningTakesEffect(plan: LogicalPlan): Boolean = {
def takesEffect(plan: LogicalPlan): Boolean = {
val result = org.apache.spark.sql.catalyst.optimizer.ColumnPruning.apply(plan)
!result.fastEquals(plan)
}

plan.collectFirst {
case Filter(condition, _) if condition.collectFirst {
case subquery: org.apache.spark.sql.catalyst.expressions.ScalarSubquery
if takesEffect(subquery.plan) => true
}.nonEmpty => true
}.nonEmpty
}

def assertRewroteSemiJoin(query: String): Unit = {
checkWithAndWithoutFeatureEnabled(query, testSemiJoin = true, shouldReplace = true)
}
Expand Down

0 comments on commit c98725a

Please sign in to comment.