diff --git a/integration_tests/src/main/python/parquet_test.py b/integration_tests/src/main/python/parquet_test.py index b3e04b91d93..51b327dfd70 100644 --- a/integration_tests/src/main/python/parquet_test.py +++ b/integration_tests/src/main/python/parquet_test.py @@ -312,35 +312,14 @@ def test_parquet_pred_push_round_trip(spark_tmp_path, parquet_gen, read_func, v1 parquet_ts_write_options = ['INT96', 'TIMESTAMP_MICROS', 'TIMESTAMP_MILLIS'] - -# Once https://github.com/NVIDIA/spark-rapids/issues/1126 is fixed delete this test and merge it -# into test_ts_read_round_trip nested timestamps and dates are not supported right now. -@pytest.mark.parametrize('gen', [ArrayGen(TimestampGen(start=datetime(1900, 1, 1, tzinfo=timezone.utc))), +# Once https://github.com/NVIDIA/spark-rapids/issues/132 is fixed replace this with timestamp_gen +@pytest.mark.parametrize('gen', [TimestampGen(start=datetime(1900, 1, 1, tzinfo=timezone.utc)), + ArrayGen(TimestampGen(start=datetime(1900, 1, 1, tzinfo=timezone.utc))), ArrayGen(ArrayGen(TimestampGen(start=datetime(1900, 1, 1, tzinfo=timezone.utc))))], ids=idfn) @pytest.mark.parametrize('ts_write', parquet_ts_write_options) @pytest.mark.parametrize('ts_rebase', ['CORRECTED', 'LEGACY']) @pytest.mark.parametrize('reader_confs', reader_opt_confs) @pytest.mark.parametrize('v1_enabled_list', ["", "parquet"]) -@pytest.mark.xfail(reason='https://github.com/NVIDIA/spark-rapids/issues/1126') -def test_parquet_ts_read_round_trip_nested(gen, spark_tmp_path, ts_write, ts_rebase, v1_enabled_list, reader_confs): - data_path = spark_tmp_path + '/PARQUET_DATA' - with_cpu_session( - lambda spark : unary_op_df(spark, gen).write.parquet(data_path), - conf={'spark.sql.legacy.parquet.datetimeRebaseModeInWrite': ts_rebase, - 'spark.sql.legacy.parquet.int96RebaseModeInWrite': ts_rebase, - 'spark.sql.parquet.outputTimestampType': ts_write}) - all_confs = copy_and_update(reader_confs, {'spark.sql.sources.useV1SourceList': v1_enabled_list}) - assert_gpu_and_cpu_are_equal_collect( - lambda spark : spark.read.parquet(data_path), - conf=all_confs) - -# Once https://github.com/NVIDIA/spark-rapids/issues/132 is fixed replace this with -# timestamp_gen -@pytest.mark.parametrize('gen', [TimestampGen(start=datetime(1900, 1, 1, tzinfo=timezone.utc))], ids=idfn) -@pytest.mark.parametrize('ts_write', parquet_ts_write_options) -@pytest.mark.parametrize('ts_rebase', ['CORRECTED', 'LEGACY']) -@pytest.mark.parametrize('reader_confs', reader_opt_confs) -@pytest.mark.parametrize('v1_enabled_list', ["", "parquet"]) def test_ts_read_round_trip(gen, spark_tmp_path, ts_write, ts_rebase, v1_enabled_list, reader_confs): data_path = spark_tmp_path + '/PARQUET_DATA' with_cpu_session( @@ -358,10 +337,10 @@ def readParquetCatchException(spark, data_path): df = spark.read.parquet(data_path).collect() assert e_info.match(r".*SparkUpgradeException.*") -# Once https://github.com/NVIDIA/spark-rapids/issues/1126 is fixed nested timestamps and dates should be added in -# Once https://github.com/NVIDIA/spark-rapids/issues/132 is fixed replace this with -# timestamp_gen -@pytest.mark.parametrize('gen', [TimestampGen(start=datetime(1590, 1, 1, tzinfo=timezone.utc))], ids=idfn) +# Once https://github.com/NVIDIA/spark-rapids/issues/132 is fixed replace this with timestamp_gen +@pytest.mark.parametrize('gen', [TimestampGen(start=datetime(1590, 1, 1, tzinfo=timezone.utc)), + ArrayGen(TimestampGen(start=datetime(1590, 1, 1, tzinfo=timezone.utc))), + ArrayGen(ArrayGen(TimestampGen(start=datetime(1590, 1, 1, tzinfo=timezone.utc))))], ids=idfn) @pytest.mark.parametrize('ts_write', parquet_ts_write_options) @pytest.mark.parametrize('ts_rebase', ['LEGACY']) @pytest.mark.parametrize('reader_confs', reader_opt_confs) @@ -1003,7 +982,7 @@ def test_parquet_reading_from_unaligned_pages_basic_filters_with_nulls(spark_tmp conf_for_parquet_aggregate_pushdown = { - "spark.sql.parquet.aggregatePushdown": "true", + "spark.sql.parquet.aggregatePushdown": "true", "spark.sql.sources.useV1SourceList": "" } @@ -1490,7 +1469,7 @@ def test_parquet_read_count(spark_tmp_path): def test_read_case_col_name(spark_tmp_path, read_func, v1_enabled_list, reader_confs, col_name): all_confs = copy_and_update(reader_confs, { 'spark.sql.sources.useV1SourceList': v1_enabled_list}) - gen_list =[('k0', LongGen(nullable=False, min_val=0, max_val=0)), + gen_list =[('k0', LongGen(nullable=False, min_val=0, max_val=0)), ('k1', LongGen(nullable=False, min_val=1, max_val=1)), ('k2', LongGen(nullable=False, min_val=2, max_val=2)), ('k3', LongGen(nullable=False, min_val=3, max_val=3)), @@ -1498,7 +1477,7 @@ def test_read_case_col_name(spark_tmp_path, read_func, v1_enabled_list, reader_c ('v1', LongGen()), ('v2', LongGen()), ('v3', LongGen())] - + gen = StructGen(gen_list, nullable=False) data_path = spark_tmp_path + '/PAR_DATA' reader = read_func(data_path) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/RebaseHelper.scala b/sql-plugin/src/main/scala/com/nvidia/spark/RebaseHelper.scala index 81fb4c79ddc..e928be79f7b 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/RebaseHelper.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/RebaseHelper.scala @@ -16,7 +16,7 @@ package com.nvidia.spark -import ai.rapids.cudf.{ColumnVector, DType, Scalar} +import ai.rapids.cudf.{ColumnView, DType, Scalar} import com.nvidia.spark.rapids.Arm.withResource import com.nvidia.spark.rapids.shims.SparkShimImpl @@ -24,54 +24,50 @@ import org.apache.spark.sql.catalyst.util.RebaseDateTime import org.apache.spark.sql.rapids.execution.TrampolineUtil object RebaseHelper { - private[this] def isDateRebaseNeeded(column: ColumnVector, - startDay: Int): Boolean = { - // TODO update this for nested column checks - // https://github.com/NVIDIA/spark-rapids/issues/1126 + private[this] def isRebaseNeeded(column: ColumnView, checkType: DType, + minGood: Scalar): Boolean = { val dtype = column.getType - if (dtype == DType.TIMESTAMP_DAYS) { - val hasBad = withResource(Scalar.timestampDaysFromInt(startDay)) { - column.lessThan - } - val anyBad = withResource(hasBad) { - _.any() - } - withResource(anyBad) { _ => - anyBad.isValid && anyBad.getBoolean - } - } else { - false - } - } + require(!dtype.hasTimeResolution || dtype == DType.TIMESTAMP_MICROSECONDS) - private[this] def isTimeRebaseNeeded(column: ColumnVector, - startTs: Long): Boolean = { - val dtype = column.getType - if (dtype.hasTimeResolution) { - require(dtype == DType.TIMESTAMP_MICROSECONDS) - withResource( - Scalar.timestampFromLong(DType.TIMESTAMP_MICROSECONDS, startTs)) { minGood => + dtype match { + case `checkType` => withResource(column.lessThan(minGood)) { hasBad => - withResource(hasBad.any()) { a => - a.isValid && a.getBoolean + withResource(hasBad.any()) { anyBad => + anyBad.isValid && anyBad.getBoolean } } - } - } else { - false + + case DType.LIST | DType.STRUCT => (0 until column.getNumChildren).exists(i => + withResource(column.getChildColumnView(i)) { child => + isRebaseNeeded(child, checkType, minGood) + }) + + case _ => false + } + } + + private[this] def isDateRebaseNeeded(column: ColumnView, startDay: Int): Boolean = { + withResource(Scalar.timestampDaysFromInt(startDay)) { minGood => + isRebaseNeeded(column, DType.TIMESTAMP_DAYS, minGood) + } + } + + private[this] def isTimeRebaseNeeded(column: ColumnView, startTs: Long): Boolean = { + withResource(Scalar.timestampFromLong(DType.TIMESTAMP_MICROSECONDS, startTs)) { minGood => + isRebaseNeeded(column, DType.TIMESTAMP_MICROSECONDS, minGood) } } - def isDateRebaseNeededInRead(column: ColumnVector): Boolean = + def isDateRebaseNeededInRead(column: ColumnView): Boolean = isDateRebaseNeeded(column, RebaseDateTime.lastSwitchJulianDay) - def isTimeRebaseNeededInRead(column: ColumnVector): Boolean = + def isTimeRebaseNeededInRead(column: ColumnView): Boolean = isTimeRebaseNeeded(column, RebaseDateTime.lastSwitchJulianTs) - def isDateRebaseNeededInWrite(column: ColumnVector): Boolean = + def isDateRebaseNeededInWrite(column: ColumnView): Boolean = isDateRebaseNeeded(column, RebaseDateTime.lastSwitchGregorianDay) - def isTimeRebaseNeededInWrite(column: ColumnVector): Boolean = + def isTimeRebaseNeededInWrite(column: ColumnView): Boolean = isTimeRebaseNeeded(column, RebaseDateTime.lastSwitchGregorianTs) def newRebaseExceptionInRead(format: String): Exception = { diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetFileFormat.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetFileFormat.scala index cff38e3e3ec..585012d2a2e 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetFileFormat.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetFileFormat.scala @@ -310,11 +310,9 @@ class GpuParquetWriter( override def throwIfRebaseNeededInExceptionMode(batch: ColumnarBatch): Unit = { val cols = GpuColumnVector.extractBases(batch) cols.foreach { col => - // if col is a day if (dateRebaseMode.equals("EXCEPTION") && RebaseHelper.isDateRebaseNeededInWrite(col)) { throw DataSourceUtils.newRebaseExceptionInWrite("Parquet") } - // if col is a time else if (timestampRebaseMode.equals("EXCEPTION") && RebaseHelper.isTimeRebaseNeededInWrite(col)) { throw DataSourceUtils.newRebaseExceptionInWrite("Parquet") diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetScan.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetScan.scala index 34f6777c26f..6d00c2b6393 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetScan.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetScan.scala @@ -165,11 +165,9 @@ object GpuParquetScan { hasInt96Timestamps: Boolean): Unit = { (0 until table.getNumberOfColumns).foreach { i => val col = table.getColumn(i) - // if col is a day if (!isCorrectedDateTimeRebase && RebaseHelper.isDateRebaseNeededInRead(col)) { throw DataSourceUtils.newRebaseExceptionInRead("Parquet") } - // if col is a time else if (hasInt96Timestamps && !isCorrectedInt96Rebase || !hasInt96Timestamps && !isCorrectedDateTimeRebase) { if (RebaseHelper.isTimeRebaseNeededInRead(col)) { @@ -201,21 +199,6 @@ object GpuParquetScan { FileFormatChecks.tag(meta, readSchema, ParquetFormatType, ReadFileOp) - val schemaHasTimestamps = readSchema.exists { field => - TrampolineUtil.dataTypeExistsRecursively(field.dataType, _.isInstanceOf[TimestampType]) - } - def isTsOrDate(dt: DataType) : Boolean = dt match { - case TimestampType | DateType => true - case _ => false - } - val schemaMightNeedNestedRebase = readSchema.exists { field => - if (DataTypeUtils.isNestedType(field.dataType)) { - TrampolineUtil.dataTypeExistsRecursively(field.dataType, isTsOrDate) - } else { - false - } - } - // Currently timestamp conversion is not supported. // If support needs to be added then we need to follow the logic in Spark's // ParquetPartitionReaderFactory and VectorizedColumnReader which essentially @@ -225,35 +208,32 @@ object GpuParquetScan { // were written in that timezone and convert them to UTC timestamps. // Essentially this should boil down to a vector subtract of the scalar delta // between the configured timezone's delta from UTC on the timestamp data. + val schemaHasTimestamps = readSchema.exists { field => + TrampolineUtil.dataTypeExistsRecursively(field.dataType, _.isInstanceOf[TimestampType]) + } if (schemaHasTimestamps && sparkSession.sessionState.conf.isParquetINT96TimestampConversion) { meta.willNotWorkOnGpu("GpuParquetScan does not support int96 timestamp conversion") } + val schemaHasDates = readSchema.exists { field => + TrampolineUtil.dataTypeExistsRecursively(field.dataType, _.isInstanceOf[DateType]) + } + sqlConf.get(SparkShimImpl.int96ParquetRebaseReadKey) match { - case "EXCEPTION" => if (schemaMightNeedNestedRebase) { - meta.willNotWorkOnGpu("Nested timestamp and date values are not supported when " + - s"${SparkShimImpl.int96ParquetRebaseReadKey} is EXCEPTION") - } - case "CORRECTED" => // Good + case "EXCEPTION" | "CORRECTED" => // Good case "LEGACY" => // really is EXCEPTION for us... - if (schemaMightNeedNestedRebase) { - meta.willNotWorkOnGpu("Nested timestamp and date values are not supported when " + - s"${SparkShimImpl.int96ParquetRebaseReadKey} is LEGACY") + if (schemaHasTimestamps) { + meta.willNotWorkOnGpu("LEGACY rebase mode for dates and timestamps is not supported") } case other => meta.willNotWorkOnGpu(s"$other is not a supported read rebase mode") } sqlConf.get(SparkShimImpl.parquetRebaseReadKey) match { - case "EXCEPTION" => if (schemaMightNeedNestedRebase) { - meta.willNotWorkOnGpu("Nested timestamp and date values are not supported when " + - s"${SparkShimImpl.parquetRebaseReadKey} is EXCEPTION") - } - case "CORRECTED" => // Good + case "EXCEPTION" | "CORRECTED" => // Good case "LEGACY" => // really is EXCEPTION for us... - if (schemaMightNeedNestedRebase) { - meta.willNotWorkOnGpu("Nested timestamp and date values are not supported when " + - s"${SparkShimImpl.parquetRebaseReadKey} is LEGACY") + if (schemaHasDates || schemaHasTimestamps) { + meta.willNotWorkOnGpu("LEGACY rebase mode for dates and timestamps is not supported") } case other => meta.willNotWorkOnGpu(s"$other is not a supported read rebase mode") @@ -2918,4 +2898,3 @@ object ParquetPartitionReader { block } } -