Skip to content

Commit

Permalink
[SPARK-44342][SQL] Replace SQLContext with SparkSession for GenTPCDSData
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
The `SQLContext` is an old API for Spark SQL.
But `GenTPCDSData` still use it directly.

### Why are the changes needed?
Avoid use the legacy API for `GenTPCDSData`.

### Does this PR introduce _any_ user-facing change?
'No'.
Just update the benchmark utils.

### How was this patch tested?
Manual test by running:
`build/sbt "sql/Test/runMain org.apache.spark.sql.GenTPCDSData --dsdgenDir ... --location ... --scaleFactor 1"`

Closes apache#41900 from beliefer/SPARK-44342.

Authored-by: Jiaan Geng <[email protected]>
Signed-off-by: Kent Yao <[email protected]>
  • Loading branch information
beliefer authored and yaooqinn committed Jul 10, 2023
1 parent 71703d6 commit 4bec311
Showing 1 changed file with 8 additions and 9 deletions.
17 changes: 8 additions & 9 deletions sql/core/src/test/scala/org/apache/spark/sql/GenTPCDSData.scala
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import org.apache.spark.SparkContext
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.functions.{col, rpad}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{CharType, StringType, StructField, StructType, VarcharType}

// The classes in this file are basically moved from https://github.com/databricks/spark-sql-perf
Expand Down Expand Up @@ -120,7 +121,7 @@ class Dsdgen(dsdgenDir: String) extends Serializable {
}
}

class TPCDSTables(sqlContext: SQLContext, dsdgenDir: String, scaleFactor: Int)
class TPCDSTables(spark: SparkSession, dsdgenDir: String, scaleFactor: Int)
extends TPCDSSchema with Logging with Serializable {

private val dataGenerator = new Dsdgen(dsdgenDir)
Expand All @@ -138,7 +139,7 @@ class TPCDSTables(sqlContext: SQLContext, dsdgenDir: String, scaleFactor: Int)

private def df(numPartition: Int) = {
val generatedData = dataGenerator.generate(
sqlContext.sparkContext, name, numPartition, scaleFactor)
spark.sparkContext, name, numPartition, scaleFactor)
val rows = generatedData.mapPartitions { iter =>
iter.map { l =>
val values = l.split("\\|", -1).dropRight(1).map { v =>
Expand All @@ -154,7 +155,7 @@ class TPCDSTables(sqlContext: SQLContext, dsdgenDir: String, scaleFactor: Int)
}

val stringData =
sqlContext.createDataFrame(
spark.createDataFrame(
rows,
StructType(schema.fields.map(f => StructField(f.name, StringType))))

Expand Down Expand Up @@ -210,7 +211,7 @@ class TPCDSTables(sqlContext: SQLContext, dsdgenDir: String, scaleFactor: Int)
|DISTRIBUTE BY
| $partitionColumnString
""".stripMargin
val grouped = sqlContext.sql(query)
val grouped = spark.sql(query)
logInfo(s"Pre-clustering with partitioning columns with query $query.")
grouped.write
} else {
Expand All @@ -223,9 +224,7 @@ class TPCDSTables(sqlContext: SQLContext, dsdgenDir: String, scaleFactor: Int)
// datagen speed files will be truncated to maxRecordsPerFile value, so the final
// result will be the same.
val numRows = data.count
val maxRecordPerFile = Try {
sqlContext.getConf("spark.sql.files.maxRecordsPerFile").toInt
}.getOrElse(0)
val maxRecordPerFile = spark.conf.get(SQLConf.MAX_RECORDS_PER_FILE)

if (maxRecordPerFile > 0 && numRows > maxRecordPerFile) {
val numFiles = (numRows.toDouble/maxRecordPerFile).ceil.toInt
Expand All @@ -244,7 +243,7 @@ class TPCDSTables(sqlContext: SQLContext, dsdgenDir: String, scaleFactor: Int)
}
logInfo(s"Generating table $name in database to $location with save mode $mode.")
writer.save(location)
sqlContext.dropTempTable(tempTableName)
spark.catalog.dropTempView(tempTableName)
}
}

Expand Down Expand Up @@ -429,7 +428,7 @@ object GenTPCDSData {
.getOrCreate()

val tables = new TPCDSTables(
spark.sqlContext,
spark,
dsdgenDir = config.dsdgenDir,
scaleFactor = config.scaleFactor)

Expand Down

0 comments on commit 4bec311

Please sign in to comment.