From 1b464ec631c81c4b99833ba81d3c556f05bc7d00 Mon Sep 17 00:00:00 2001 From: Nghia Truong Date: Sat, 4 Nov 2023 21:17:32 -0700 Subject: [PATCH] Extract common code Signed-off-by: Nghia Truong --- .../spark/rapids/GpuParquetFileFormat.scala | 24 +++++++++---------- .../nvidia/spark/rapids/GpuParquetScan.scala | 6 ++--- .../spark/rapids/datetimeRebaseUtils.scala | 19 ++++++++------- 3 files changed, 25 insertions(+), 24 deletions(-) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetFileFormat.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetFileFormat.scala index 827666a24cd..7b96c6ddf5e 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetFileFormat.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetFileFormat.scala @@ -126,8 +126,7 @@ object GpuParquetFileFormat { if (schemaHasTimestamps) { meta.willNotWorkOnGpu("LEGACY rebase mode for int96 timestamps is not supported") } - case other => meta.willNotWorkOnGpu(s"Invalid datetime rebase mode from config: $other " + - "(must be either 'EXCEPTION', 'LEGACY', or 'CORRECTED')") + case other => meta.willNotWorkOnGpu(DateTimeRebaseUtils.invalidRebaseModeMessage(other)) } SparkShimImpl.parquetRebaseWrite(sqlConf) match { @@ -139,8 +138,7 @@ object GpuParquetFileFormat { s"session: ${SQLConf.get.sessionLocalTimeZone}). " + " Set both of the timezones to UTC to enable LEGACY rebase support.") } - case other => meta.willNotWorkOnGpu(s"Invalid datetime rebase mode from config: $other " + - "(must be either 'EXCEPTION', 'LEGACY', or 'CORRECTED')") + case other => meta.willNotWorkOnGpu(DateTimeRebaseUtils.invalidRebaseModeMessage(other)) } if (meta.canThisBeReplaced) { @@ -191,9 +189,11 @@ class GpuParquetFileFormat extends ColumnarFileFormat with Logging { val conf = ContextUtil.getConfiguration(job) val outputTimestampType = sqlConf.parquetOutputTimestampType - val dateTimeRebaseMode = sparkSession.sqlContext.getConf(SparkShimImpl.parquetRebaseWriteKey) + val dateTimeRebaseMode = DateTimeRebaseUtils.getRebaseModeFromName( + sparkSession.sqlContext.getConf(SparkShimImpl.parquetRebaseWriteKey)) val timestampRebaseMode = if (outputTimestampType.equals(ParquetOutputTimestampType.INT96)) { - sparkSession.sqlContext.getConf(SparkShimImpl.int96ParquetRebaseWriteKey) + DateTimeRebaseUtils.getRebaseModeFromName( + sparkSession.sqlContext.getConf(SparkShimImpl.int96ParquetRebaseWriteKey)) } else { dateTimeRebaseMode } @@ -300,19 +300,19 @@ class GpuParquetWriter( dataSchema: StructType, compressionType: CompressionType, outputTimestampType: String, - dateRebaseMode: String, - timestampRebaseMode: String, + dateRebaseMode: DateTimeRebaseMode, + timestampRebaseMode: DateTimeRebaseMode, context: TaskAttemptContext, parquetFieldIdEnabled: Boolean) extends ColumnarOutputWriter(context, dataSchema, "Parquet", true) { override def throwIfRebaseNeededInExceptionMode(batch: ColumnarBatch): Unit = { val cols = GpuColumnVector.extractBases(batch) cols.foreach { col => - if (dateRebaseMode.equals("EXCEPTION") && + if (dateRebaseMode == DateTimeRebaseException && DateTimeRebaseUtils.isDateRebaseNeededInWrite(col)) { throw DataSourceUtils.newRebaseExceptionInWrite("Parquet") } - else if (timestampRebaseMode.equals("EXCEPTION") && + else if (timestampRebaseMode == DateTimeRebaseException && DateTimeRebaseUtils.isTimeRebaseNeededInWrite(col)) { throw DataSourceUtils.newRebaseExceptionInWrite("Parquet") } @@ -333,14 +333,14 @@ class GpuParquetWriter( ColumnCastUtil.deepTransform(cv, Some(dt)) { case (cv, _) if cv.getType.isTimestampType => if(cv.getType == DType.TIMESTAMP_DAYS) { - if (dateRebaseMode.equals("LEGACY")) { + if (dateRebaseMode == DateTimeRebaseLegacy) { DateTimeRebase.rebaseGregorianToJulian(cv) } else { cv.copyToColumnVector() } } else { /* timestamp */ val typeMillis = ParquetOutputTimestampType.TIMESTAMP_MILLIS.toString - if (timestampRebaseMode.equals("LEGACY")) { + if (timestampRebaseMode == DateTimeRebaseLegacy) { val rebasedTimestampAsMicros = if(cv.getType == DType.TIMESTAMP_MICROSECONDS) { DateTimeRebase.rebaseGregorianToJulian(cv) } else { diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetScan.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetScan.scala index 34fa5ab7019..6c4758ad6d7 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetScan.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetScan.scala @@ -215,8 +215,7 @@ object GpuParquetScan { if (schemaHasTimestamps) { meta.willNotWorkOnGpu("LEGACY rebase mode for dates and timestamps is not supported") } - case other => meta.willNotWorkOnGpu(s"Invalid datetime rebase mode from config: $other " + - "(must be either 'EXCEPTION', 'LEGACY', or 'CORRECTED')") + case other => meta.willNotWorkOnGpu(DateTimeRebaseUtils.invalidRebaseModeMessage(other)) } sqlConf.get(SparkShimImpl.parquetRebaseReadKey) match { @@ -228,8 +227,7 @@ object GpuParquetScan { if (schemaHasDates || schemaHasTimestamps) { meta.willNotWorkOnGpu("LEGACY rebase mode for dates and timestamps is not supported") } - case other => meta.willNotWorkOnGpu(s"Invalid datetime rebase mode from config: $other " + - "(must be either 'EXCEPTION', 'LEGACY', or 'CORRECTED')") + case other => meta.willNotWorkOnGpu(DateTimeRebaseUtils.invalidRebaseModeMessage(other)) } } diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/datetimeRebaseUtils.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/datetimeRebaseUtils.scala index beea39b6ed2..25ed4fe1751 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/datetimeRebaseUtils.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/datetimeRebaseUtils.scala @@ -44,6 +44,16 @@ case object DateTimeRebaseLegacy extends DateTimeRebaseMode("LEGACY") case object DateTimeRebaseCorrected extends DateTimeRebaseMode("CORRECTED") object DateTimeRebaseUtils { + def invalidRebaseModeMessage(name: String): String = + s"Invalid datetime rebase mode: $name (must be either 'EXCEPTION', 'LEGACY', or 'CORRECTED')" + + def getRebaseModeFromName(name: String): DateTimeRebaseMode = name match { + case DateTimeRebaseException.value => DateTimeRebaseException + case DateTimeRebaseLegacy.value => DateTimeRebaseLegacy + case DateTimeRebaseCorrected.value => DateTimeRebaseCorrected + case _ => throw new IllegalArgumentException(invalidRebaseModeMessage(name)) + } + // Copied from Spark private val SPARK_VERSION_METADATA_KEY = "org.apache.spark.version" private val SPARK_LEGACY_DATETIME_METADATA_KEY = "org.apache.spark.legacyDateTime" @@ -66,14 +76,7 @@ object DateTimeRebaseUtils { } else { DateTimeRebaseCorrected } - }.getOrElse(modeByConfig match { - case DateTimeRebaseException.value => DateTimeRebaseException - case DateTimeRebaseLegacy.value => DateTimeRebaseLegacy - case DateTimeRebaseCorrected.value => DateTimeRebaseCorrected - case _ => throw new IllegalArgumentException( - s"Invalid datetime rebase mode from config: $modeByConfig " + - "(must be either 'EXCEPTION', 'LEGACY', or 'CORRECTED')") - }) + }.getOrElse(getRebaseModeFromName(modeByConfig)) // Check the timezone of the file if the mode is LEGACY. if (mode == DateTimeRebaseLegacy) {