Skip to content

Commit

Permalink
Support from_unixtime via Gpu for non-UTC time zone (#9814)
Browse files Browse the repository at this point in the history
Signed-off-by: Chong Gao <[email protected]>
  • Loading branch information
res-life authored Dec 7, 2023
1 parent 6affa02 commit 2805b95
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 6 deletions.
18 changes: 18 additions & 0 deletions integration_tests/src/main/python/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,24 @@ def is_utc():
def is_not_utc():
return not is_utc()

# key is time zone, value is recorded boolean value
_support_info_cache_for_time_zone = {}

def is_supported_time_zone():
"""
Is current TZ supported, forward to Java TimeZoneDB to check
"""
tz = get_test_tz()
if tz in _support_info_cache_for_time_zone:
# already cached
return _support_info_cache_for_time_zone[tz]
else:
jvm = spark_jvm()
support = jvm.com.nvidia.spark.rapids.jni.GpuTimeZoneDB.isSupportedTimeZone(tz)
# cache support info
_support_info_cache_for_time_zone[tz] = support
return support

_is_nightly_run = False
_is_precommit_run = False

Expand Down
23 changes: 22 additions & 1 deletion integration_tests/src/main/python/date_time_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import pytest
from asserts import assert_gpu_and_cpu_are_equal_collect, assert_gpu_fallback_collect, assert_gpu_and_cpu_error
from conftest import is_not_utc
from conftest import is_supported_time_zone
from data_gen import *
from datetime import date, datetime, timezone
from marks import ignore_order, incompat, allow_non_gpu
Expand Down Expand Up @@ -474,6 +474,27 @@ def test_date_format(data_gen, date_format):
assert_gpu_and_cpu_are_equal_collect(
lambda spark : unary_op_df(spark, data_gen).selectExpr("date_format(a, '{}')".format(date_format)))

@pytest.mark.parametrize('date_format', supported_date_formats, ids=idfn)
# from 0001-02-01 to 9999-12-30 to avoid 'year 0 is out of range'
@pytest.mark.parametrize('data_gen', [LongGen(min_val=int(datetime(1, 2, 1).timestamp()), max_val=int(datetime(9999, 12, 30).timestamp()))], ids=idfn)
@pytest.mark.skipif(not is_supported_time_zone(), reason="not all time zones are supported now, refer to https://github.com/NVIDIA/spark-rapids/issues/6839, please update after all time zones are supported")
def test_from_unixtime(data_gen, date_format):
conf = {'spark.rapids.sql.nonUTC.enabled': True}
assert_gpu_and_cpu_are_equal_collect(
lambda spark : unary_op_df(spark, data_gen, length=5).selectExpr("from_unixtime(a, '{}')".format(date_format)),
conf)

@allow_non_gpu('ProjectExec')
@pytest.mark.parametrize('date_format', supported_date_formats, ids=idfn)
# from 0001-02-01 to 9999-12-30 to avoid 'year 0 is out of range'
@pytest.mark.parametrize('data_gen', [LongGen(min_val=int(datetime(1, 2, 1).timestamp()), max_val=int(datetime(9999, 12, 30).timestamp()))], ids=idfn)
@pytest.mark.skipif(is_supported_time_zone(), reason="not all time zones are supported now, refer to https://github.com/NVIDIA/spark-rapids/issues/6839, please update after all time zones are supported")
def test_from_unixtime_fall_back(data_gen, date_format):
conf = {'spark.rapids.sql.nonUTC.enabled': True}
assert_gpu_fallback_collect(lambda spark : unary_op_df(spark, data_gen, length=5).selectExpr("from_unixtime(a, '{}')".format(date_format)),
'ProjectExec',
conf)

unsupported_date_formats = ['F']
@pytest.mark.parametrize('date_format', unsupported_date_formats, ids=idfn)
@pytest.mark.parametrize('data_gen', date_n_time_gens, ids=idfn)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1780,9 +1780,10 @@ object GpuOverrides extends Logging {
.withPsNote(TypeEnum.STRING, "Only a limited number of formats are supported"),
TypeSig.STRING)),
(a, conf, p, r) => new UnixTimeExprMeta[FromUnixTime](a, conf, p, r) {
override def isTimeZoneSupported = true
override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression =
// passing the already converted strf string for a little optimization
GpuFromUnixTime(lhs, rhs, strfFormat)
GpuFromUnixTime(lhs, rhs, strfFormat, a.timeZoneId)
}),
expr[FromUTCTimestamp](
"Render the input UTC timestamp in the input timezone",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import java.time.ZoneId

import scala.collection.mutable

import com.nvidia.spark.rapids.jni.GpuTimeZoneDB
import com.nvidia.spark.rapids.shims.{DistributionUtil, SparkShimImpl}

import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, BinaryExpression, Cast, ComplexTypeMergingExpression, Expression, QuaternaryExpression, String2TrimExpression, TernaryExpression, TimeZoneAwareExpression, UnaryExpression, UTCTimestamp, WindowExpression, WindowFunction}
Expand Down Expand Up @@ -1130,7 +1131,7 @@ abstract class BaseExprMeta[INPUT <: Expression](
if (!isTimeZoneSupported) return checkUTCTimezone(this)

// Level 4 check
if (TimeZoneDB.isSupportedTimezone(getZoneId())) {
if (!GpuTimeZoneDB.isSupportedTimeZone(getZoneId())) {
willNotWorkOnGpu(TimeZoneDB.timezoneNotSupportedStr(this.wrapped.getClass.toString))
}
}
Expand Down Expand Up @@ -1201,7 +1202,7 @@ abstract class BaseExprMeta[INPUT <: Expression](

// Level 3 timezone checking flag, need to override to true when supports timezone in functions
// Useless if it's not timezone related expression defined in [[needTimeZoneCheck]]
val isTimeZoneSupported: Boolean = false
def isTimeZoneSupported: Boolean = false

/**
* Timezone check which only allows UTC timezone. This is consistent with previous behavior.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1009,8 +1009,18 @@ case class GpuFromUnixTime(
override def doColumnar(lhs: GpuColumnVector, rhs: GpuScalar): ColumnVector = {
// we aren't using rhs as it was already converted in the GpuOverrides while creating the
// expressions map and passed down here as strfFormat
withResource(lhs.getBase.asTimestampSeconds) { tsVector =>
tsVector.asStrings(strfFormat)
withResource(lhs.getBase.asTimestampSeconds) { secondsVector =>
withResource(secondsVector.asTimestampMicroseconds) { tsVector =>

This comment has been minimized.

Copy link
@winningsix

winningsix Dec 8, 2023

Collaborator
if (GpuOverrides.isUTCTimezone(zoneId)) {
// UTC time zone
tsVector.asStrings(strfFormat)
} else {
// Non-UTC TZ
withResource(GpuTimeZoneDB.fromUtcTimestampToTimestamp(tsVector, zoneId.normalized())) {
shifted => shifted.asStrings(strfFormat)
}
}
}
}
}

Expand Down

0 comments on commit 2805b95

Please sign in to comment.