Skip to content

Commit

Permalink
[SPARK-32268][SQL][TESTS][FOLLOW-UP] Use function registry in the Spa…
Browse files Browse the repository at this point in the history
…rkSession

### What changes were proposed in this pull request?

This PR proposes:
1. Use the function registry in the Spark Session being used
2. Move function registration into `beforeAll`

### Why are the changes needed?

Registration of the function without `beforeAll` at `builtin` can affect other tests. See also https://lists.apache.org/thread/jp0ccqv10ht716g9xldm2ohdv3mpmmz1.

### Does this PR introduce _any_ user-facing change?

No, test-only.

### How was this patch tested?

Unittests fixed.

Closes #36576 from HyukjinKwon/SPARK-32268-followup.

Authored-by: Hyukjin Kwon <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
  • Loading branch information
HyukjinKwon committed May 17, 2022
1 parent 6d74557 commit c5351f8
Showing 1 changed file with 18 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
package org.apache.spark.sql

import org.apache.spark.sql.catalyst.FunctionIdentifier
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.BloomFilterAggregate
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec
Expand All @@ -35,23 +34,26 @@ class BloomFilterAggregateQuerySuite extends QueryTest with SharedSparkSession {
val funcId_bloom_filter_agg = new FunctionIdentifier("bloom_filter_agg")
val funcId_might_contain = new FunctionIdentifier("might_contain")

// Register 'bloom_filter_agg' to builtin.
FunctionRegistry.builtin.registerFunction(funcId_bloom_filter_agg,
new ExpressionInfo(classOf[BloomFilterAggregate].getName, "bloom_filter_agg"),
(children: Seq[Expression]) => children.size match {
case 1 => new BloomFilterAggregate(children.head)
case 2 => new BloomFilterAggregate(children.head, children(1))
case 3 => new BloomFilterAggregate(children.head, children(1), children(2))
})

// Register 'might_contain' to builtin.
FunctionRegistry.builtin.registerFunction(funcId_might_contain,
new ExpressionInfo(classOf[BloomFilterMightContain].getName, "might_contain"),
(children: Seq[Expression]) => BloomFilterMightContain(children.head, children(1)))
override def beforeAll(): Unit = {
super.beforeAll()
// Register 'bloom_filter_agg' to builtin.
spark.sessionState.functionRegistry.registerFunction(funcId_bloom_filter_agg,
new ExpressionInfo(classOf[BloomFilterAggregate].getName, "bloom_filter_agg"),
(children: Seq[Expression]) => children.size match {
case 1 => new BloomFilterAggregate(children.head)
case 2 => new BloomFilterAggregate(children.head, children(1))
case 3 => new BloomFilterAggregate(children.head, children(1), children(2))
})

// Register 'might_contain' to builtin.
spark.sessionState.functionRegistry.registerFunction(funcId_might_contain,
new ExpressionInfo(classOf[BloomFilterMightContain].getName, "might_contain"),
(children: Seq[Expression]) => BloomFilterMightContain(children.head, children(1)))
}

override def afterAll(): Unit = {
FunctionRegistry.builtin.dropFunction(funcId_bloom_filter_agg)
FunctionRegistry.builtin.dropFunction(funcId_might_contain)
spark.sessionState.functionRegistry.dropFunction(funcId_bloom_filter_agg)
spark.sessionState.functionRegistry.dropFunction(funcId_might_contain)
super.afterAll()
}

Expand Down

0 comments on commit c5351f8

Please sign in to comment.