diff --git a/velox/functions/sparksql/DateTimeFunctions.h b/velox/functions/sparksql/DateTimeFunctions.h index f714f46fdadf..e112be5f109b 100644 --- a/velox/functions/sparksql/DateTimeFunctions.h +++ b/velox/functions/sparksql/DateTimeFunctions.h @@ -16,6 +16,7 @@ #include "velox/functions/lib/DateTimeFormatter.h" #include "velox/functions/lib/TimeUtils.h" +#include "velox/functions/prestosql/DateTimeImpl.h" #include "velox/type/tz/TimeZoneMap.h" namespace facebook::velox::functions::sparksql { @@ -152,4 +153,29 @@ struct UnixTimestampParseWithFormatFunction bool invalidFormat_{false}; }; +struct DateAddFunction { + VELOX_DEFINE_FUNCTION_TYPES(T); + + FOLLY_ALWAYS_INLINE bool call( + out_type& result, + const arg_type& date, + const int32_t value) { + result = addToDate(date, DateTimeUnit::kDay, value); + return true; + } + + FOLLY_ALWAYS_INLINE bool call( + out_type& result, + const arg_type& date, + const int16_t value) { + result = addToDate(date, DateTimeUnit::kDay, (int32_t)value); + return true; + } + + FOLLY_ALWAYS_INLINE bool + call(out_type& result, const arg_type& date, const int8_t value) { + result = addToDate(date, DateTimeUnit::kDay, (int32_t)value); + return true; + } +}; } // namespace facebook::velox::functions::sparksql diff --git a/velox/functions/sparksql/Register.cpp b/velox/functions/sparksql/Register.cpp index 2352fdd5614f..e0a596f605cc 100644 --- a/velox/functions/sparksql/Register.cpp +++ b/velox/functions/sparksql/Register.cpp @@ -222,6 +222,9 @@ void registerFunctions(const std::string& prefix) { {prefix + "year_of_week"}); registerFunction( {prefix + "year_of_week"}); + registerFunction({"date_add"}); + registerFunction({"date_add"}); + registerFunction({"date_add"}); } } // namespace sparksql diff --git a/velox/functions/sparksql/tests/DateTimeFunctionsTest.cpp b/velox/functions/sparksql/tests/DateTimeFunctionsTest.cpp index 457d08960ac7..d8fadf2494d1 100644 --- a/velox/functions/sparksql/tests/DateTimeFunctionsTest.cpp +++ b/velox/functions/sparksql/tests/DateTimeFunctionsTest.cpp @@ -29,6 +29,12 @@ class DateTimeFunctionsTest : public SparkFunctionBaseTest { {core::QueryConfig::kAdjustTimestampToTimezone, "true"}, }); } + + Date parseDate(const std::string& dateStr) { + Date returnDate; + parseTo(dateStr, returnDate); + return returnDate; + } }; TEST_F(DateTimeFunctionsTest, year) { @@ -142,5 +148,42 @@ TEST_F(DateTimeFunctionsTest, toUnixTimestamp) { EXPECT_THROW(evaluateOnce("to_unix_timestamp()"), VeloxUserError); } +TEST_F(DateTimeFunctionsTest, dateAdd) { + const auto dateAddInt32 = [&](std::optional date, + std::optional value) { + return evaluateOnce("date_add(c0, c1)", date, value); + }; + const auto dateAddInt16 = [&](std::optional date, + std::optional value) { + return evaluateOnce("date_add(c0, c1)", date, value); + }; + const auto dateAddInt8 = [&](std::optional date, + std::optional value) { + return evaluateOnce("date_add(c0, c1)", date, value); + }; + + // Check null behaviors + EXPECT_EQ(std::nullopt, dateAddInt32(std::nullopt, 1)); + EXPECT_EQ(std::nullopt, dateAddInt16(std::nullopt, 1)); + EXPECT_EQ(std::nullopt, dateAddInt8(std::nullopt, 1)); + + // Simple tests + EXPECT_EQ(parseDate("2019-03-01"), dateAddInt32(parseDate("2019-02-28"), 1)); + EXPECT_EQ(parseDate("2019-03-01"), dateAddInt16(parseDate("2019-02-28"), 1)); + EXPECT_EQ(parseDate("2019-03-01"), dateAddInt8(parseDate("2019-02-28"), 1)); + + // Account for the last day of a year-month + EXPECT_EQ( + parseDate("2020-02-29"), dateAddInt32(parseDate("2019-01-30"), 395)); + EXPECT_EQ( + parseDate("2020-02-29"), dateAddInt16(parseDate("2019-01-30"), 395)); + + // Check for negative intervals + EXPECT_EQ( + parseDate("2019-02-28"), dateAddInt32(parseDate("2020-02-29"), -366)); + EXPECT_EQ( + parseDate("2019-02-28"), dateAddInt16(parseDate("2020-02-29"), -366)); +} + } // namespace } // namespace facebook::velox::functions::sparksql::test