From 2805b95619d2c5bd687a342677da53f48dd89487 Mon Sep 17 00:00:00 2001 From: Chong Gao Date: Thu, 7 Dec 2023 18:39:15 +0800 Subject: [PATCH] Support from_unixtime via Gpu for non-UTC time zone (#9814) Signed-off-by: Chong Gao --- integration_tests/src/main/python/conftest.py | 18 +++++++++++++++ .../src/main/python/date_time_test.py | 23 ++++++++++++++++++- .../nvidia/spark/rapids/GpuOverrides.scala | 3 ++- .../com/nvidia/spark/rapids/RapidsMeta.scala | 5 ++-- .../sql/rapids/datetimeExpressions.scala | 14 +++++++++-- 5 files changed, 57 insertions(+), 6 deletions(-) diff --git a/integration_tests/src/main/python/conftest.py b/integration_tests/src/main/python/conftest.py index a9b2f6146ec..210034a5ded 100644 --- a/integration_tests/src/main/python/conftest.py +++ b/integration_tests/src/main/python/conftest.py @@ -86,6 +86,24 @@ def is_utc(): def is_not_utc(): return not is_utc() +# key is time zone, value is recorded boolean value +_support_info_cache_for_time_zone = {} + +def is_supported_time_zone(): + """ + Is current TZ supported, forward to Java TimeZoneDB to check + """ + tz = get_test_tz() + if tz in _support_info_cache_for_time_zone: + # already cached + return _support_info_cache_for_time_zone[tz] + else: + jvm = spark_jvm() + support = jvm.com.nvidia.spark.rapids.jni.GpuTimeZoneDB.isSupportedTimeZone(tz) + # cache support info + _support_info_cache_for_time_zone[tz] = support + return support + _is_nightly_run = False _is_precommit_run = False diff --git a/integration_tests/src/main/python/date_time_test.py b/integration_tests/src/main/python/date_time_test.py index 4c08d7d3935..8af0b94b77f 100644 --- a/integration_tests/src/main/python/date_time_test.py +++ b/integration_tests/src/main/python/date_time_test.py @@ -14,7 +14,7 @@ import pytest from asserts import assert_gpu_and_cpu_are_equal_collect, assert_gpu_fallback_collect, assert_gpu_and_cpu_error -from conftest import is_not_utc +from conftest import is_supported_time_zone from data_gen import * from datetime import date, datetime, timezone from marks import ignore_order, incompat, allow_non_gpu @@ -474,6 +474,27 @@ def test_date_format(data_gen, date_format): assert_gpu_and_cpu_are_equal_collect( lambda spark : unary_op_df(spark, data_gen).selectExpr("date_format(a, '{}')".format(date_format))) +@pytest.mark.parametrize('date_format', supported_date_formats, ids=idfn) +# from 0001-02-01 to 9999-12-30 to avoid 'year 0 is out of range' +@pytest.mark.parametrize('data_gen', [LongGen(min_val=int(datetime(1, 2, 1).timestamp()), max_val=int(datetime(9999, 12, 30).timestamp()))], ids=idfn) +@pytest.mark.skipif(not is_supported_time_zone(), reason="not all time zones are supported now, refer to https://github.com/NVIDIA/spark-rapids/issues/6839, please update after all time zones are supported") +def test_from_unixtime(data_gen, date_format): + conf = {'spark.rapids.sql.nonUTC.enabled': True} + assert_gpu_and_cpu_are_equal_collect( + lambda spark : unary_op_df(spark, data_gen, length=5).selectExpr("from_unixtime(a, '{}')".format(date_format)), + conf) + +@allow_non_gpu('ProjectExec') +@pytest.mark.parametrize('date_format', supported_date_formats, ids=idfn) +# from 0001-02-01 to 9999-12-30 to avoid 'year 0 is out of range' +@pytest.mark.parametrize('data_gen', [LongGen(min_val=int(datetime(1, 2, 1).timestamp()), max_val=int(datetime(9999, 12, 30).timestamp()))], ids=idfn) +@pytest.mark.skipif(is_supported_time_zone(), reason="not all time zones are supported now, refer to https://github.com/NVIDIA/spark-rapids/issues/6839, please update after all time zones are supported") +def test_from_unixtime_fall_back(data_gen, date_format): + conf = {'spark.rapids.sql.nonUTC.enabled': True} + assert_gpu_fallback_collect(lambda spark : unary_op_df(spark, data_gen, length=5).selectExpr("from_unixtime(a, '{}')".format(date_format)), + 'ProjectExec', + conf) + unsupported_date_formats = ['F'] @pytest.mark.parametrize('date_format', unsupported_date_formats, ids=idfn) @pytest.mark.parametrize('data_gen', date_n_time_gens, ids=idfn) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala index 2a77fdbc06c..74430ae8e90 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala @@ -1780,9 +1780,10 @@ object GpuOverrides extends Logging { .withPsNote(TypeEnum.STRING, "Only a limited number of formats are supported"), TypeSig.STRING)), (a, conf, p, r) => new UnixTimeExprMeta[FromUnixTime](a, conf, p, r) { + override def isTimeZoneSupported = true override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression = // passing the already converted strf string for a little optimization - GpuFromUnixTime(lhs, rhs, strfFormat) + GpuFromUnixTime(lhs, rhs, strfFormat, a.timeZoneId) }), expr[FromUTCTimestamp]( "Render the input UTC timestamp in the input timezone", diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsMeta.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsMeta.scala index e5cf75e9a13..369209cc8fc 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsMeta.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsMeta.scala @@ -20,6 +20,7 @@ import java.time.ZoneId import scala.collection.mutable +import com.nvidia.spark.rapids.jni.GpuTimeZoneDB import com.nvidia.spark.rapids.shims.{DistributionUtil, SparkShimImpl} import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, BinaryExpression, Cast, ComplexTypeMergingExpression, Expression, QuaternaryExpression, String2TrimExpression, TernaryExpression, TimeZoneAwareExpression, UnaryExpression, UTCTimestamp, WindowExpression, WindowFunction} @@ -1130,7 +1131,7 @@ abstract class BaseExprMeta[INPUT <: Expression]( if (!isTimeZoneSupported) return checkUTCTimezone(this) // Level 4 check - if (TimeZoneDB.isSupportedTimezone(getZoneId())) { + if (!GpuTimeZoneDB.isSupportedTimeZone(getZoneId())) { willNotWorkOnGpu(TimeZoneDB.timezoneNotSupportedStr(this.wrapped.getClass.toString)) } } @@ -1201,7 +1202,7 @@ abstract class BaseExprMeta[INPUT <: Expression]( // Level 3 timezone checking flag, need to override to true when supports timezone in functions // Useless if it's not timezone related expression defined in [[needTimeZoneCheck]] - val isTimeZoneSupported: Boolean = false + def isTimeZoneSupported: Boolean = false /** * Timezone check which only allows UTC timezone. This is consistent with previous behavior. diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressions.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressions.scala index 8f6c591787f..d85fe582eff 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressions.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressions.scala @@ -1009,8 +1009,18 @@ case class GpuFromUnixTime( override def doColumnar(lhs: GpuColumnVector, rhs: GpuScalar): ColumnVector = { // we aren't using rhs as it was already converted in the GpuOverrides while creating the // expressions map and passed down here as strfFormat - withResource(lhs.getBase.asTimestampSeconds) { tsVector => - tsVector.asStrings(strfFormat) + withResource(lhs.getBase.asTimestampSeconds) { secondsVector => + withResource(secondsVector.asTimestampMicroseconds) { tsVector => + if (GpuOverrides.isUTCTimezone(zoneId)) { + // UTC time zone + tsVector.asStrings(strfFormat) + } else { + // Non-UTC TZ + withResource(GpuTimeZoneDB.fromUtcTimestampToTimestamp(tsVector, zoneId.normalized())) { + shifted => shifted.asStrings(strfFormat) + } + } + } } }