From c5351f85dec628a5c806893aa66777cbd77a4d65 Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Tue, 17 May 2022 23:05:48 +0900 Subject: [PATCH] [SPARK-32268][SQL][TESTS][FOLLOW-UP] Use function registry in the SparkSession ### What changes were proposed in this pull request? This PR proposes: 1. Use the function registry in the Spark Session being used 2. Move function registration into `beforeAll` ### Why are the changes needed? Registration of the function without `beforeAll` at `builtin` can affect other tests. See also https://lists.apache.org/thread/jp0ccqv10ht716g9xldm2ohdv3mpmmz1. ### Does this PR introduce _any_ user-facing change? No, test-only. ### How was this patch tested? Unittests fixed. Closes #36576 from HyukjinKwon/SPARK-32268-followup. Authored-by: Hyukjin Kwon Signed-off-by: Hyukjin Kwon --- .../sql/BloomFilterAggregateQuerySuite.scala | 34 ++++++++++--------- 1 file changed, 18 insertions(+), 16 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/BloomFilterAggregateQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/BloomFilterAggregateQuerySuite.scala index 7fc89ecc88ba3..05513cddccb86 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/BloomFilterAggregateQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/BloomFilterAggregateQuerySuite.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql import org.apache.spark.sql.catalyst.FunctionIdentifier -import org.apache.spark.sql.catalyst.analysis.FunctionRegistry import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.BloomFilterAggregate import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec @@ -35,23 +34,26 @@ class BloomFilterAggregateQuerySuite extends QueryTest with SharedSparkSession { val funcId_bloom_filter_agg = new FunctionIdentifier("bloom_filter_agg") val funcId_might_contain = new FunctionIdentifier("might_contain") - // Register 'bloom_filter_agg' to builtin. - FunctionRegistry.builtin.registerFunction(funcId_bloom_filter_agg, - new ExpressionInfo(classOf[BloomFilterAggregate].getName, "bloom_filter_agg"), - (children: Seq[Expression]) => children.size match { - case 1 => new BloomFilterAggregate(children.head) - case 2 => new BloomFilterAggregate(children.head, children(1)) - case 3 => new BloomFilterAggregate(children.head, children(1), children(2)) - }) - - // Register 'might_contain' to builtin. - FunctionRegistry.builtin.registerFunction(funcId_might_contain, - new ExpressionInfo(classOf[BloomFilterMightContain].getName, "might_contain"), - (children: Seq[Expression]) => BloomFilterMightContain(children.head, children(1))) + override def beforeAll(): Unit = { + super.beforeAll() + // Register 'bloom_filter_agg' to builtin. + spark.sessionState.functionRegistry.registerFunction(funcId_bloom_filter_agg, + new ExpressionInfo(classOf[BloomFilterAggregate].getName, "bloom_filter_agg"), + (children: Seq[Expression]) => children.size match { + case 1 => new BloomFilterAggregate(children.head) + case 2 => new BloomFilterAggregate(children.head, children(1)) + case 3 => new BloomFilterAggregate(children.head, children(1), children(2)) + }) + + // Register 'might_contain' to builtin. + spark.sessionState.functionRegistry.registerFunction(funcId_might_contain, + new ExpressionInfo(classOf[BloomFilterMightContain].getName, "might_contain"), + (children: Seq[Expression]) => BloomFilterMightContain(children.head, children(1))) + } override def afterAll(): Unit = { - FunctionRegistry.builtin.dropFunction(funcId_bloom_filter_agg) - FunctionRegistry.builtin.dropFunction(funcId_might_contain) + spark.sessionState.functionRegistry.dropFunction(funcId_bloom_filter_agg) + spark.sessionState.functionRegistry.dropFunction(funcId_might_contain) super.afterAll() }