diff --git a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/TestHoodieFileIndex.scala b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/TestHoodieFileIndex.scala index 69248fc2c237..a88d263e9dc7 100644 --- a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/TestHoodieFileIndex.scala +++ b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/TestHoodieFileIndex.scala @@ -86,13 +86,9 @@ class TestHoodieFileIndex extends HoodieSparkClientTestBase with ScalaAssertionS @BeforeEach override def setUp() { setTableName("hoodie_test") + super.setUp() initPath() - initSparkContexts() spark = sqlContext.sparkSession - initTestDataGenerator() - initFileSystem() - initMetaClient() - queryOpts = queryOpts ++ Map("path" -> basePath) } diff --git a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/TestHoodieParquetBloom.scala b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/TestHoodieParquetBloom.scala index 2e5e30362bb9..a6f3a0e7368b 100644 --- a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/TestHoodieParquetBloom.scala +++ b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/TestHoodieParquetBloom.scala @@ -19,53 +19,29 @@ package org.apache.hudi import org.apache.spark.sql._ import org.apache.spark.sql.hudi.HoodieSparkSessionExtension -import org.apache.spark.util.{AccumulatorV2} +import org.apache.spark.util.AccumulatorV2 import org.apache.spark.SparkContext - import org.apache.hudi.testutils.HoodieClientTestUtils.getSparkConfForTest import org.apache.hudi.DataSourceWriteOptions import org.apache.hudi.config.HoodieWriteConfig import org.apache.hudi.common.model.{HoodieTableType, WriteOperationType} - - -import org.junit.jupiter.api.Assertions.{assertEquals} -import org.junit.jupiter.api.{BeforeEach} +import org.apache.hudi.testutils.HoodieSparkClientTestBase +import org.junit.jupiter.api.Assertions.assertEquals +import org.junit.jupiter.api.BeforeEach import org.junit.jupiter.params.ParameterizedTest -import org.junit.jupiter.params.provider.{EnumSource} - -class TestHoodieParquetBloomFilter { - - var spark: SparkSession = _ - var sqlContext: SQLContext = _ - var sc: SparkContext = _ +import org.junit.jupiter.params.provider.EnumSource - def initSparkContext(): Unit = { - val sparkConf = getSparkConfForTest(getClass.getSimpleName) - - spark = SparkSession.builder() - .withExtensions(new HoodieSparkSessionExtension) - .config(sparkConf) - .getOrCreate() - - sc = spark.sparkContext - sc.setLogLevel("ERROR") - sqlContext = spark.sqlContext - } - - @BeforeEach - def setUp() { - initSparkContext() - } +class TestHoodieParquetBloomFilter extends HoodieSparkClientTestBase with ScalaAssertionSupport { @ParameterizedTest @EnumSource(value = classOf[WriteOperationType], names = Array("BULK_INSERT", "INSERT", "UPSERT", "INSERT_OVERWRITE")) def testBloomFilter(operation: WriteOperationType): Unit = { // setup hadoop conf with bloom col enabled - spark.sparkContext.hadoopConfiguration.set("parquet.bloom.filter.enabled#bloom_col", "true") - spark.sparkContext.hadoopConfiguration.set("parquet.bloom.filter.expected.ndv#bloom_col", "2") + jsc.hadoopConfiguration.set("parquet.bloom.filter.enabled#bloom_col", "true") + jsc.hadoopConfiguration.set("parquet.bloom.filter.expected.ndv#bloom_col", "2") // ensure nothing but bloom can trigger read skip - spark.sql("set parquet.filter.columnindex.enabled=false") - spark.sql("set parquet.filter.stats.enabled=false") + sparkSession.sql("set parquet.filter.columnindex.enabled=false") + sparkSession.sql("set parquet.filter.stats.enabled=false") val basePath = java.nio.file.Files.createTempDirectory("hoodie_bloom_source_path").toAbsolutePath.toString val opts = Map( @@ -75,7 +51,7 @@ class TestHoodieParquetBloomFilter { DataSourceWriteOptions.RECORDKEY_FIELD.key -> "_row_key", DataSourceWriteOptions.PARTITIONPATH_FIELD.key -> "partition" ) - val inputDF = spark.sql( + val inputDF = sparkSession.sql( """select '0' as _row_key, '1' as bloom_col, '2' as partition, '3' as ts |union |select '1', '2', '3', '4' @@ -86,19 +62,19 @@ class TestHoodieParquetBloomFilter { .save(basePath) val accu = new NumRowGroupsAcc - spark.sparkContext.register(accu) + sparkSession.sparkContext.register(accu) // this one shall skip partition scanning thanks to bloom when spark >=3 - spark.read.format("hudi").load(basePath).filter("bloom_col = '3'").foreachPartition((it: Iterator[Row]) => it.foreach(_ => accu.add(0))) + sparkSession.read.format("hudi").load(basePath).filter("bloom_col = '3'").foreachPartition((it: Iterator[Row]) => it.foreach(_ => accu.add(0))) assertEquals(if (currentSparkSupportParquetBloom()) 0 else 1, accu.value) // this one will trigger one partition scan - spark.read.format("hudi").load(basePath).filter("bloom_col = '2'").foreachPartition((it: Iterator[Row]) => it.foreach(_ => accu.add(0))) + sparkSession.read.format("hudi").load(basePath).filter("bloom_col = '2'").foreachPartition((it: Iterator[Row]) => it.foreach(_ => accu.add(0))) assertEquals(1, accu.value) } def currentSparkSupportParquetBloom(): Boolean = { - Integer.valueOf(spark.version.charAt(0)) >= 3 + Integer.valueOf(sparkSession.version.charAt(0)) >= 3 } }