From 4bec3114c3cb6c49eff3c2c853e58ef239d0bacf Mon Sep 17 00:00:00 2001 From: Jiaan Geng Date: Mon, 10 Jul 2023 10:01:16 +0800 Subject: [PATCH] [SPARK-44342][SQL] Replace SQLContext with SparkSession for GenTPCDSData ### What changes were proposed in this pull request? The `SQLContext` is an old API for Spark SQL. But `GenTPCDSData` still use it directly. ### Why are the changes needed? Avoid use the legacy API for `GenTPCDSData`. ### Does this PR introduce _any_ user-facing change? 'No'. Just update the benchmark utils. ### How was this patch tested? Manual test by running: `build/sbt "sql/Test/runMain org.apache.spark.sql.GenTPCDSData --dsdgenDir ... --location ... --scaleFactor 1"` Closes #41900 from beliefer/SPARK-44342. Authored-by: Jiaan Geng Signed-off-by: Kent Yao --- .../org/apache/spark/sql/GenTPCDSData.scala | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/GenTPCDSData.scala b/sql/core/src/test/scala/org/apache/spark/sql/GenTPCDSData.scala index b645b65ee8b35..6d5a0dc759f97 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/GenTPCDSData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/GenTPCDSData.scala @@ -27,6 +27,7 @@ import org.apache.spark.SparkContext import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql.functions.{col, rpad} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{CharType, StringType, StructField, StructType, VarcharType} // The classes in this file are basically moved from https://github.com/databricks/spark-sql-perf @@ -120,7 +121,7 @@ class Dsdgen(dsdgenDir: String) extends Serializable { } } -class TPCDSTables(sqlContext: SQLContext, dsdgenDir: String, scaleFactor: Int) +class TPCDSTables(spark: SparkSession, dsdgenDir: String, scaleFactor: Int) extends TPCDSSchema with Logging with Serializable { private val dataGenerator = new Dsdgen(dsdgenDir) @@ -138,7 +139,7 @@ class TPCDSTables(sqlContext: SQLContext, dsdgenDir: String, scaleFactor: Int) private def df(numPartition: Int) = { val generatedData = dataGenerator.generate( - sqlContext.sparkContext, name, numPartition, scaleFactor) + spark.sparkContext, name, numPartition, scaleFactor) val rows = generatedData.mapPartitions { iter => iter.map { l => val values = l.split("\\|", -1).dropRight(1).map { v => @@ -154,7 +155,7 @@ class TPCDSTables(sqlContext: SQLContext, dsdgenDir: String, scaleFactor: Int) } val stringData = - sqlContext.createDataFrame( + spark.createDataFrame( rows, StructType(schema.fields.map(f => StructField(f.name, StringType)))) @@ -210,7 +211,7 @@ class TPCDSTables(sqlContext: SQLContext, dsdgenDir: String, scaleFactor: Int) |DISTRIBUTE BY | $partitionColumnString """.stripMargin - val grouped = sqlContext.sql(query) + val grouped = spark.sql(query) logInfo(s"Pre-clustering with partitioning columns with query $query.") grouped.write } else { @@ -223,9 +224,7 @@ class TPCDSTables(sqlContext: SQLContext, dsdgenDir: String, scaleFactor: Int) // datagen speed files will be truncated to maxRecordsPerFile value, so the final // result will be the same. val numRows = data.count - val maxRecordPerFile = Try { - sqlContext.getConf("spark.sql.files.maxRecordsPerFile").toInt - }.getOrElse(0) + val maxRecordPerFile = spark.conf.get(SQLConf.MAX_RECORDS_PER_FILE) if (maxRecordPerFile > 0 && numRows > maxRecordPerFile) { val numFiles = (numRows.toDouble/maxRecordPerFile).ceil.toInt @@ -244,7 +243,7 @@ class TPCDSTables(sqlContext: SQLContext, dsdgenDir: String, scaleFactor: Int) } logInfo(s"Generating table $name in database to $location with save mode $mode.") writer.save(location) - sqlContext.dropTempTable(tempTableName) + spark.catalog.dropTempView(tempTableName) } } @@ -429,7 +428,7 @@ object GenTPCDSData { .getOrCreate() val tables = new TPCDSTables( - spark.sqlContext, + spark, dsdgenDir = config.dsdgenDir, scaleFactor = config.scaleFactor)