Skip to content

Commit

Permalink
Move file and rename class
Browse files Browse the repository at this point in the history
  • Loading branch information
ttnghia committed Nov 4, 2023
1 parent 29df7cd commit 1b5112d
Show file tree
Hide file tree
Showing 5 changed files with 15 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

import scala.collection.Seq;

import com.nvidia.spark.rapids.DateTimeRebaseLegacy$;
import com.nvidia.spark.rapids.GpuMetric;
import com.nvidia.spark.rapids.GpuParquetUtils;
import com.nvidia.spark.rapids.ParquetPartitionReader;
Expand Down Expand Up @@ -139,8 +140,8 @@ public org.apache.iceberg.io.CloseableIterator<ColumnarBatch> iterator() {
new Path(input.location()), clippedBlocks, fileReadSchema, caseSensitive,
partReaderSparkSchema, debugDumpPrefix, debugDumpAlways,
maxBatchSizeRows, maxBatchSizeBytes, targetBatchSizeBytes, useChunkedReader, metrics,
"CORRECTED", // dateRebaseMode
"CORRECTED", // timestampRebaseMode
DateTimeRebaseLegacy$.MODULE$.toString(), // dateRebaseMode
DateTimeRebaseLegacy$.MODULE$.toString(), // timestampRebaseMode
true, // hasInt96Timestamps
false // useFieldId
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ package com.nvidia.spark.rapids
import java.time.ZoneId

import ai.rapids.cudf._
import com.nvidia.spark.DateTimeRebaseHelper
import com.nvidia.spark.rapids.Arm.withResource
import com.nvidia.spark.rapids.RapidsPluginImplicits.AutoCloseableProducingArray
import com.nvidia.spark.rapids.jni.DateTimeRebase
Expand Down Expand Up @@ -311,11 +310,11 @@ class GpuParquetWriter(
val cols = GpuColumnVector.extractBases(batch)
cols.foreach { col =>
if (dateRebaseMode.equals("EXCEPTION") &&
DateTimeRebaseHelper.isDateRebaseNeededInWrite(col)) {
DateTimeRebaseUtils.isDateRebaseNeededInWrite(col)) {
throw DataSourceUtils.newRebaseExceptionInWrite("Parquet")
}
else if (timestampRebaseMode.equals("EXCEPTION") &&
DateTimeRebaseHelper.isTimeRebaseNeededInWrite(col)) {
DateTimeRebaseUtils.isTimeRebaseNeededInWrite(col)) {
throw DataSourceUtils.newRebaseExceptionInWrite("Parquet")
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ import scala.collection.mutable.ArrayBuffer
import scala.language.implicitConversions

import ai.rapids.cudf._
import com.nvidia.spark.DateTimeRebaseHelper
import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource}
import com.nvidia.spark.rapids.GpuMetric._
import com.nvidia.spark.rapids.ParquetPartitionReader.{CopyRange, LocalCopy}
Expand Down Expand Up @@ -163,11 +162,11 @@ object GpuParquetScan {
(0 until table.getNumberOfColumns).foreach { i =>
val col = table.getColumn(i)
if (dateRebaseMode.equals("EXCEPTION") &&
DateTimeRebaseHelper.isDateRebaseNeededInRead(col)) {
DateTimeRebaseUtils.isDateRebaseNeededInRead(col)) {
throw DataSourceUtils.newRebaseExceptionInRead("Parquet")
}
else if (timestampRebaseMode.equals("EXCEPTION") &&
DateTimeRebaseHelper.isTimeRebaseNeededInRead(col)) {
DateTimeRebaseUtils.isTimeRebaseNeededInRead(col)) {
throw DataSourceUtils.newRebaseExceptionInRead("Parquet")
}
}
Expand Down Expand Up @@ -740,11 +739,11 @@ private case class GpuParquetFileFilterHandler(
(clipped, clippedSchema)
}

val dateRebaseModeForThisFile = DateTimeRebaseHelper.datetimeRebaseMode(
val dateRebaseModeForThisFile = DateTimeRebaseUtils.datetimeRebaseMode(
footer.getFileMetaData.getKeyValueMetaData.get, datetimeRebaseMode)
val hasInt96Timestamps = isParquetTimeInInt96(fileSchema)
val timestampRebaseModeForThisFile = if (hasInt96Timestamps) {
DateTimeRebaseHelper.int96RebaseMode(
DateTimeRebaseUtils.int96RebaseMode(
footer.getFileMetaData.getKeyValueMetaData.get, int96RebaseMode)
} else {
dateRebaseModeForThisFile
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
* limitations under the License.
*/

package com.nvidia.spark
package com.nvidia.spark.rapids

import ai.rapids.cudf.{ColumnView, DType, Scalar}
import com.nvidia.spark.rapids.Arm.withResource
Expand Down Expand Up @@ -43,7 +43,7 @@ case object DateTimeRebaseLegacy extends DateTimeRebaseMode
*/
case object DateTimeRebaseCorrected extends DateTimeRebaseMode

object DateTimeRebaseHelper {
object DateTimeRebaseUtils {
// Copied from Spark
private val SPARK_VERSION_METADATA_KEY = "org.apache.spark.version"
private val SPARK_LEGACY_DATETIME_METADATA_KEY = "org.apache.spark.legacyDateTime"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,22 +17,21 @@
package com.nvidia.spark.rapids

import ai.rapids.cudf.ColumnVector
import com.nvidia.spark.DateTimeRebaseHelper
import com.nvidia.spark.rapids.Arm.withResource
import org.scalatest.funsuite.AnyFunSuite

class RebaseHelperSuite extends AnyFunSuite {
test("all null timestamp days column rebase check") {
withResource(ColumnVector.timestampDaysFromBoxedInts(null, null, null)) { c =>
assertResult(false)(DateTimeRebaseHelper.isDateRebaseNeededInWrite(c))
assertResult(false)(DateTimeRebaseHelper.isDateRebaseNeededInRead(c))
assertResult(false)(DateTimeRebaseUtils.isDateRebaseNeededInWrite(c))
assertResult(false)(DateTimeRebaseUtils.isDateRebaseNeededInRead(c))
}
}

test("all null timestamp microseconds column rebase check") {
withResource(ColumnVector.timestampMicroSecondsFromBoxedLongs(null, null, null)) { c =>
assertResult(false)(DateTimeRebaseHelper.isTimeRebaseNeededInWrite(c))
assertResult(false)(DateTimeRebaseHelper.isTimeRebaseNeededInRead(c))
assertResult(false)(DateTimeRebaseUtils.isTimeRebaseNeededInWrite(c))
assertResult(false)(DateTimeRebaseUtils.isTimeRebaseNeededInRead(c))
}
}
}

0 comments on commit 1b5112d

Please sign in to comment.