Skip to content

Commit

Permalink
Support rebase checking for nested dates and timestamps (NVIDIA#9617)
Browse files Browse the repository at this point in the history
* Add check for nested types

* Recursively check for rebasing

* Extract common code

* Allow nested type in rebase check

* Enable nested timestamp in roundtrip test

* Fix another test

Signed-off-by: Nghia Truong <[email protected]>

* Update sql-plugin/src/main/scala/com/nvidia/spark/RebaseHelper.scala

* Remove comment

---------

Signed-off-by: Nghia Truong <[email protected]>
  • Loading branch information
ttnghia authored Nov 3, 2023
1 parent 614f873 commit 401d0d8
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 102 deletions.
41 changes: 10 additions & 31 deletions integration_tests/src/main/python/parquet_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,35 +312,14 @@ def test_parquet_pred_push_round_trip(spark_tmp_path, parquet_gen, read_func, v1

parquet_ts_write_options = ['INT96', 'TIMESTAMP_MICROS', 'TIMESTAMP_MILLIS']


# Once https://github.com/NVIDIA/spark-rapids/issues/1126 is fixed delete this test and merge it
# into test_ts_read_round_trip nested timestamps and dates are not supported right now.
@pytest.mark.parametrize('gen', [ArrayGen(TimestampGen(start=datetime(1900, 1, 1, tzinfo=timezone.utc))),
# Once https://github.com/NVIDIA/spark-rapids/issues/132 is fixed replace this with timestamp_gen
@pytest.mark.parametrize('gen', [TimestampGen(start=datetime(1900, 1, 1, tzinfo=timezone.utc)),
ArrayGen(TimestampGen(start=datetime(1900, 1, 1, tzinfo=timezone.utc))),
ArrayGen(ArrayGen(TimestampGen(start=datetime(1900, 1, 1, tzinfo=timezone.utc))))], ids=idfn)
@pytest.mark.parametrize('ts_write', parquet_ts_write_options)
@pytest.mark.parametrize('ts_rebase', ['CORRECTED', 'LEGACY'])
@pytest.mark.parametrize('reader_confs', reader_opt_confs)
@pytest.mark.parametrize('v1_enabled_list', ["", "parquet"])
@pytest.mark.xfail(reason='https://github.com/NVIDIA/spark-rapids/issues/1126')
def test_parquet_ts_read_round_trip_nested(gen, spark_tmp_path, ts_write, ts_rebase, v1_enabled_list, reader_confs):
data_path = spark_tmp_path + '/PARQUET_DATA'
with_cpu_session(
lambda spark : unary_op_df(spark, gen).write.parquet(data_path),
conf={'spark.sql.legacy.parquet.datetimeRebaseModeInWrite': ts_rebase,
'spark.sql.legacy.parquet.int96RebaseModeInWrite': ts_rebase,
'spark.sql.parquet.outputTimestampType': ts_write})
all_confs = copy_and_update(reader_confs, {'spark.sql.sources.useV1SourceList': v1_enabled_list})
assert_gpu_and_cpu_are_equal_collect(
lambda spark : spark.read.parquet(data_path),
conf=all_confs)

# Once https://github.com/NVIDIA/spark-rapids/issues/132 is fixed replace this with
# timestamp_gen
@pytest.mark.parametrize('gen', [TimestampGen(start=datetime(1900, 1, 1, tzinfo=timezone.utc))], ids=idfn)
@pytest.mark.parametrize('ts_write', parquet_ts_write_options)
@pytest.mark.parametrize('ts_rebase', ['CORRECTED', 'LEGACY'])
@pytest.mark.parametrize('reader_confs', reader_opt_confs)
@pytest.mark.parametrize('v1_enabled_list', ["", "parquet"])
def test_ts_read_round_trip(gen, spark_tmp_path, ts_write, ts_rebase, v1_enabled_list, reader_confs):
data_path = spark_tmp_path + '/PARQUET_DATA'
with_cpu_session(
Expand All @@ -358,10 +337,10 @@ def readParquetCatchException(spark, data_path):
df = spark.read.parquet(data_path).collect()
assert e_info.match(r".*SparkUpgradeException.*")

# Once https://github.com/NVIDIA/spark-rapids/issues/1126 is fixed nested timestamps and dates should be added in
# Once https://github.com/NVIDIA/spark-rapids/issues/132 is fixed replace this with
# timestamp_gen
@pytest.mark.parametrize('gen', [TimestampGen(start=datetime(1590, 1, 1, tzinfo=timezone.utc))], ids=idfn)
# Once https://github.com/NVIDIA/spark-rapids/issues/132 is fixed replace this with timestamp_gen
@pytest.mark.parametrize('gen', [TimestampGen(start=datetime(1590, 1, 1, tzinfo=timezone.utc)),
ArrayGen(TimestampGen(start=datetime(1590, 1, 1, tzinfo=timezone.utc))),
ArrayGen(ArrayGen(TimestampGen(start=datetime(1590, 1, 1, tzinfo=timezone.utc))))], ids=idfn)
@pytest.mark.parametrize('ts_write', parquet_ts_write_options)
@pytest.mark.parametrize('ts_rebase', ['LEGACY'])
@pytest.mark.parametrize('reader_confs', reader_opt_confs)
Expand Down Expand Up @@ -1003,7 +982,7 @@ def test_parquet_reading_from_unaligned_pages_basic_filters_with_nulls(spark_tmp


conf_for_parquet_aggregate_pushdown = {
"spark.sql.parquet.aggregatePushdown": "true",
"spark.sql.parquet.aggregatePushdown": "true",
"spark.sql.sources.useV1SourceList": ""
}

Expand Down Expand Up @@ -1490,15 +1469,15 @@ def test_parquet_read_count(spark_tmp_path):
def test_read_case_col_name(spark_tmp_path, read_func, v1_enabled_list, reader_confs, col_name):
all_confs = copy_and_update(reader_confs, {
'spark.sql.sources.useV1SourceList': v1_enabled_list})
gen_list =[('k0', LongGen(nullable=False, min_val=0, max_val=0)),
gen_list =[('k0', LongGen(nullable=False, min_val=0, max_val=0)),
('k1', LongGen(nullable=False, min_val=1, max_val=1)),
('k2', LongGen(nullable=False, min_val=2, max_val=2)),
('k3', LongGen(nullable=False, min_val=3, max_val=3)),
('v0', LongGen()),
('v1', LongGen()),
('v2', LongGen()),
('v3', LongGen())]

gen = StructGen(gen_list, nullable=False)
data_path = spark_tmp_path + '/PAR_DATA'
reader = read_func(data_path)
Expand Down
66 changes: 31 additions & 35 deletions sql-plugin/src/main/scala/com/nvidia/spark/RebaseHelper.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,62 +16,58 @@

package com.nvidia.spark

import ai.rapids.cudf.{ColumnVector, DType, Scalar}
import ai.rapids.cudf.{ColumnView, DType, Scalar}
import com.nvidia.spark.rapids.Arm.withResource
import com.nvidia.spark.rapids.shims.SparkShimImpl

import org.apache.spark.sql.catalyst.util.RebaseDateTime
import org.apache.spark.sql.rapids.execution.TrampolineUtil

object RebaseHelper {
private[this] def isDateRebaseNeeded(column: ColumnVector,
startDay: Int): Boolean = {
// TODO update this for nested column checks
// https://github.com/NVIDIA/spark-rapids/issues/1126
private[this] def isRebaseNeeded(column: ColumnView, checkType: DType,
minGood: Scalar): Boolean = {
val dtype = column.getType
if (dtype == DType.TIMESTAMP_DAYS) {
val hasBad = withResource(Scalar.timestampDaysFromInt(startDay)) {
column.lessThan
}
val anyBad = withResource(hasBad) {
_.any()
}
withResource(anyBad) { _ =>
anyBad.isValid && anyBad.getBoolean
}
} else {
false
}
}
require(!dtype.hasTimeResolution || dtype == DType.TIMESTAMP_MICROSECONDS)

private[this] def isTimeRebaseNeeded(column: ColumnVector,
startTs: Long): Boolean = {
val dtype = column.getType
if (dtype.hasTimeResolution) {
require(dtype == DType.TIMESTAMP_MICROSECONDS)
withResource(
Scalar.timestampFromLong(DType.TIMESTAMP_MICROSECONDS, startTs)) { minGood =>
dtype match {
case `checkType` =>
withResource(column.lessThan(minGood)) { hasBad =>
withResource(hasBad.any()) { a =>
a.isValid && a.getBoolean
withResource(hasBad.any()) { anyBad =>
anyBad.isValid && anyBad.getBoolean
}
}
}
} else {
false

case DType.LIST | DType.STRUCT => (0 until column.getNumChildren).exists(i =>
withResource(column.getChildColumnView(i)) { child =>
isRebaseNeeded(child, checkType, minGood)
})

case _ => false
}
}

private[this] def isDateRebaseNeeded(column: ColumnView, startDay: Int): Boolean = {
withResource(Scalar.timestampDaysFromInt(startDay)) { minGood =>
isRebaseNeeded(column, DType.TIMESTAMP_DAYS, minGood)
}
}

private[this] def isTimeRebaseNeeded(column: ColumnView, startTs: Long): Boolean = {
withResource(Scalar.timestampFromLong(DType.TIMESTAMP_MICROSECONDS, startTs)) { minGood =>
isRebaseNeeded(column, DType.TIMESTAMP_MICROSECONDS, minGood)
}
}

def isDateRebaseNeededInRead(column: ColumnVector): Boolean =
def isDateRebaseNeededInRead(column: ColumnView): Boolean =
isDateRebaseNeeded(column, RebaseDateTime.lastSwitchJulianDay)

def isTimeRebaseNeededInRead(column: ColumnVector): Boolean =
def isTimeRebaseNeededInRead(column: ColumnView): Boolean =
isTimeRebaseNeeded(column, RebaseDateTime.lastSwitchJulianTs)

def isDateRebaseNeededInWrite(column: ColumnVector): Boolean =
def isDateRebaseNeededInWrite(column: ColumnView): Boolean =
isDateRebaseNeeded(column, RebaseDateTime.lastSwitchGregorianDay)

def isTimeRebaseNeededInWrite(column: ColumnVector): Boolean =
def isTimeRebaseNeededInWrite(column: ColumnView): Boolean =
isTimeRebaseNeeded(column, RebaseDateTime.lastSwitchGregorianTs)

def newRebaseExceptionInRead(format: String): Exception = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -310,11 +310,9 @@ class GpuParquetWriter(
override def throwIfRebaseNeededInExceptionMode(batch: ColumnarBatch): Unit = {
val cols = GpuColumnVector.extractBases(batch)
cols.foreach { col =>
// if col is a day
if (dateRebaseMode.equals("EXCEPTION") && RebaseHelper.isDateRebaseNeededInWrite(col)) {
throw DataSourceUtils.newRebaseExceptionInWrite("Parquet")
}
// if col is a time
else if (timestampRebaseMode.equals("EXCEPTION") &&
RebaseHelper.isTimeRebaseNeededInWrite(col)) {
throw DataSourceUtils.newRebaseExceptionInWrite("Parquet")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -165,11 +165,9 @@ object GpuParquetScan {
hasInt96Timestamps: Boolean): Unit = {
(0 until table.getNumberOfColumns).foreach { i =>
val col = table.getColumn(i)
// if col is a day
if (!isCorrectedDateTimeRebase && RebaseHelper.isDateRebaseNeededInRead(col)) {
throw DataSourceUtils.newRebaseExceptionInRead("Parquet")
}
// if col is a time
else if (hasInt96Timestamps && !isCorrectedInt96Rebase ||
!hasInt96Timestamps && !isCorrectedDateTimeRebase) {
if (RebaseHelper.isTimeRebaseNeededInRead(col)) {
Expand Down Expand Up @@ -201,21 +199,6 @@ object GpuParquetScan {

FileFormatChecks.tag(meta, readSchema, ParquetFormatType, ReadFileOp)

val schemaHasTimestamps = readSchema.exists { field =>
TrampolineUtil.dataTypeExistsRecursively(field.dataType, _.isInstanceOf[TimestampType])
}
def isTsOrDate(dt: DataType) : Boolean = dt match {
case TimestampType | DateType => true
case _ => false
}
val schemaMightNeedNestedRebase = readSchema.exists { field =>
if (DataTypeUtils.isNestedType(field.dataType)) {
TrampolineUtil.dataTypeExistsRecursively(field.dataType, isTsOrDate)
} else {
false
}
}

// Currently timestamp conversion is not supported.
// If support needs to be added then we need to follow the logic in Spark's
// ParquetPartitionReaderFactory and VectorizedColumnReader which essentially
Expand All @@ -225,35 +208,32 @@ object GpuParquetScan {
// were written in that timezone and convert them to UTC timestamps.
// Essentially this should boil down to a vector subtract of the scalar delta
// between the configured timezone's delta from UTC on the timestamp data.
val schemaHasTimestamps = readSchema.exists { field =>
TrampolineUtil.dataTypeExistsRecursively(field.dataType, _.isInstanceOf[TimestampType])
}
if (schemaHasTimestamps && sparkSession.sessionState.conf.isParquetINT96TimestampConversion) {
meta.willNotWorkOnGpu("GpuParquetScan does not support int96 timestamp conversion")
}

val schemaHasDates = readSchema.exists { field =>
TrampolineUtil.dataTypeExistsRecursively(field.dataType, _.isInstanceOf[DateType])
}

sqlConf.get(SparkShimImpl.int96ParquetRebaseReadKey) match {
case "EXCEPTION" => if (schemaMightNeedNestedRebase) {
meta.willNotWorkOnGpu("Nested timestamp and date values are not supported when " +
s"${SparkShimImpl.int96ParquetRebaseReadKey} is EXCEPTION")
}
case "CORRECTED" => // Good
case "EXCEPTION" | "CORRECTED" => // Good
case "LEGACY" => // really is EXCEPTION for us...
if (schemaMightNeedNestedRebase) {
meta.willNotWorkOnGpu("Nested timestamp and date values are not supported when " +
s"${SparkShimImpl.int96ParquetRebaseReadKey} is LEGACY")
if (schemaHasTimestamps) {
meta.willNotWorkOnGpu("LEGACY rebase mode for dates and timestamps is not supported")
}
case other =>
meta.willNotWorkOnGpu(s"$other is not a supported read rebase mode")
}

sqlConf.get(SparkShimImpl.parquetRebaseReadKey) match {
case "EXCEPTION" => if (schemaMightNeedNestedRebase) {
meta.willNotWorkOnGpu("Nested timestamp and date values are not supported when " +
s"${SparkShimImpl.parquetRebaseReadKey} is EXCEPTION")
}
case "CORRECTED" => // Good
case "EXCEPTION" | "CORRECTED" => // Good
case "LEGACY" => // really is EXCEPTION for us...
if (schemaMightNeedNestedRebase) {
meta.willNotWorkOnGpu("Nested timestamp and date values are not supported when " +
s"${SparkShimImpl.parquetRebaseReadKey} is LEGACY")
if (schemaHasDates || schemaHasTimestamps) {
meta.willNotWorkOnGpu("LEGACY rebase mode for dates and timestamps is not supported")
}
case other =>
meta.willNotWorkOnGpu(s"$other is not a supported read rebase mode")
Expand Down Expand Up @@ -2918,4 +2898,3 @@ object ParquetPartitionReader {
block
}
}

0 comments on commit 401d0d8

Please sign in to comment.