diff --git a/velox/type/Conversions.h b/velox/type/Conversions.h index 50503d41d8b2..36da8e4136e6 100644 --- a/velox/type/Conversions.h +++ b/velox/type/Conversions.h @@ -21,12 +21,17 @@ #include #include #include "velox/common/base/Exceptions.h" +#include "velox/type/DecimalUtil.h" #include "velox/type/TimestampConversion.h" #include "velox/type/Type.h" namespace facebook::velox::util { -template +template < + TypeKind KIND, + typename = void, + bool TRUNCATE = false, + bool ALLOW_DECIMAL = false> struct Converter { template // nullOutput API requires that the user has already set nullOutput to @@ -37,12 +42,26 @@ struct Converter { VELOX_UNSUPPORTED( "Conversion to {} is not supported", TypeTraits::name); } + + template + static typename TypeTraits::NativeType + cast(T val, bool& nullOutput, const TypePtr& toType) { + VELOX_UNSUPPORTED( + "Conversion of {} to {} is not supported", + CppToType::name, + TypeTraits::name); + } }; template <> struct Converter { using T = bool; + template + static T cast(const From& v, bool& nullOutput, const TypePtr& toType) { + VELOX_NYI(); + } + template static T cast(const From& v, bool& nullOutput) { return folly::to(v); @@ -69,7 +88,7 @@ struct Converter { } }; -template +template struct Converter< KIND, std::enable_if_t< @@ -77,16 +96,53 @@ struct Converter< KIND == TypeKind::SMALLINT || KIND == TypeKind::INTEGER || KIND == TypeKind::BIGINT || KIND == TypeKind::HUGEINT, void>, - TRUNCATE> { + TRUNCATE, + ALLOW_DECIMAL> { using T = typename TypeTraits::NativeType; + template + static T cast(const From& v, bool& nullOutput, const TypePtr& toType) { + VELOX_NYI(); + } + + // from long decimal cast to some type + static T cast(const int128_t& d, bool& nullOutput, const TypePtr& fromType) { + const auto& decimalType = fromType->asLongDecimal(); + auto scale0Decimal = DecimalUtil::rescaleWithRoundUp( + d, + decimalType.precision(), + decimalType.scale(), + decimalType.precision(), + 0, + false, + false); + return cast(scale0Decimal.value(), nullOutput); + } + + // from short decimal cast to some type + static T cast(const int64_t& d, bool& nullOutput, const TypePtr& fromType) { + const auto& decimalType = fromType->asShortDecimal(); + auto scale0Decimal = DecimalUtil::rescaleWithRoundUp( + d, + decimalType.precision(), + decimalType.scale(), + decimalType.precision(), + 0, + false, + false); + return cast(scale0Decimal.value(), nullOutput); + } + template static T cast(const From& v, bool& nullOutput) { VELOX_UNSUPPORTED( "Conversion to {} is not supported", TypeTraits::name); } - static T convertStringToInt(const folly::StringPiece& v, bool& nullOutput) { + static T convertStringToInt( + const folly::StringPiece& v, + const bool allowDecimal, + bool& nullOutput) { // Handling boolean target case fist because it is in this scope if constexpr (std::is_same_v) { return folly::to(v); @@ -110,6 +166,10 @@ struct Converter< } if (negative) { for (; index < len; index++) { + // Allow decimal and ignore the fractional part. + if (v[index] == '.' && allowDecimal) { + break; + } if (!std::isdigit(v[index])) { return -1; } @@ -121,6 +181,9 @@ struct Converter< } } else { for (; index < len; index++) { + if (v[index] == '.' && allowDecimal) { + break; + } if (!std::isdigit(v[index])) { return -1; } @@ -140,7 +203,7 @@ struct Converter< static T cast(const folly::StringPiece& v, bool& nullOutput) { try { if constexpr (TRUNCATE) { - return convertStringToInt(v, nullOutput); + return convertStringToInt(v, ALLOW_DECIMAL, nullOutput); } else { return folly::to(v); } @@ -152,7 +215,8 @@ struct Converter< static T cast(const StringView& v, bool& nullOutput) { try { if constexpr (TRUNCATE) { - return convertStringToInt(folly::StringPiece(v), nullOutput); + return convertStringToInt( + folly::StringPiece(v), ALLOW_DECIMAL, nullOutput); } else { return folly::to(folly::StringPiece(v)); } @@ -164,7 +228,7 @@ struct Converter< static T cast(const std::string& v, bool& nullOutput) { try { if constexpr (TRUNCATE) { - return convertStringToInt(v, nullOutput); + return convertStringToInt(v, ALLOW_DECIMAL, nullOutput); } else { return folly::to(v); } @@ -221,7 +285,9 @@ struct Converter< if (v > LimitType::maxLimit()) { return LimitType::max(); } - if (v < LimitType::minLimit()) { + // bool type's min is 0, but spark expects true for casting negative float + // data. + if (!std::is_same_v && v < LimitType::minLimit()) { return LimitType::min(); } return LimitType::cast(v); @@ -241,7 +307,9 @@ struct Converter< if (v > LimitType::maxLimit()) { return LimitType::max(); } - if (v < LimitType::minLimit()) { + // bool type's min is 0, but spark expects true for casting negative float + // data. + if (!std::is_same_v && v < LimitType::minLimit()) { return LimitType::min(); } return LimitType::cast(v); @@ -284,15 +352,39 @@ struct Converter< return folly::to(v); } } + + static T cast(const int128_t& v, bool& nullOutput) { + if constexpr (TRUNCATE) { + return T(v); + } else { + return static_cast(v); + } + } }; -template +template struct Converter< KIND, std::enable_if_t, - TRUNCATE> { + TRUNCATE, + ALLOW_DECIMAL> { using T = typename TypeTraits::NativeType; + template + static T cast(const From& v, bool& nullOutput, const TypePtr& toType) { + VELOX_NYI(); + } + + static T cast(const int64_t& v, bool& nullOutput, const TypePtr& fromType) { + auto decimalType = fromType->asShortDecimal(); + return DecimalUtil::toDoubleValue(v, decimalType.scale()); + } + + static T cast(const int128_t& v, bool& nullOutput, const TypePtr& fromType) { + auto decimalType = fromType->asLongDecimal(); + return DecimalUtil::toDoubleValue(v, decimalType.scale()); + } + template static T cast(const From& v, bool& nullOutput) { try { @@ -358,10 +450,31 @@ struct Converter< VELOX_UNSUPPORTED( "Conversion of Timestamp to Real or Double is not supported"); } + + static T cast(const int128_t& d, bool& nullOutput) { + VELOX_UNSUPPORTED( + "Conversion of int128_t to Real or Double is not supported"); + } }; -template -struct Converter { +template +struct Converter { + template + static std::string + cast(const T& v, bool& nullOutput, const TypePtr& fromType) { + VELOX_NYI(); + } + + static std::string + cast(const int64_t& v, bool& nullOutput, const TypePtr& fromType) { + return DecimalUtil::toString(v, fromType); + } + + static std::string + cast(const int128_t& v, bool& nullOutput, const TypePtr& fromType) { + return DecimalUtil::toString(v, fromType); + } + template static std::string cast(const T& val, bool& nullOutput) { if constexpr ( @@ -390,6 +503,11 @@ template <> struct Converter { using T = typename TypeTraits::NativeType; + template + static T cast(const From& v, bool& nullOutput, const TypePtr& toType) { + VELOX_NYI(); + } + template static T cast(const From& /* v */, bool& nullOutput) { VELOX_UNSUPPORTED("Conversion to Timestamp is not supported"); @@ -415,9 +533,15 @@ struct Converter { }; // Allow conversions from string to DATE type. -template -struct Converter { +template +struct Converter { using T = typename TypeTraits::NativeType; + + template + static T cast(const From& v, bool& nullOutput, const TypePtr& toType) { + VELOX_NYI(); + } + template static T cast(const From& /* v */, bool& nullOutput) { VELOX_UNSUPPORTED("Conversion to Date is not supported"); diff --git a/velox/type/DecimalUtil.h b/velox/type/DecimalUtil.h index 0d722603a361..7ca462721c4b 100644 --- a/velox/type/DecimalUtil.h +++ b/velox/type/DecimalUtil.h @@ -94,7 +94,9 @@ class DecimalUtil { const int fromPrecision, const int fromScale, const int toPrecision, - const int toScale) { + const int toScale, + bool nullOnOverflow = false, + bool roundUp = true) { int128_t rescaledValue = inputValue; auto scaleDifference = toScale - fromScale; bool isOverflow = false; @@ -108,18 +110,59 @@ class DecimalUtil { const auto scalingFactor = DecimalUtil::kPowersOfTen[scaleDifference]; rescaledValue /= scalingFactor; int128_t remainder = inputValue % scalingFactor; - if (inputValue >= 0 && remainder >= scalingFactor / 2) { + if (roundUp && inputValue >= 0 && remainder >= scalingFactor / 2) { ++rescaledValue; - } else if (remainder <= -scalingFactor / 2) { + } else if (roundUp && remainder <= -scalingFactor / 2) { --rescaledValue; } } // Check overflow. + if (rescaledValue < -DecimalUtil::kPowersOfTen[toPrecision] || + rescaledValue > DecimalUtil::kPowersOfTen[toPrecision] || isOverflow) { + if (nullOnOverflow) { + return std::nullopt; + } else { + VELOX_USER_FAIL( + "Cannot cast DECIMAL '{}' to DECIMAL({},{})", + DecimalUtil::toString( + inputValue, DECIMAL(fromPrecision, fromScale)), + toPrecision, + toScale); + } + } + return static_cast(rescaledValue); + } + + template + inline static std::optional rescaleDouble( + const TInput inputValue, + const int toPrecision, + const int toScale) { + static_assert( + std::is_same_v || std::is_same_v); + + // Multiply decimal with the scale + auto unscaled = inputValue * DecimalUtil::kPowersOfTen[toScale]; + + bool isOverflow = std::isnan(unscaled); + + unscaled = std::round(unscaled); + + // convert scaled double to int128 + int32_t sign = unscaled < 0 ? -1 : 1; + auto unscaled_abs = std::abs(unscaled); + + uint64_t high_bits = static_cast(std::ldexp(unscaled_abs, -64)); + uint64_t low_bits = static_cast( + unscaled_abs - std::ldexp(static_cast(high_bits), 64)); + + auto rescaledValue = HugeInt::build(high_bits, low_bits); + if (rescaledValue < -DecimalUtil::kPowersOfTen[toPrecision] || rescaledValue > DecimalUtil::kPowersOfTen[toPrecision] || isOverflow) { VELOX_USER_FAIL( "Cannot cast DECIMAL '{}' to DECIMAL({},{})", - DecimalUtil::toString(inputValue, DECIMAL(fromPrecision, fromScale)), + inputValue, toPrecision, toScale); } @@ -249,6 +292,226 @@ class DecimalUtil { } } + inline static int32_t FirstNonzeroLongNum( + const std::vector& mag, + int32_t length) { + int32_t fn = 0; + int32_t i; + for (i = length - 1; i >= 0 && mag[i] == 0; i--) + ; + fn = length - i - 1; + return fn; + } + + inline static int32_t GetInt( + int32_t n, + int32_t sig, + const std::vector& mag, + int32_t length) { + if (n < 0) + return 0; + if (n >= length) + return sig < 0 ? -1 : 0; + + int32_t magInt = mag[length - n - 1]; + return ( + sig >= 0 ? magInt + : (n <= FirstNonzeroLongNum(mag, length) ? -magInt : ~magInt)); + } + + inline static int32_t GetNumberOfLeadingZeros(uint32_t i) { + // TODO: we can get faster implementation by gcc build-in function + // HD, Figure 5-6 + if (i == 0) + return 32; + int32_t n = 1; + if (i >> 16 == 0) { + n += 16; + i <<= 16; + } + if (i >> 24 == 0) { + n += 8; + i <<= 8; + } + if (i >> 28 == 0) { + n += 4; + i <<= 4; + } + if (i >> 30 == 0) { + n += 2; + i <<= 2; + } + n -= i >> 31; + return n; + } + + inline static int32_t GetBitLengthForInt(uint32_t n) { + return 32 - GetNumberOfLeadingZeros(n); + } + + inline static int32_t GetBitCount(uint32_t i) { + // HD, Figure 5-2 + i = i - ((i >> 1) & 0x55555555); + i = (i & 0x33333333) + ((i >> 2) & 0x33333333); + i = (i + (i >> 4)) & 0x0f0f0f0f; + i = i + (i >> 8); + i = i + (i >> 16); + return i & 0x3f; + } + + inline static int32_t + GetBitLength(int32_t sig, const std::vector& mag, int32_t len) { + int32_t n = -1; + if (len == 0) { + n = 0; + } else { + // Calculate the bit length of the magnitude + int32_t mag_bit_length = + ((len - 1) << 5) + GetBitLengthForInt((uint32_t)mag[0]); + if (sig < 0) { + // Check if magnitude is a power of two + bool pow2 = (GetBitCount((uint32_t)mag[0]) == 1); + for (int i = 1; i < len && pow2; i++) + pow2 = (mag[i] == 0); + + n = (pow2 ? mag_bit_length - 1 : mag_bit_length); + } else { + n = mag_bit_length; + } + } + return n; + } + + static std::vector + ConvertMagArray(int64_t new_high, uint64_t new_low, int32_t* size) { + std::vector mag; + int64_t orignal_low = new_low; + int64_t orignal_high = new_high; + mag.push_back(new_high >>= 32); + mag.push_back((uint32_t)orignal_high); + mag.push_back(new_low >>= 32); + mag.push_back((uint32_t)orignal_low); + + int32_t start = 0; + // remove the front 0 + for (int32_t i = 0; i < 4; i++) { + if (mag[i] == 0) + start++; + if (mag[i] != 0) + break; + } + + int32_t length = 4 - start; + std::vector new_mag; + // get the mag after remove the high 0 + for (int32_t i = start; i < 4; i++) { + new_mag.push_back(mag[i]); + } + + *size = length; + return new_mag; + } + + /* + * This method refer to the BigInterger#toByteArray() method in Java side. + */ + inline static char* ToByteArray(int128_t value) { + int128_t new_value; + int32_t sig; + if (value > 0) { + new_value = value; + sig = 1; + } else if (value < 0) { + new_value = std::abs(value); + sig = -1; + } else { + new_value = value; + sig = 0; + } + + int64_t new_high; + uint64_t new_low; + + int128_t orignal_value = new_value; + new_high = new_value >> 64; + new_low = (uint64_t)orignal_value; + + std::vector mag; + int32_t size; + mag = ConvertMagArray(new_high, new_low, &size); + + std::vector final_mag; + for (auto i = 0; i < size; i++) { + final_mag.push_back(mag[i]); + } + + int32_t byte_length = GetBitLength(sig, final_mag, size) / 8 + 1; + + char* out = new char[16]; + uint32_t next_int = 0; + for (int32_t i = byte_length - 1, bytes_copied = 4, int_index = 0; i >= 0; + i--) { + if (bytes_copied == 4) { + next_int = GetInt(int_index++, sig, final_mag, size); + bytes_copied = 1; + } else { + next_int >>= 8; + bytes_copied++; + } + + out[i] = (uint8_t)next_int; + } + return out; + } + + inline static double toDoubleValue(int128_t value, uint8_t scale) { + int128_t new_value; + int32_t sig; + if (value > 0) { + new_value = value; + sig = 1; + } else if (value < 0) { + new_value = std::abs(value); + sig = -1; + } else { + new_value = value; + sig = 0; + } + + int64_t new_high; + uint64_t new_low; + + int128_t orignal_value = new_value; + new_high = new_value >> 64; + new_low = (uint64_t)orignal_value; + + double unscaled = static_cast(new_low) + + std::ldexp(static_cast(new_high), 64); + + // scale double. + return (unscaled * sig) / DecimalUtil::kPowersOfTen[scale]; + } + + template + inline static int numDigits(T number) { + int digits = 0; + if (number < 0) + digits = 1; // remove this line if '-' counts as a digit + while (number) { + number /= 10; + digits++; + } + return digits; + } + + static constexpr double double10pow[] = { + 1.0e0, 1.0e1, 1.0e2, 1.0e3, 1.0e4, 1.0e5, 1.0e6, 1.0e7, + 1.0e8, 1.0e9, 1.0e10, 1.0e11, 1.0e12, 1.0e13, 1.0e14, 1.0e15, + 1.0e16, 1.0e17, 1.0e18, 1.0e19, 1.0e20, 1.0e21, 1.0e22}; + static constexpr __uint128_t kOverflowMultiplier = ((__uint128_t)1 << 127); + static constexpr long kLongMinValue = 0x8000000000000000L; + static constexpr long kLONG_MASK = 0xffffffffL; + }; // DecimalUtil } // namespace facebook::velox diff --git a/velox/type/DecimalUtilOp.h b/velox/type/DecimalUtilOp.h new file mode 100644 index 000000000000..7da4f131d61b --- /dev/null +++ b/velox/type/DecimalUtilOp.h @@ -0,0 +1,431 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include "velox/common/base/CheckedArithmetic.h" +#include "velox/common/base/Exceptions.h" +#include "velox/type/DecimalUtil.h" +#include "velox/type/Type.h" + +namespace facebook::velox { + +using int128_t = __int128_t; +using boost::multiprecision::int256_t; +using uint128_t = __uint128_t; + +static inline int64_t convertToInt64(int256_t in, bool* overflow) { + int64_t result; + constexpr int256_t UINT64_MASK = std::numeric_limits::max(); + + int256_t inAbs = abs(in); + bool isNegative = in < 0; + + uint128_t unsignResult = (inAbs & UINT64_MASK).convert_to(); + inAbs >>= 64; + + if (inAbs > 0) { + // we've shifted in by 128-bit, so nothing should be left. + *overflow = true; + } else if (unsignResult > INT64_MAX) { + // the high-bit must not be set (signed 128-bit). + *overflow = true; + } else { + result = static_cast(unsignResult); + } + return isNegative ? -result : result; +} + +static inline int128_t convertToInt128(int256_t in, bool* overflow) { + int128_t result; +#ifndef INT128_MAX + int128_t INT128_MAX = int128_t(int128_t(-1L)) >> 1; +#endif + constexpr int256_t UINT128_MASK = std::numeric_limits::max(); + + int256_t inAbs = abs(in); + bool isNegative = in < 0; + + uint128_t unsignResult = (inAbs & UINT128_MASK).convert_to(); + inAbs >>= 128; + + if (inAbs > 0) { + // we've shifted in by 128-bit, so nothing should be left. + *overflow = true; + } else if (unsignResult > INT128_MAX) { + *overflow = true; + } else { + result = static_cast(unsignResult); + } + return isNegative ? -result : result; +} + +class DecimalUtilOp { + public: + inline static int32_t maxBitsRequiredIncreaseAfterScaling(int32_t scale_by) { + // We rely on the following formula: + // bits_required(x * 10^y) <= bits_required(x) + floor(log2(10^y)) + 1 + // We precompute floor(log2(10^x)) + 1 for x = 0, 1, 2...75, 76 + + static const int32_t floor_log2_plus_one[] = { + 0, 4, 7, 10, 14, 17, 20, 24, 27, 30, 34, 37, 40, + 44, 47, 50, 54, 57, 60, 64, 67, 70, 74, 77, 80, 84, + 87, 90, 94, 97, 100, 103, 107, 110, 113, 117, 120, 123, 127, + 130, 133, 137, 140, 143, 147, 150, 153, 157, 160, 163, 167, 170, + 173, 177, 180, 183, 187, 190, 193, 196, 200, 203, 206, 210, 213, + 216, 220, 223, 226, 230, 233, 236, 240, 243, 246, 250, 253}; + return floor_log2_plus_one[scale_by]; + } + + template + inline static int32_t maxBitsRequiredAfterScaling( + const A& num, + uint8_t aRescale) { + auto value = num; + auto valueAbs = std::abs(value); + int32_t num_occupied = 0; + if constexpr (std::is_same_v) { + num_occupied = 64 - bits::countLeadingZeros(valueAbs); + } else { + num_occupied = 128 - bits::countLeadingZerosUint128(std::abs(num)); + } + + return num_occupied + maxBitsRequiredIncreaseAfterScaling(aRescale); + } + + // If we have a number with 'numLz' leading zeros, and we scale it up by + // 10^scale_by, + // this function returns the minimum number of leading zeros the result can + // have. + inline static int32_t minLeadingZerosAfterScaling( + int32_t numLz, + int32_t scaleBy) { + int32_t result = numLz - maxBitsRequiredIncreaseAfterScaling(scaleBy); + return result; + } + + template + inline static int32_t + minLeadingZeros(const A& a, const B& b, uint8_t aScale, uint8_t bScale) { + auto x_value_abs = std::abs(a); + + auto y_value_abs = std::abs(b); + int32_t x_lz = 0; + int32_t y_lz = 0; + if constexpr (std::is_same_v) { + x_lz = bits::countLeadingZerosUint128(std::abs(a)); + } else { + x_lz = bits::countLeadingZeros(a); + } + if constexpr (std::is_same_v) { + y_lz = bits::countLeadingZerosUint128(std::abs(b)); + } else { + y_lz = bits::countLeadingZeros(b); + } + if (aScale < bScale) { + x_lz = minLeadingZerosAfterScaling(x_lz, bScale - aScale); + } else if (aScale > bScale) { + y_lz = minLeadingZerosAfterScaling(y_lz, aScale - bScale); + } + return std::min(x_lz, y_lz); + } + + template + inline static R divideWithRoundUp( + R& r, + const A& a, + const B& b, + bool noRoundUp, + uint8_t aRescale, + uint8_t /*bRescale*/, + bool* overflow) { + if (b == 0) { + *overflow = true; + return R(-1); + } + int resultSign = 1; + R unsignedDividendRescaled(a); + int aSign = 1; + int bSign = 1; + if (a < 0) { + resultSign = -1; + unsignedDividendRescaled *= -1; + aSign = -1; + } + R unsignedDivisor(b); + if (b < 0) { + resultSign *= -1; + unsignedDivisor *= -1; + bSign = -1; + } + auto bitsRequiredAfterScaling = maxBitsRequiredAfterScaling(a, aRescale); + if (bitsRequiredAfterScaling <= 127) { + unsignedDividendRescaled = checkedMultiply( + unsignedDividendRescaled, R(DecimalUtil::kPowersOfTen[aRescale])); + if (*overflow) { + return R(-1); + } + R quotient = unsignedDividendRescaled / unsignedDivisor; + R remainder = unsignedDividendRescaled % unsignedDivisor; + if (!noRoundUp && remainder * 2 >= unsignedDivisor) { + ++quotient; + } + r = quotient * resultSign; + return remainder; + } else if constexpr ( + std::is_same_v || std::is_same_v) { + // Derives from Arrow BasicDecimal128 Divide + if (aRescale > 38 && bitsRequiredAfterScaling > 255) { + *overflow = true; + return R(-1); + } + int256_t aLarge = a; + int256_t x_large_scaled_up = aLarge * DecimalUtil::kPowersOfTen[aRescale]; + int256_t y_large = b; + int256_t result_large = x_large_scaled_up / y_large; + int256_t remainder_large = x_large_scaled_up % y_large; + // Since we are scaling up and then, scaling down, round-up the result (+1 + // for +ve, -1 for -ve), if the remainder is >= 2 * divisor. + if (abs(2 * remainder_large) >= abs(y_large)) { + // x +ve and y +ve, result is +ve => (1 ^ 1) + 1 = 0 + 1 = +1 + // x +ve and y -ve, result is -ve => (-1 ^ 1) + 1 = -2 + 1 = -1 + // x +ve and y -ve, result is -ve => (1 ^ -1) + 1 = -2 + 1 = -1 + // x -ve and y -ve, result is +ve => (-1 ^ -1) + 1 = 0 + 1 = +1 + result_large += (aSign ^ bSign) + 1; + } + if constexpr (std::is_same_v) { + auto result = convertToInt64(result_large, overflow); + auto remainder = convertToInt64(remainder_large, overflow); + if (!(result >= DecimalUtil::kLongDecimalMin && + result <= DecimalUtil::kLongDecimalMax)) { + *overflow = true; + } else { + r = result; + } + return remainder; + } + if constexpr (std::is_same_v) { + auto result = convertToInt128(result_large, overflow); + auto remainder = convertToInt128(remainder_large, overflow); + if (!(result >= DecimalUtil::kLongDecimalMin && + result <= DecimalUtil::kLongDecimalMax)) { + *overflow = true; + } else { + r = result; + } + return remainder; + } + } else { + VELOX_FAIL("Should not reach here in DecimalUtilOp.h"); + } + } + + // Convert a number of scientific notation to normal. + inline static std::string getNormalNumber(const std::string& value) { + size_t dotPos = value.find('.'); + size_t expPos = value.find('E'); + if (expPos == std::string::npos) { + return value; + } + + std::string ints; + std::string digits; + // Get the integers and digits from the base number. + if (dotPos == std::string::npos) { + ints = value.substr(0, expPos); + digits = ""; + } else { + ints = value.substr(0, dotPos); + digits = value.substr(dotPos + 1, expPos - dotPos - 1); + } + + size_t pos = value.find("E+"); + // Handle number with positive exponent. + if (pos != std::string::npos) { + int exponent = std::stoi(value.substr(pos + 2, value.length())); + std::string number = ints; + if (exponent >= digits.length()) { + // Dot is not needed. + number = ints + digits; + for (int i = 0; i < exponent - digits.length(); i++) { + number += '0'; + } + } else { + number += digits.substr(0, exponent) + '.' + + digits.substr(exponent + 1, digits.length()); + } + return number; + } + pos = value.find("E-"); + if (pos != std::string::npos) { + int exponent = std::stoi(value.substr(pos + 2, value.length())); + std::string number; + if (exponent < ints.length()) { + number = ints.substr(0, ints.length() - exponent) + '.' + + ints.substr(ints.length() - exponent + 1, ints.length()); + } else { + number = "0."; + for (int i = 0; i < exponent - ints.length(); i++) { + number += '0'; + } + number += ints; + number += digits; + } + return number; + } + return value; + } + + // Round double to certain precision with half up. + inline static double roundTo(double value, int precision) { + int charsNeeded = 1 + snprintf(NULL, 0, "%.*f", (int)precision, value); + char* buffer = reinterpret_cast(malloc(charsNeeded)); + double nextValue; + if (value < 0) { + nextValue = nextafter(value, value - 0.1); + } else { + nextValue = nextafter(value, value + 0.1); + } + snprintf(buffer, charsNeeded, "%.*f", (int)precision, nextValue); + return atof(buffer); + } + + // return unscaled value and scale + inline static std::pair splitVarChar( + const StringView& value, + int toScale) { + std::string s = getNormalNumber(value.str()); + size_t pos = s.find('.'); + if (pos == std::string::npos) { + return {s.substr(0, pos), 0}; + } else if (toScale < s.length() - pos - 1) { + // If toScale is less than scales.length(), the string scales will be cut + // and rounded. + std::string roundedValue = std::to_string(roundTo(std::stod(s), toScale)); + pos = roundedValue.find('.'); + std::string scales = roundedValue.substr(pos + 1, toScale); + return {roundedValue.substr(0, pos) + scales, scales.length()}; + } else { + std::string scales = s.substr(pos + 1, s.length()); + return {s.substr(0, pos) + scales, scales.length()}; + } + } + + static int128_t convertStringToInt128( + const std::string& value, + bool& nullOutput) { + // Handling integer target cases + const char* v = value.c_str(); + nullOutput = true; + bool negative = false; + int128_t result = 0; + int index = 0; + int len = value.size(); + if (len == 0) { + return -1; + } + // Setting negative flag + if (v[0] == '-') { + if (len == 1) { + return -1; + } + negative = true; + index = 1; + } + if (negative) { + for (; index < len; index++) { + if (!std::isdigit(v[index])) { + return -1; + } + result = result * 10 - (v[index] - '0'); + // Overflow check + if (result > 0) { + return -1; + } + } + } else { + for (; index < len; index++) { + if (!std::isdigit(v[index])) { + return -1; + } + result = result * 10 + (v[index] - '0'); + // Overflow check + if (result < 0) { + return -1; + } + } + } + // Final result + nullOutput = false; + return result; + } + + template + inline static std::optional rescaleVarchar( + const StringView& inputValue, + const int toPrecision, + const int toScale) { + static_assert( + std::is_same_v || std::is_same_v); + auto [unscaledStr, fromScale] = splitVarChar(inputValue, toScale); + uint8_t fromPrecision = unscaledStr.size(); + VELOX_CHECK_LE(fromPrecision, LongDecimalType::kMaxPrecision); + if (fromPrecision <= 18) { + int64_t fromUnscaledValue = folly::to(unscaledStr); + return DecimalUtil::rescaleWithRoundUp( + fromUnscaledValue, + fromPrecision, + fromScale, + toPrecision, + toScale, + false, + true); + } else { + bool nullOutput = true; + int128_t decimalValue = convertStringToInt128(unscaledStr, nullOutput); + if (nullOutput) { + VELOX_USER_FAIL( + "Cannot cast StringView '{}' to DECIMAL({},{})", + inputValue, + toPrecision, + toScale); + } + return DecimalUtil::rescaleWithRoundUp( + decimalValue, + fromPrecision, + fromScale, + toPrecision, + toScale, + false, + true); + } + } + + template + inline static std::optional rescaleDouble( + const TInput inputValue, + const int toPrecision, + const int toScale) { + static_assert( + std::is_same_v || std::is_same_v); + auto str = velox::to(inputValue); + auto stringView = StringView(str.c_str(), str.size()); + return rescaleVarchar(stringView, toPrecision, toScale); + } +}; +} // namespace facebook::velox diff --git a/velox/type/Filter.cpp b/velox/type/Filter.cpp index e289d2cec313..e0dc57da2ad2 100644 --- a/velox/type/Filter.cpp +++ b/velox/type/Filter.cpp @@ -64,6 +64,9 @@ std::string Filter::toString() const { case FilterKind::kDoubleRange: strKind = "DoubleRange"; break; + case FilterKind::kDoubleValues: + strKind = "DoubleValues"; + break; case FilterKind::kFloatRange: strKind = "FloatRange"; break; @@ -114,6 +117,7 @@ std::unordered_map filterKindNames() { {FilterKind::kNegatedBigintValuesUsingBitmask, "kNegatedBigintValuesUsingBitmask"}, {FilterKind::kDoubleRange, "kDoubleRange"}, + {FilterKind::kDoubleValues, "kDoubleValues"}, {FilterKind::kFloatRange, "kFloatRange"}, {FilterKind::kBytesRange, "kBytesRange"}, {FilterKind::kNegatedBytesRange, "kNegatedBytesRange"}, @@ -1009,6 +1013,33 @@ std::unique_ptr notNullOrTrue(bool nullAllowed) { return std::make_unique(); } +} // namespace + +std::unique_ptr createDoubleValues( + const std::vector& values, + bool nullAllowed) { + if (values.empty()) { + return nullOrFalse(nullAllowed); + } + if (values.size() == 1) { + return std::make_unique( + values.front(), + false, + false, + values.front(), + false, + false, + nullAllowed); + } + double min = values.front(); + double max = values.front(); + for (const auto& value : values) { + min = (value < min) ? value : min; + max = (value > max) ? value : max; + } + return std::make_unique(min, max, values, nullAllowed); +} + std::unique_ptr createBigintValuesFilter( const std::vector& values, bool nullAllowed, @@ -1065,7 +1096,6 @@ std::unique_ptr createBigintValuesFilter( return std::make_unique( min, max, values, nullAllowed); } -} // namespace std::unique_ptr createBigintValues( const std::vector& values, @@ -1079,6 +1109,133 @@ std::unique_ptr createNegatedBigintValues( return createBigintValuesFilter(values, nullAllowed, true); } +DoubleValues::DoubleValues( + double min, + double max, + const std::vector& values, + bool nullAllowed) + : Filter(true, nullAllowed, FilterKind::kDoubleValues), + min_(min), + max_(max) { + VELOX_CHECK(min < max, "min must be less than max"); + VELOX_CHECK(values.size() > 1, "values must contain at least 2 entries"); + + bitmask_.resize(toInt64(max) - toInt64(min) + 1); + + for (double value : values) { + bitmask_[toInt64(value) - toInt64(min)] = true; + } +} + +bool DoubleValues::testDouble(double value) const { + if (value < min_ || value > max_) { + return false; + } + return bitmask_[toInt64(value) - toInt64(min_)]; +} + +std::vector DoubleValues::values() const { + std::vector values; + for (int i = 0; i < bitmask_.size(); i++) { + if (bitmask_[i]) { + values.push_back(min_ + i); + } + } + return values; +} + +bool DoubleValues::testDoubleRange(double min, double max, bool hasNull) const { + if (hasNull && nullAllowed_) { + return true; + } + + if (toInt64(min) == toInt64(max)) { + return testDouble(min); + } + + return !(min > max_ || max < min_); +} + +std::unique_ptr DoubleValues::mergeWith(const Filter* other) const { + switch (other->kind()) { + case FilterKind::kAlwaysTrue: + case FilterKind::kAlwaysFalse: + case FilterKind::kIsNull: + return other->mergeWith(this); + case FilterKind::kIsNotNull: + return std::make_unique(*this, false); + case FilterKind::kDoubleRange: { + auto otherRange = dynamic_cast(other); + + auto min = std::max(min_, otherRange->lower()); + auto max = std::min(max_, otherRange->upper()); + + return mergeWith(min, max, other); + } + case FilterKind::kFloatRange: { + auto otherRange = dynamic_cast(other); + + auto min = std::max(min_, otherRange->lower()); + auto max = std::min(max_, otherRange->upper()); + + return mergeWith(min, max, other); + } + case FilterKind::kDoubleValues: { + auto otherValues = dynamic_cast(other); + + auto min = std::max(min_, otherValues->min_); + auto max = std::min(max_, otherValues->max_); + + return mergeWith(min, max, other); + } + default: + VELOX_UNREACHABLE(); + } +} + +std::unique_ptr +DoubleValues::mergeWith(double min, double max, const Filter* other) const { + bool bothNullAllowed = nullAllowed_ && other->testNull(); + + std::vector valuesToKeep; + for (auto i = min; i <= max; ++i) { + if (bitmask_[toInt64(i) - toInt64(min_)] && other->testDouble(i)) { + valuesToKeep.push_back(i); + } + } + return createDoubleValues(valuesToKeep, bothNullAllowed); +} + +folly::dynamic DoubleValues::serialize() const { + auto obj = Filter::serializeBase("DoubleValues"); + obj["min"] = min_; + obj["max"] = max_; + folly::dynamic bitmask = folly::dynamic::array; + for (auto v : bitmask_) { + bitmask.push_back(v); + } + obj["bitmask"] = bitmask; + return obj; +} + +bool DoubleValues::testingEquals(const Filter& other) const { + auto otherDoubleValues = dynamic_cast(&other); + bool res = otherDoubleValues != nullptr && Filter::testingBaseEquals(other) && + min_ == otherDoubleValues->min_ && max_ == otherDoubleValues->max_ && + bitmask_.size() == otherDoubleValues->bitmask_.size(); + if (!res) { + return false; + } + // values_ can be compared pair-wise since they are sorted. + for (size_t i = 0; i < bitmask_.size(); ++i) { + if (bitmask_.at(i) != otherDoubleValues->bitmask_.at(i)) { + return false; + } + } + + return true; +} + BigintMultiRange::BigintMultiRange( std::vector> ranges, bool nullAllowed) @@ -1152,25 +1309,27 @@ bool BytesRange::testBytesRange( if (lowerUnbounded_) { // min > upper_ + int compare = compareRanges(min->data(), min->length(), upper_); return min.has_value() && - compareRanges(min->data(), min->length(), upper_) < 0; + (compare < 0 || (!upperExclusive_ && compare == 0)); } if (upperUnbounded_) { // max < lower_ + int compare = compareRanges(max->data(), max->length(), lower_); return max.has_value() && - compareRanges(max->data(), max->length(), lower_) > 0; + (compare > 0 || (!lowerExclusive_ && compare == 0)); } // min > upper_ - if (min.has_value() && - compareRanges(min->data(), min->length(), upper_) > 0) { + int compare = compareRanges(min->data(), min->length(), upper_); + if (min.has_value() && (compare > 0 || (compare == 0 && upperExclusive_))) { return false; } // max < lower_ - if (max.has_value() && - compareRanges(max->data(), max->length(), lower_) < 0) { + compare = compareRanges(max->data(), max->length(), lower_); + if (max.has_value() && (compare < 0 || (compare == 0 && lowerExclusive_))) { return false; } return true; diff --git a/velox/type/Filter.h b/velox/type/Filter.h index f9e8a8733ae7..d0ed94d1f405 100644 --- a/velox/type/Filter.h +++ b/velox/type/Filter.h @@ -47,6 +47,7 @@ enum class FilterKind { kNegatedBigintValuesUsingHashTable, kNegatedBigintValuesUsingBitmask, kDoubleRange, + kDoubleValues, kFloatRange, kBytesRange, kNegatedBytesRange, @@ -622,7 +623,27 @@ class BigintRange final : public Filter { folly::dynamic serialize() const override; static FilterPtr create(const folly::dynamic& obj); - + + BigintRange( + int64_t lower, + bool lowerUnbounded, + bool lowerExclusive, + int64_t upper, + bool upperUnbounded, + bool upperExclusive, + bool nullAllowed) + : Filter(true, nullAllowed, FilterKind::kBigintRange), + lower_(lowerExclusive ? lower + 1 : lower), + upper_(upperExclusive ? upper - 1 : upper), + lower32_( + std::max(lower_, std::numeric_limits::min())), + upper32_( + std::min(upper_, std::numeric_limits::max())), + lower16_( + std::max(lower_, std::numeric_limits::min())), + upper16_( + std::min(upper_, std::numeric_limits::max())), + isSingleValue_(upper_ == lower_) {} std::unique_ptr clone( std::optional nullAllowed = std::nullopt) const final { if (nullAllowed) { @@ -1125,6 +1146,60 @@ class NegatedBigintValuesUsingBitmask final : public Filter { std::unique_ptr nonNegated_; }; +class DoubleValues final : public Filter { + public: + /// @param min Minimum value. + /// @param max Maximum value. + /// @param values A list of unique values that pass the filter. Must contain + /// at least two entries. + /// @param nullAllowed Null values are passing the filter if true. + DoubleValues( + double min, + double max, + const std::vector& values, + bool nullAllowed); + + DoubleValues(const DoubleValues& other, bool nullAllowed) + : Filter(true, nullAllowed, FilterKind::kDoubleValues), + bitmask_(other.bitmask_), + min_(other.min_), + max_(other.max_) {} + + folly::dynamic serialize() const override; + + std::unique_ptr clone( + std::optional nullAllowed = std::nullopt) const final { + if (nullAllowed) { + return std::make_unique(*this, nullAllowed.value()); + } else { + return std::make_unique(*this); + } + } + + std::vector values() const; + + bool testDouble(double value) const final; + + bool testDoubleRange(double min, double max, bool hasNull) const final; + + std::unique_ptr mergeWith(const Filter* other) const final; + + bool testingEquals(const Filter& other) const final; + + private: + std::unique_ptr mergeWith(double min, double max, const Filter* other) + const; + + int64_t toInt64(double value) const { + int64_t converted = (int64_t)(value + 0.5); + return converted; + } + + std::vector bitmask_; + const double min_; + const double max_; +}; + /// Base class for range filters on floating point and string data types. class AbstractRange : public Filter { public: @@ -1859,6 +1934,11 @@ class MultiRange final : public Filter { filters_(std::move(filters)), nanAllowed_(nanAllowed) {} + MultiRange(std::vector> filters, bool nullAllowed) + : Filter(true, nullAllowed, FilterKind::kMultiRange), + filters_(std::move(filters)), + nanAllowed_(true) {} + folly::dynamic serialize() const override; static FilterPtr create(const folly::dynamic& obj); @@ -1940,4 +2020,8 @@ std::unique_ptr createNegatedBigintValues( const std::vector& values, bool nullAllowed); +std::unique_ptr createDoubleValues( + const std::vector& values, + bool nullAllowed); + } // namespace facebook::velox::common diff --git a/velox/type/Subfield.cpp b/velox/type/Subfield.cpp index 40be316a0902..9962289cfc11 100644 --- a/velox/type/Subfield.cpp +++ b/velox/type/Subfield.cpp @@ -18,8 +18,8 @@ namespace facebook::velox::common { -Subfield::Subfield(const std::string& path) { - Tokenizer tokenizer(path); +Subfield::Subfield(const std::string& path, bool dotAsRegular) { + Tokenizer tokenizer(path, dotAsRegular); VELOX_CHECK(tokenizer.hasNext(), "Column name is missing: {}", path); auto firstElement = tokenizer.next(); diff --git a/velox/type/Subfield.h b/velox/type/Subfield.h index 25869ee750a2..2cc8279be76a 100644 --- a/velox/type/Subfield.h +++ b/velox/type/Subfield.h @@ -191,7 +191,7 @@ class Subfield { }; public: - explicit Subfield(const std::string& path); + explicit Subfield(const std::string& path, bool dotAsRegular = false); explicit Subfield(std::vector>&& path); diff --git a/velox/type/Timestamp.h b/velox/type/Timestamp.h index 7e93451baf4a..83e823b6f5c4 100644 --- a/velox/type/Timestamp.h +++ b/velox/type/Timestamp.h @@ -34,8 +34,11 @@ struct Timestamp { public: enum class Precision : int { kMilliseconds = 3, kNanoseconds = 9 }; constexpr Timestamp() : seconds_(0), nanos_(0) {} - constexpr Timestamp(int64_t seconds, uint64_t nanos) - : seconds_(seconds), nanos_(nanos) {} + Timestamp(int64_t seconds, uint64_t nanos) { + constexpr const uint64_t kNanosPerSecond{1000000000}; + seconds_ = seconds + nanos / kNanosPerSecond; + nanos_ = nanos - (nanos / kNanosPerSecond) * kNanosPerSecond; + } // Returns the current unix timestamp (ms precision). static Timestamp now(); diff --git a/velox/type/TimestampConversion.cpp b/velox/type/TimestampConversion.cpp index 1df2e7d70cb1..a657a250110e 100644 --- a/velox/type/TimestampConversion.cpp +++ b/velox/type/TimestampConversion.cpp @@ -177,12 +177,17 @@ bool tryParseDateString( if (!characterIsDigit(buf[pos])) { return false; } + int yearSegStart = pos; // First parse the year. for (; pos < len && characterIsDigit(buf[pos]); pos++) { year = checkedPlus((buf[pos] - '0'), checkedMultiply(year, 10)); if (year > kMaxYear) { break; } + // Align with spark, year digits should not be greater than 7. + if (pos - yearSegStart + 1 > 7) { + return false; + } } if (yearneg) { year = checkedNegate(year); @@ -191,13 +196,20 @@ bool tryParseDateString( } } - if (pos >= len) { + // No month or day. + if (pos == len) { + daysSinceEpoch = daysSinceEpochFromDate(year, 1, 1); + return true; + } + + if (pos > len) { return false; } // Fetch the separator. sep = buf[pos++]; - if (sep != ' ' && sep != '-' && sep != '/' && sep != '\\') { + // For spark, "/" separtor is not supported. + if (sep != ' ' && sep != '-' && sep != '\\') { // Invalid separator. return false; } @@ -207,7 +219,13 @@ bool tryParseDateString( return false; } - if (pos >= len) { + // No day. + if (pos == len) { + daysSinceEpoch = daysSinceEpochFromDate(year, month, 1); + return true; + } + + if (pos > len) { return false; } @@ -240,13 +258,25 @@ bool tryParseDateString( // In strict mode, check remaining string for non-space characters. if (strict) { - // Skip trailing spaces. - while (pos < len && characterIsSpace(buf[pos])) { + // Check for an optional trailing 'T' followed by optional digits. + if (pos < len && buf[pos] == 'T') { pos++; - } - // Check position. if end was not reached, non-space chars remaining. - if (pos < len) { - return false; + while (pos < len && characterIsDigit(buf[pos])) { + pos++; + } + } else { + // Skip trailing spaces. + while (pos < len && characterIsSpace(buf[pos])) { + pos++; + } + // Skip trailing digits after spaces. + while (pos < len && characterIsDigit(buf[pos])) { + pos++; + } + // Check position. if end was not reached, non-space chars remaining. + if (pos < len) { + return false; + } } } else { // In non-strict mode, check for any direct trailing digits. @@ -499,6 +529,19 @@ int64_t fromDateString(const char* str, size_t len) { size_t pos = 0; if (!tryParseDateString(str, len, pos, daysSinceEpoch, true)) { + if (len == 19) { + // Timestamp format: (YYYY-MM-DD HH:MM:SS). + std::string input(str); + size_t strLen = 10; + std::string leadingStr = input.substr(0, strLen); + if (!tryParseDateString( + leadingStr.c_str(), strLen, pos, daysSinceEpoch, true)) { + VELOX_USER_FAIL( + "Unable to parse date value: \"{}\", expected format is (YYYY-MM-DD)", + std::string(leadingStr, strLen)); + } + return daysSinceEpoch; + } VELOX_USER_FAIL( "Unable to parse date value: \"{}\", expected format is (YYYY-MM-DD)", std::string(str, len)); diff --git a/velox/type/Tokenizer.cpp b/velox/type/Tokenizer.cpp index 53c4d9d595b2..548332d96214 100644 --- a/velox/type/Tokenizer.cpp +++ b/velox/type/Tokenizer.cpp @@ -17,7 +17,8 @@ namespace facebook::velox::common { -Tokenizer::Tokenizer(const std::string& path) : path_(path) { +Tokenizer::Tokenizer(const std::string& path, bool dotAsRegular) + : path_(path), dotAsRegular_(dotAsRegular) { state = State::kNotReady; index_ = 0; } @@ -54,7 +55,7 @@ std::unique_ptr Tokenizer::computeNext() { return nullptr; } - if (tryMatch(DOT)) { + if (!dotAsRegular_ && tryMatch(DOT)) { std::unique_ptr token = matchPathSegment(); firstSegment = false; return token; @@ -144,8 +145,12 @@ std::unique_ptr Tokenizer::matchUnquotedSubscript() { } bool Tokenizer::isUnquotedPathCharacter(char c) { - return c == ':' || c == '$' || c == '-' || c == '/' || c == '@' || c == '|' || - c == '#' || isUnquotedSubscriptCharacter(c); + bool unquoted = c == ':' || c == '$' || c == '-' || c == '/' || c == '@' || + c == '|' || c == '#' || isUnquotedSubscriptCharacter(c); + if (dotAsRegular_) { + return unquoted || c == '.'; + } + return unquoted; } bool Tokenizer::isUnquotedSubscriptCharacter(char c) { diff --git a/velox/type/Tokenizer.h b/velox/type/Tokenizer.h index 56380f0aead2..c4e88794a93a 100644 --- a/velox/type/Tokenizer.h +++ b/velox/type/Tokenizer.h @@ -35,7 +35,7 @@ class Tokenizer { kFailed, }; - explicit Tokenizer(const std::string& path); + explicit Tokenizer(const std::string& path, bool dotAsRegular = false); bool hasNext(); @@ -51,6 +51,8 @@ class Tokenizer { const char UNICODE_CARET = '^'; const std::string path_; + // Whether to treat dot as regular charactor. + bool dotAsRegular_; int index_; State state; bool firstSegment = true; diff --git a/velox/type/Type.h b/velox/type/Type.h index 9cbeb0a1c955..77fc958f5cc1 100644 --- a/velox/type/Type.h +++ b/velox/type/Type.h @@ -45,6 +45,10 @@ namespace facebook::velox { using int128_t = __int128_t; +struct __attribute__((__packed__)) int96_t { + int32_t days; + uint64_t nanos; +}; /// Velox type system supports a small set of SQL-compatible composeable types: /// BOOLEAN, TINYINT, SMALLINT, INTEGER, BIGINT, HUGEINT, REAL, DOUBLE, VARCHAR, diff --git a/velox/type/Variant.cpp b/velox/type/Variant.cpp index 37c9ec39ee29..8ed37f113a27 100644 --- a/velox/type/Variant.cpp +++ b/velox/type/Variant.cpp @@ -34,6 +34,85 @@ const folly::json::serialization_opts& getOpts() { } } // namespace +std::optional isFloatingPointType(const variant& value) { + if (!value.hasValue()) { + return std::nullopt; + } + + if (value.kind() == TypeKind::REAL || value.kind() == TypeKind::DOUBLE) { + return true; + } else { + if (value.kind() == TypeKind::ARRAY) { + auto elements = value.value(); + if (elements.empty()) { + return false; + } else { + for (const auto& element : elements) { + auto result = isFloatingPointType(element); + if (result.has_value()) { + return result; + } + } + return std::nullopt; + } + } else if (value.kind() == TypeKind::MAP) { + auto pairs = value.value(); + if (pairs.empty()) { + return false; + } else { + std::optional floatingKey = std::nullopt; + for (const auto& pair : pairs) { + auto result = isFloatingPointType(pair.first); + if (result.has_value()) { + floatingKey = result; + break; + } + } + + std::optional floatingValue = std::nullopt; + for (const auto& pair : pairs) { + auto result = isFloatingPointType(pair.second); + if (result.has_value()) { + floatingValue = result; + break; + } + } + + if ((floatingKey.has_value() && *floatingKey) || + (floatingValue.has_value() && *floatingValue)) { + return true; + } + if (!floatingKey.has_value() || !floatingValue.has_value()) { + return std::nullopt; + } + return false; + } + } else if (value.kind() == TypeKind::ROW) { + auto children = value.value(); + if (children.empty()) { + return false; + } else { + bool undetermined = false; + for (const auto& child : children) { + auto result = isFloatingPointType(child); + if (result.has_value() && *result) { + return true; + } + if (!result.has_value()) { + undetermined = true; + } + } + if (undetermined) { + return std::nullopt; + } + return false; + } + } else { + return false; + } + } +} + template bool evaluateNullEquality(const variant& a, const variant& b) { if constexpr (nullEqualsNull) { @@ -682,6 +761,57 @@ bool variant::equalsWithEpsilon(const variant& other) const { } if ((kind_ == TypeKind::REAL) or (kind_ == TypeKind::DOUBLE)) { return equalsFloatingPointWithEpsilon(*this, other); + } else if (kind_ == TypeKind::ARRAY) { + auto elements = value(); + auto otherElements = other.value(); + if (elements.size() != otherElements.size()) { + return false; + } + if (elements.empty()) { + return true; + } else { + for (auto i = 0; i < elements.size(); ++i) { + if (!elements[i].equalsWithEpsilon(otherElements[i])) { + return false; + } + } + return true; + } + } else if (kind_ == TypeKind::MAP) { + auto pairs = value(); + auto otherPairs = other.value(); + if (pairs.size() != otherPairs.size()) { + return false; + } + if (pairs.empty()) { + return true; + } else { + auto iter = pairs.begin(); + auto otherIter = otherPairs.begin(); + for (; iter != pairs.end(); ++iter, ++otherIter) { + if (!iter->first.equalsWithEpsilon(otherIter->first) || + !iter->second.equalsWithEpsilon(otherIter->second)) { + return false; + } + } + return true; + } + } else if (kind_ == TypeKind::ROW) { + auto children = value(); + auto otherChildren = other.value(); + if (children.size() != otherChildren.size()) { + return false; + } + if (children.empty()) { + return true; + } else { + for (auto i = 0; i < children.size(); ++i) { + if (!children[i].equalsWithEpsilon(otherChildren[i])) { + return false; + } + } + return true; + } } return VELOX_DYNAMIC_TYPE_DISPATCH_ALL(equals, kind_, *this, other); diff --git a/velox/type/Variant.h b/velox/type/Variant.h index d0d6323dc17a..2b44c5112377 100644 --- a/velox/type/Variant.h +++ b/velox/type/Variant.h @@ -435,6 +435,10 @@ class variant { } void checkIsKind(TypeKind kind) const { + // Integer is compatible for getting the value from a date variant. + if (kind_ == TypeKind::DATE && kind == TypeKind::INTEGER) { + return; + } if (kind_ != kind) { // Error path outlined to encourage inlining of the branch. throwCheckIsKindError(kind); @@ -630,4 +634,8 @@ struct VariantConverter { } }; +// Return true if value is of a floating-point type or a complex type that +// contains a floating-point-typed child. +std::optional isFloatingPointType(const variant& value); + } // namespace facebook::velox diff --git a/velox/type/tests/TimestampConversionTest.cpp b/velox/type/tests/TimestampConversionTest.cpp index ff47c6fda7a1..edad734549d6 100644 --- a/velox/type/tests/TimestampConversionTest.cpp +++ b/velox/type/tests/TimestampConversionTest.cpp @@ -83,7 +83,8 @@ TEST(DateTimeUtilTest, fromDateString) { EXPECT_EQ(-719162, fromDateString(" \t \n 00001-1-1 \n")); // Different separators. - EXPECT_EQ(-719162, fromDateString("1/1/1")); + // Illegal date format for spark. + // EXPECT_EQ(-719162, fromDateString("1/1/1")); EXPECT_EQ(-719162, fromDateString("1 1 1")); EXPECT_EQ(-719162, fromDateString("1\\1\\1")); @@ -94,7 +95,7 @@ TEST(DateTimeUtilTest, fromDateString) { TEST(DateTimeUtilTest, fromDateStrInvalid) { EXPECT_THROW(fromDateString(""), VeloxUserError); EXPECT_THROW(fromDateString(" "), VeloxUserError); - EXPECT_THROW(fromDateString("2000"), VeloxUserError); + EXPECT_EQ(fromDateString("2000"), 10957); // Different separators. EXPECT_THROW(fromDateString("2000/01-01"), VeloxUserError); @@ -102,11 +103,11 @@ TEST(DateTimeUtilTest, fromDateStrInvalid) { // Trailing characters. EXPECT_THROW(fromDateString("2000-01-01 asdf"), VeloxUserError); - EXPECT_THROW(fromDateString("2000-01-01 0"), VeloxUserError); + EXPECT_EQ(fromDateString("2000-01-01 0"), 10957); // Too large of a year. - EXPECT_THROW(fromDateString("1000000"), VeloxUserError); - EXPECT_THROW(fromDateString("-1000000"), VeloxUserError); + EXPECT_EQ(fromDateString("1000000"), 364522972); + EXPECT_EQ(fromDateString("-1000000"), -365962028); } TEST(DateTimeUtilTest, fromTimeString) { diff --git a/velox/vector/BaseVector.cpp b/velox/vector/BaseVector.cpp index 4e9b6de95517..b748310e7f81 100644 --- a/velox/vector/BaseVector.cpp +++ b/velox/vector/BaseVector.cpp @@ -616,7 +616,7 @@ VectorPtr BaseVector::createConstant( variant value, vector_size_t size, velox::memory::MemoryPool* pool) { - VELOX_CHECK_EQ(type->kind(), value.kind()); + VELOX_CHECK(compatibleKind(type->kind(), value.kind())); return VELOX_DYNAMIC_SCALAR_TYPE_DISPATCH_ALL( newConstant, value.kind(), type, value, size, pool); } diff --git a/velox/vector/BaseVector.h b/velox/vector/BaseVector.h index 912961cd5245..a2520f0524b9 100644 --- a/velox/vector/BaseVector.h +++ b/velox/vector/BaseVector.h @@ -671,7 +671,11 @@ class BaseVector { // two unknowns but values cannot be assigned into an unknown 'left' // from a not-unknown 'right'. static bool compatibleKind(TypeKind left, TypeKind right) { - return left == right || right == TypeKind::UNKNOWN; + // Vectors of VARCHAR and VARBINARY are compatible with each other. + bool varcharAndBinary = + (left == TypeKind::VARCHAR && right == TypeKind::VARBINARY) || + (left == TypeKind::VARBINARY && right == TypeKind::VARCHAR); + return left == right || right == TypeKind::UNKNOWN || varcharAndBinary; } /// Returns a brief summary of the vector. If 'recursive' is true, includes a diff --git a/velox/vector/ComplexVectorStream.h b/velox/vector/ComplexVectorStream.h new file mode 100644 index 000000000000..14acd4b00bbe --- /dev/null +++ b/velox/vector/ComplexVectorStream.h @@ -0,0 +1,34 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "velox/vector/ComplexVector.h" + +namespace facebook::velox { + +class RowVectorStream { + public: + RowVectorStream() {} + + virtual ~RowVectorStream() = default; + + virtual bool hasNext() = 0; + + virtual RowVectorPtr next() = 0; +}; + +} // namespace facebook::velox diff --git a/velox/vector/VectorStream.h b/velox/vector/VectorStream.h index 0dd7184641d3..b494c831f56d 100644 --- a/velox/vector/VectorStream.h +++ b/velox/vector/VectorStream.h @@ -39,6 +39,19 @@ class VectorSerializer { const RowVectorPtr& vector, const folly::Range& ranges) = 0; + // Usage + // append(vector, ranges); + // vector_size_t size = serializedSize(); + // OutputStream* stream = allocateBuffer(size); + // flush(); + // + // So we can allocate memory for flush OutputStream size + // The return value is accurate without compress, + // Return the maximum required size with different compress codec + virtual vector_size_t maxSerializedSize() { + VELOX_NYI("{} unsupported", __FUNCTION__); + }; + // Writes the contents to 'stream' in wire format virtual void flush(OutputStream* stream) = 0; }; diff --git a/velox/vector/arrow/Bridge.cpp b/velox/vector/arrow/Bridge.cpp index f9cb1da3ca0c..d488e2e29133 100644 --- a/velox/vector/arrow/Bridge.cpp +++ b/velox/vector/arrow/Bridge.cpp @@ -242,7 +242,7 @@ const char* exportArrowFormatStr( case TypeKind::TIMESTAMP: // TODO: need to figure out how we'll map this since in Velox we currently // store timestamps as two int64s (epoch in sec and nanos). - return "ttn"; // time64 [nanoseconds] + return "tsu:"; // timestamp [microseconds] case TypeKind::DATE: return "tdD"; // date32[days] // Complex/nested types. @@ -363,6 +363,7 @@ void exportValues( out.n_buffers = 2; // Short decimals need to be converted to 128 bit values as they are mapped // to Arrow Decimal128. + // Timestamps need to be converted to micros. if (!rows.changed() && !vec.type()->isShortDecimal()) { holder.setBuffer(1, vec.values()); return; @@ -377,6 +378,32 @@ void exportValues( holder.setBuffer(1, values); } +void exportTimestamps( + const BaseVector& vec, + const Selection& rows, + ArrowArray& out, + memory::MemoryPool* pool, + VeloxToArrowBridgeHolder& holder) { + out.n_buffers = 2; + auto size = vec.type()->cppSizeInBytes(); + auto values = AlignedBuffer::allocate( + checkedMultiply(out.length, size), pool); + const Buffer& buf = *vec.values(); + const auto& tsSrc = buf.as(); + Buffer& outBuffer = *values; + auto dst = outBuffer.asMutable(); + vector_size_t j = 0; // index into dst + rows.apply([&](vector_size_t i) { + int64_t value = 0; + if (!vec.mayHaveNulls() || !vec.isNullAt(i)) { + // The use of toMicros on null causes integer overflow. + value = tsSrc[i].toMicros(); + } + memcpy(dst + (j++) * sizeof(int64_t), &value, sizeof(int64_t)); + }); + holder.setBuffer(1, values); +} + void exportStrings( const FlatVector& vec, const Selection& rows, @@ -430,6 +457,9 @@ void exportFlat( case TypeKind::DOUBLE: exportValues(vec, rows, out, pool, holder); break; + case TypeKind::TIMESTAMP: + exportTimestamps(vec, rows, out, pool, holder); + break; case TypeKind::VARCHAR: case TypeKind::VARBINARY: exportStrings( @@ -727,6 +757,11 @@ void exportToArrow(const VectorPtr& vec, ArrowSchema& arrowSchema) { TypePtr importFromArrow(const ArrowSchema& arrowSchema) { const char* format = arrowSchema.format; VELOX_CHECK_NOT_NULL(format); + std::string formatStr(format); + // TODO: Timezone and unit are not handled. + if (formatStr.rfind("ts", 0) == 0) { + return TIMESTAMP(); + } switch (format[0]) { case 'b': @@ -929,8 +964,7 @@ VectorPtr createStringFlatVector( std::vector stringViewBuffers; if (shouldAcquireStringBuffer) { - stringViewBuffers.emplace_back( - wrapInBufferView(values, offsets[length + 1])); + stringViewBuffers.emplace_back(wrapInBufferView(values, offsets[length])); } return std::make_shared>( @@ -951,6 +985,62 @@ VectorPtr importFromArrowImpl( memory::MemoryPool* pool, bool isViewer); +VectorPtr createDecimalVector( + memory::MemoryPool* pool, + const TypePtr& type, + BufferPtr nulls, + const ArrowSchema& arrowSchema, + const ArrowArray& arrowArray, + WrapInBufferViewFunc wrapInBufferView) { + auto valueBuf = wrapInBufferView( + arrowArray.buffers[1], arrowArray.length * sizeof(int128_t)); + + auto src = valueBuf->as(); + + VectorPtr base = BaseVector::create(type, arrowArray.length, pool); + base->setNulls(nulls); + + auto flatVector = std::dynamic_pointer_cast>(base); + + for (int i = 0; i < arrowArray.length; i++) { + if (!base->isNullAt(i)) { + int128_t result; + memcpy(&result, src + i * sizeof(int128_t), sizeof(int128_t)); + flatVector->set(i, static_cast(result)); + } + } + + return flatVector; +} + +VectorPtr createTimestampVector( + memory::MemoryPool* pool, + const TypePtr& type, + BufferPtr nulls, + const ArrowSchema& arrowSchema, + const ArrowArray& arrowArray, + WrapInBufferViewFunc wrapInBufferView) { + auto valueBuf = wrapInBufferView( + arrowArray.buffers[1], arrowArray.length * sizeof(Timestamp)); + + auto src = valueBuf->as(); + + VectorPtr base = BaseVector::create(type, arrowArray.length, pool); + base->setNulls(nulls); + + auto flatVector = std::dynamic_pointer_cast>(base); + + for (int i = 0; i < arrowArray.length; i++) { + if (!base->isNullAt(i)) { + int64_t result; + memcpy(&result, src + i * sizeof(int64_t), sizeof(int64_t)); + flatVector->set(i, Timestamp::fromMicros(result)); + } + } + + return flatVector; +} + RowVectorPtr createRowVector( memory::MemoryPool* pool, const RowTypePtr& rowType, @@ -1151,6 +1241,15 @@ VectorPtr importFromArrowImpl( return createMapVector( pool, type, nulls, arrowSchema, arrowArray, isViewer, wrapInBufferView); } + if (type->isShortDecimal()) { + return createDecimalVector( + pool, type, nulls, arrowSchema, arrowArray, wrapInBufferView); + } + if (type->kind() == TypeKind::TIMESTAMP) { + return createTimestampVector( + pool, type, nulls, arrowSchema, arrowArray, wrapInBufferView); + } + // Other primitive types. VELOX_CHECK( type->isPrimitiveType(), @@ -1163,7 +1262,7 @@ VectorPtr importFromArrowImpl( auto values = wrapInBufferView( arrowArray.buffers[1], arrowArray.length * type->cppSizeInBytes()); - return VELOX_DYNAMIC_SCALAR_TYPE_DISPATCH( + return VELOX_DYNAMIC_SCALAR_TYPE_DISPATCH_ALL( createFlatVector, type->kind(), pool, diff --git a/velox/vector/arrow/tests/ArrowBridgeArrayTest.cpp b/velox/vector/arrow/tests/ArrowBridgeArrayTest.cpp index da093027471c..3cf3a5e32883 100644 --- a/velox/vector/arrow/tests/ArrowBridgeArrayTest.cpp +++ b/velox/vector/arrow/tests/ArrowBridgeArrayTest.cpp @@ -657,7 +657,7 @@ TEST_F(ArrowBridgeArrayExportTest, unsupported) { // Timestamps. vector = vectorMaker_.flatVectorNullable({}); - EXPECT_THROW(exportToArrow(vector, arrowArray, pool_.get()), VeloxException); + exportToArrow(vector, arrowArray, pool_.get()); // Constant encoding. vector = BaseVector::createConstant(INTEGER(), variant(10), 10, pool_.get()); diff --git a/velox/vector/arrow/tests/ArrowBridgeSchemaTest.cpp b/velox/vector/arrow/tests/ArrowBridgeSchemaTest.cpp index 9b3a1ad3c8b9..d0f4f467933f 100644 --- a/velox/vector/arrow/tests/ArrowBridgeSchemaTest.cpp +++ b/velox/vector/arrow/tests/ArrowBridgeSchemaTest.cpp @@ -120,7 +120,7 @@ TEST_F(ArrowBridgeSchemaExportTest, scalar) { testScalarType(VARCHAR(), "u"); testScalarType(VARBINARY(), "z"); - testScalarType(TIMESTAMP(), "ttn"); + testScalarType(TIMESTAMP(), "tsu:"); testScalarType(DATE(), "tdD"); testScalarType(DECIMAL(10, 4), "d:10,4");