diff --git a/velox/common/base/BitUtil.h b/velox/common/base/BitUtil.h index fbe8600d0163..7b70e763700f 100644 --- a/velox/common/base/BitUtil.h +++ b/velox/common/base/BitUtil.h @@ -693,6 +693,13 @@ inline int32_t countLeadingZeros(uint64_t word) { return __builtin_clzll(word); } +inline int32_t countLeadingZerosUint128(__uint128_t word) { + uint64_t hi = word >> 64; + uint64_t lo = static_cast(word); + return (hi == 0) ? 64 + bits::countLeadingZeros(lo) + : bits::countLeadingZeros(hi); +} + inline uint64_t nextPowerOfTwo(uint64_t size) { if (size == 0) { return 0; diff --git a/velox/expression/CastExpr.cpp b/velox/expression/CastExpr.cpp index 2e20683e6b57..13727a0ce025 100644 --- a/velox/expression/CastExpr.cpp +++ b/velox/expression/CastExpr.cpp @@ -27,6 +27,7 @@ #include "velox/expression/StringWriter.h" #include "velox/external/date/tz.h" #include "velox/functions/lib/RowsTranslationUtil.h" +#include "velox/type/DecimalUtilOp.h" #include "velox/vector/ComplexVector.h" #include "velox/vector/FunctionVector.h" #include "velox/vector/SelectivityVector.h" @@ -202,6 +203,30 @@ void applyDoubleToDecimalCastKernel( } }); } + +template +void applyVarCharToDecimalCastKernel( + const SelectivityVector& rows, + const BaseVector& input, + exec::EvalCtx& context, + const TypePtr& toType, + VectorPtr castResult) { + auto sourceVector = input.as>(); + auto castResultRawBuffer = + castResult->asUnchecked>()->mutableRawValues(); + const auto& toPrecisionScale = getDecimalPrecisionScale(*toType); + context.applyToSelectedNoThrow(rows, [&](vector_size_t row) { + auto rescaledValue = DecimalUtilOp::rescaleVarchar( + sourceVector->valueAt(row), + toPrecisionScale.first, + toPrecisionScale.second); + if (rescaledValue.has_value()) { + castResultRawBuffer[row] = rescaledValue.value(); + } else { + castResult->setNull(row, true); + } + }); +} } // namespace template @@ -608,255 +633,262 @@ VectorPtr CastExpr::applyDecimal( applyDecimalCastKernel( rows, input, context, fromType, toType, castResult); } + } break; + case TypeKind::LONG_DECIMAL: + applyDecimalCastKernel( + rows, input, context, fromType, toType, castResult); break; - case TypeKind::LONG_DECIMAL: - applyDecimalCastKernel( - rows, input, context, fromType, toType, castResult); - break; - case TypeKind::TINYINT: - applyIntToDecimalCastKernel( + case TypeKind::TINYINT: + applyIntToDecimalCastKernel( + rows, input, context, toType, castResult); + break; + case TypeKind::SMALLINT: + applyIntToDecimalCastKernel( + rows, input, context, toType, castResult); + break; + case TypeKind::INTEGER: { + if (toType->kind() == TypeKind::SHORT_DECIMAL) { + applyBigintToDecimalCastKernel( rows, input, context, toType, castResult); - break; - case TypeKind::SMALLINT: - applyIntToDecimalCastKernel( + } else { + applyBigintToDecimalCastKernel( rows, input, context, toType, castResult); - break; - case TypeKind::INTEGER: { - if (toType->kind() == TypeKind::SHORT_DECIMAL) { - applyBigintToDecimalCastKernel( - rows, input, context, toType, castResult); - } else { - applyBigintToDecimalCastKernel( - rows, input, context, toType, castResult); - } - break; - } - case TypeKind::DATE: { - if (toType->kind() == TypeKind::SHORT_DECIMAL) { - applyDateToDecimalCastKernel( - rows, input, context, toType, castResult); - } else { - applyDateToDecimalCastKernel( - rows, input, context, toType, castResult); - } - break; } - case TypeKind::BIGINT: { - if (toType->kind() == TypeKind::SHORT_DECIMAL) { - applyBigintToDecimalCastKernel( - rows, input, context, toType, castResult); - } else { - applyBigintToDecimalCastKernel( - rows, input, context, toType, castResult); - } - break; + break; + } + case TypeKind::DATE: { + if (toType->kind() == TypeKind::SHORT_DECIMAL) { + applyDateToDecimalCastKernel( + rows, input, context, toType, castResult); + } else { + applyDateToDecimalCastKernel( + rows, input, context, toType, castResult); } - case TypeKind::REAL: { - if (toType->kind() == TypeKind::SHORT_DECIMAL) { - applyDoubleToDecimalCastKernel( - rows, input, context, toType, castResult); - } else { - applyDoubleToDecimalCastKernel( - rows, input, context, toType, castResult); - } - break; + break; + } + case TypeKind::BIGINT: { + if (toType->kind() == TypeKind::SHORT_DECIMAL) { + applyBigintToDecimalCastKernel( + rows, input, context, toType, castResult); + } else { + applyBigintToDecimalCastKernel( + rows, input, context, toType, castResult); } - case TypeKind::DOUBLE: { - if (toType->kind() == TypeKind::SHORT_DECIMAL) { - applyDoubleToDecimalCastKernel( - rows, input, context, toType, castResult); - } else { - applyDoubleToDecimalCastKernel( - rows, input, context, toType, castResult); - } - break; + break; + } + case TypeKind::REAL: { + if (toType->kind() == TypeKind::SHORT_DECIMAL) { + applyDoubleToDecimalCastKernel( + rows, input, context, toType, castResult); + } else { + applyDoubleToDecimalCastKernel( + rows, input, context, toType, castResult); } - default: - VELOX_UNSUPPORTED( - "Cast from {} to {} is not supported", - fromType->toString(), - toType->toString()); + break; } - return castResult; - } - - void CastExpr::applyPeeled( - const SelectivityVector& rows, - const BaseVector& input, - exec::EvalCtx& context, - const TypePtr& fromType, - const TypePtr& toType, - VectorPtr& result) { - if (castFromOperator_ || castToOperator_) { - VELOX_CHECK_NE( - fromType, - toType, - "Attempting to cast from {} to itself.", - fromType->toString()); - - if (castToOperator_) { - castToOperator_->castTo(input, context, rows, toType, result); + case TypeKind::DOUBLE: { + if (toType->kind() == TypeKind::SHORT_DECIMAL) { + applyDoubleToDecimalCastKernel( + rows, input, context, toType, castResult); } else { - castFromOperator_->castFrom(input, context, rows, toType, result); + applyDoubleToDecimalCastKernel( + rows, input, context, toType, castResult); } - } else { - switch (toType->kind()) { - case TypeKind::MAP: - result = applyMap( - rows, - input.asUnchecked(), - context, - fromType->asMap(), - toType->asMap()); - break; - case TypeKind::ARRAY: - result = applyArray( - rows, - input.asUnchecked(), - context, - fromType->asArray(), - toType->asArray()); - break; - case TypeKind::ROW: - result = applyRow( - rows, - input.asUnchecked(), - context, - fromType->asRow(), - toType); - break; - case TypeKind::SHORT_DECIMAL: - result = applyDecimal( - rows, input, context, fromType, toType); - break; - case TypeKind::LONG_DECIMAL: - result = applyDecimal( - rows, input, context, fromType, toType); - break; - default: { - // Handle primitive type conversions. - VELOX_DYNAMIC_SCALAR_TYPE_DISPATCH( - applyCast, - toType->kind(), - fromType, - toType, - rows, - context, - input, - result); - } + break; + } + case TypeKind::VARCHAR: { + if (toType->kind() == TypeKind::SHORT_DECIMAL) { + applyVarCharToDecimalCastKernel( + rows, input, context, toType, castResult); + } else { + applyVarCharToDecimalCastKernel( + rows, input, context, toType, castResult); } + break; } + default: + VELOX_UNSUPPORTED( + "Cast from {} to {} is not supported", + fromType->toString(), + toType->toString()); } + return castResult; +} - void CastExpr::apply( - const SelectivityVector& rows, - const VectorPtr& input, - exec::EvalCtx& context, - const TypePtr& fromType, - const TypePtr& toType, - VectorPtr& result) { - LocalDecodedVector decoded(context, *input, rows); - auto* rawNulls = decoded->nulls(); - - LocalSelectivityVector nonNullRows(*context.execCtx(), rows.end()); - *nonNullRows = rows; - if (rawNulls) { - nonNullRows->deselectNulls(rawNulls, rows.begin(), rows.end()); - } - - VectorPtr localResult; - if (!nonNullRows->hasSelections()) { - localResult = - BaseVector::createNullConstant(toType, rows.end(), context.pool()); - } else if (decoded->isIdentityMapping()) { - applyPeeled( - *nonNullRows, - *decoded->base(), - context, - fromType, - toType, - localResult); +void CastExpr::applyPeeled( + const SelectivityVector& rows, + const BaseVector& input, + exec::EvalCtx& context, + const TypePtr& fromType, + const TypePtr& toType, + VectorPtr& result) { + if (castFromOperator_ || castToOperator_) { + VELOX_CHECK_NE( + fromType, + toType, + "Attempting to cast from {} to itself.", + fromType->toString()); + + if (castToOperator_) { + castToOperator_->castTo(input, context, rows, toType, result); } else { - ScopedContextSaver saver; - LocalSelectivityVector newRowsHolder(*context.execCtx()); - - LocalDecodedVector localDecoded(context); - std::vector peeledVectors; - auto peeledEncoding = PeeledEncoding::Peel( - {input}, *nonNullRows, localDecoded, true, peeledVectors); - VELOX_CHECK_EQ(peeledVectors.size(), 1); - auto newRows = - peeledEncoding->translateToInnerRows(*nonNullRows, newRowsHolder); - // Save context and set the peel. - context.saveAndReset(saver, *nonNullRows); - context.setPeeledEncoding(peeledEncoding); - applyPeeled( - *newRows, *peeledVectors[0], context, fromType, toType, localResult); - - localResult = context.getPeeledEncoding()->wrap( - toType, context.pool(), localResult, *nonNullRows); - } - context.moveOrCopyResult(localResult, *nonNullRows, result); - context.releaseVector(localResult); - - // If there are nulls in input, add nulls to the result at the same rows. - VELOX_CHECK_NOT_NULL(result); - if (rawNulls) { - Expr::addNulls( - rows, nonNullRows->asRange().bits(), context, toType, result); + castFromOperator_->castFrom(input, context, rows, toType, result); + } + } else { + switch (toType->kind()) { + case TypeKind::MAP: + result = applyMap( + rows, + input.asUnchecked(), + context, + fromType->asMap(), + toType->asMap()); + break; + case TypeKind::ARRAY: + result = applyArray( + rows, + input.asUnchecked(), + context, + fromType->asArray(), + toType->asArray()); + break; + case TypeKind::ROW: + result = applyRow( + rows, + input.asUnchecked(), + context, + fromType->asRow(), + toType); + break; + case TypeKind::SHORT_DECIMAL: + result = applyDecimal( + rows, input, context, fromType, toType); + break; + case TypeKind::LONG_DECIMAL: + result = applyDecimal( + rows, input, context, fromType, toType); + break; + default: { + // Handle primitive type conversions. + VELOX_DYNAMIC_SCALAR_TYPE_DISPATCH( + applyCast, + toType->kind(), + fromType, + toType, + rows, + context, + input, + result); + } } } +} - void CastExpr::evalSpecialForm( - const SelectivityVector& rows, EvalCtx& context, VectorPtr& result) { - VectorPtr input; - inputs_[0]->eval(rows, context, input); - auto fromType = inputs_[0]->type(); - auto toType = std::const_pointer_cast(type_); +void CastExpr::apply( + const SelectivityVector& rows, + const VectorPtr& input, + exec::EvalCtx& context, + const TypePtr& fromType, + const TypePtr& toType, + VectorPtr& result) { + LocalDecodedVector decoded(context, *input, rows); + auto* rawNulls = decoded->nulls(); - apply(rows, input, context, fromType, toType, result); - // Return 'input' back to the vector pool in 'context' so it can be reused. - context.releaseVector(input); + LocalSelectivityVector nonNullRows(*context.execCtx(), rows.end()); + *nonNullRows = rows; + if (rawNulls) { + nonNullRows->deselectNulls(rawNulls, rows.begin(), rows.end()); } - std::string CastExpr::toString(bool recursive) const { - std::stringstream out; - out << "cast("; - if (recursive) { - appendInputs(out); - } else { - out << inputs_[0]->toString(false); - } - out << " as " << type_->toString() << ")"; - return out.str(); + VectorPtr localResult; + if (!nonNullRows->hasSelections()) { + localResult = + BaseVector::createNullConstant(toType, rows.end(), context.pool()); + } else if (decoded->isIdentityMapping()) { + applyPeeled( + *nonNullRows, *decoded->base(), context, fromType, toType, localResult); + } else { + ScopedContextSaver saver; + LocalSelectivityVector newRowsHolder(*context.execCtx()); + + LocalDecodedVector localDecoded(context); + std::vector peeledVectors; + auto peeledEncoding = PeeledEncoding::Peel( + {input}, *nonNullRows, localDecoded, true, peeledVectors); + VELOX_CHECK_EQ(peeledVectors.size(), 1); + auto newRows = + peeledEncoding->translateToInnerRows(*nonNullRows, newRowsHolder); + // Save context and set the peel. + context.saveAndReset(saver, *nonNullRows); + context.setPeeledEncoding(peeledEncoding); + applyPeeled( + *newRows, *peeledVectors[0], context, fromType, toType, localResult); + + localResult = context.getPeeledEncoding()->wrap( + toType, context.pool(), localResult, *nonNullRows); } - - std::string CastExpr::toSql(std::vector * complexConstants) const { - std::stringstream out; - out << "cast("; - appendInputsSql(out, complexConstants); - out << " as "; - toTypeSql(type_, out); - out << ")"; - return out.str(); + context.moveOrCopyResult(localResult, *nonNullRows, result); + context.releaseVector(localResult); + + // If there are nulls in input, add nulls to the result at the same rows. + VELOX_CHECK_NOT_NULL(result); + if (rawNulls) { + Expr::addNulls( + rows, nonNullRows->asRange().bits(), context, toType, result); } +} - TypePtr CastCallToSpecialForm::resolveType( - const std::vector& /* argTypes */) { - VELOX_FAIL("CAST expressions do not support type resolution."); - } +void CastExpr::evalSpecialForm( + const SelectivityVector& rows, + EvalCtx& context, + VectorPtr& result) { + VectorPtr input; + inputs_[0]->eval(rows, context, input); + auto fromType = inputs_[0]->type(); + auto toType = std::const_pointer_cast(type_); + + apply(rows, input, context, fromType, toType, result); + // Return 'input' back to the vector pool in 'context' so it can be reused. + context.releaseVector(input); +} - ExprPtr CastCallToSpecialForm::constructSpecialForm( - const TypePtr& type, - std::vector&& compiledChildren, - bool trackCpuUsage) { - VELOX_CHECK_EQ( - compiledChildren.size(), - 1, - "CAST statements expect exactly 1 argument, received {}", - compiledChildren.size()); - return std::make_shared( - type, std::move(compiledChildren[0]), trackCpuUsage); +std::string CastExpr::toString(bool recursive) const { + std::stringstream out; + out << "cast("; + if (recursive) { + appendInputs(out); + } else { + out << inputs_[0]->toString(false); } + out << " as " << type_->toString() << ")"; + return out.str(); +} + +std::string CastExpr::toSql(std::vector* complexConstants) const { + std::stringstream out; + out << "cast("; + appendInputsSql(out, complexConstants); + out << " as "; + toTypeSql(type_, out); + out << ")"; + return out.str(); +} + +TypePtr CastCallToSpecialForm::resolveType( + const std::vector& /* argTypes */) { + VELOX_FAIL("CAST expressions do not support type resolution."); +} + +ExprPtr CastCallToSpecialForm::constructSpecialForm( + const TypePtr& type, + std::vector&& compiledChildren, + bool trackCpuUsage) { + VELOX_CHECK_EQ( + compiledChildren.size(), + 1, + "CAST statements expect exactly 1 argument, received {}", + compiledChildren.size()); + return std::make_shared( + type, std::move(compiledChildren[0]), trackCpuUsage); +} } // namespace facebook::velox::exec diff --git a/velox/expression/tests/CastExprTest.cpp b/velox/expression/tests/CastExprTest.cpp index eebd5864cad6..2a8fbd4a6aa0 100644 --- a/velox/expression/tests/CastExprTest.cpp +++ b/velox/expression/tests/CastExprTest.cpp @@ -930,6 +930,26 @@ TEST_F(CastExprTest, integerToDecimal) { testIntToDecimalCasts(); } +TEST_F(CastExprTest, varcharToDecimal) { + // varchar to short decimal +// auto input = makeFlatVector({"-3", "177"}); +// testComplexCast( +// "c0", input, makeShortDecimalFlatVector({-300, 17700}, DECIMAL(6, 2))); + +// // varchar to long decimal +// auto input2 = makeFlatVector( +// {"-300000001234567891234.5", "1771234.5678912345678"}); +// testComplexCast( +// "c0", input2, makeLongDecimalFlatVector({-300, 17700}, DECIMAL(32, 7))); + + auto input3 = makeFlatVector({"9999999999.99", "9999999999.99"}); + testComplexCast( + "c0", input3, makeLongDecimalFlatVector( + {-30'000'000'000, + -20'000'000'000}, + DECIMAL(12, 2))); +} + TEST_F(CastExprTest, castInTry) { // Test try(cast(array(varchar) as array(bigint))) whose input vector is // wrapped in dictinary encoding. The row of ["2a"] should trigger an error diff --git a/velox/functions/prestosql/DecimalArithmetic.cpp b/velox/functions/prestosql/DecimalArithmetic.cpp index d65be3c875a7..b0da476ec60a 100644 --- a/velox/functions/prestosql/DecimalArithmetic.cpp +++ b/velox/functions/prestosql/DecimalArithmetic.cpp @@ -33,8 +33,22 @@ class DecimalBaseFunction : public exec::VectorFunction { DecimalBaseFunction( uint8_t aRescale, uint8_t bRescale, + uint8_t aPrecision, + uint8_t aScale, + uint8_t bPrecision, + uint8_t bScale, + uint8_t rPrecision, + uint8_t rScale, const TypePtr& resultType) - : aRescale_(aRescale), bRescale_(bRescale), resultType_(resultType) {} + : aRescale_(aRescale), + bRescale_(bRescale), + aPrecision_(aPrecision), + aScale_(aScale), + bPrecision_(bPrecision), + bScale_(bScale), + rPrecision_(rPrecision), + rScale_(rScale), + resultType_(resultType) {} void apply( const SelectivityVector& rows, @@ -49,8 +63,23 @@ class DecimalBaseFunction : public exec::VectorFunction { auto flatValues = args[1]->asUnchecked>(); auto rawValues = flatValues->mutableRawValues(); context.applyToSelectedNoThrow(rows, [&](auto row) { + bool overflow = false; Operation::template apply( - rawResults[row], constant, rawValues[row], aRescale_, bRescale_); + rawResults[row], + constant, + rawValues[row], + aRescale_, + bRescale_, + aPrecision_, + aScale_, + bPrecision_, + bScale_, + rPrecision_, + rScale_, + &overflow); + if (overflow) { + result->setNull(row, true); + } }); } else if (args[0]->isFlatEncoding() && args[1]->isConstantEncoding()) { // Fast path for (flat, const). @@ -58,8 +87,23 @@ class DecimalBaseFunction : public exec::VectorFunction { auto constant = args[1]->asUnchecked>()->valueAt(0); auto rawValues = flatValues->mutableRawValues(); context.applyToSelectedNoThrow(rows, [&](auto row) { + bool overflow = false; Operation::template apply( - rawResults[row], rawValues[row], constant, aRescale_, bRescale_); + rawResults[row], + rawValues[row], + constant, + aRescale_, + bRescale_, + aPrecision_, + aScale_, + bPrecision_, + bScale_, + rPrecision_, + rScale_, + &overflow); + if (overflow) { + result->setNull(row, true); + } }); } else if (args[0]->isFlatEncoding() && args[1]->isFlatEncoding()) { // Fast path for (flat, flat). @@ -67,9 +111,25 @@ class DecimalBaseFunction : public exec::VectorFunction { auto rawA = flatA->mutableRawValues(); auto flatB = args[1]->asUnchecked>(); auto rawB = flatB->mutableRawValues(); + context.applyToSelectedNoThrow(rows, [&](auto row) { + bool overflow = false; Operation::template apply( - rawResults[row], rawA[row], rawB[row], aRescale_, bRescale_); + rawResults[row], + rawA[row], + rawB[row], + aRescale_, + bRescale_, + aPrecision_, + aScale_, + bPrecision_, + bScale_, + rPrecision_, + rScale_, + &overflow); + if (overflow) { + result->setNull(row, true); + } }); } else { // Fast path if one or more arguments are encoded. @@ -77,12 +137,23 @@ class DecimalBaseFunction : public exec::VectorFunction { auto a = decodedArgs.at(0); auto b = decodedArgs.at(1); context.applyToSelectedNoThrow(rows, [&](auto row) { + bool overflow = false; Operation::template apply( rawResults[row], a->valueAt(row), b->valueAt(row), aRescale_, - bRescale_); + bRescale_, + aPrecision_, + aScale_, + bPrecision_, + bScale_, + rPrecision_, + rScale_, + &overflow); + if (overflow) { + result->setNull(row, true); + } }); } } @@ -102,6 +173,12 @@ class DecimalBaseFunction : public exec::VectorFunction { const uint8_t aRescale_; const uint8_t bRescale_; + const uint8_t aPrecision_; + const uint8_t aScale_; + const uint8_t bPrecision_; + const uint8_t bScale_; + const uint8_t rPrecision_; + const uint8_t rScale_; const TypePtr resultType_; }; @@ -160,8 +237,19 @@ class DecimalUnaryBaseFunction : public exec::VectorFunction { class Addition { public: template - inline static void - apply(R& r, const A& a, const B& b, uint8_t aRescale, uint8_t bRescale) + inline static void apply( + R& r, + const A& a, + const B& b, + uint8_t aRescale, + uint8_t bRescale, + uint8_t /* aPrecision */, + uint8_t /* aScale */, + uint8_t /* bPrecision */, + uint8_t /* bScale */, + uint8_t /* rPrecision */, + uint8_t /* rScale */, + bool* overflow) #if defined(__has_feature) #if __has_feature(__address_sanitizer__) __attribute__((__no_sanitize__("signed-integer-overflow"))) @@ -181,7 +269,10 @@ class Addition { VELOX_ARITHMETIC_ERROR( "Decimal overflow: {} + {}", a.unscaledValue(), b.unscaledValue()); } - r = checkedPlus(R(aRescaled), R(bRescaled)); + auto res = R(aRescaled).plus(R(bRescaled), overflow); + if (!*overflow) { + r = res; + } } inline static uint8_t @@ -201,13 +292,38 @@ class Addition { std::max(aScale, bScale) + 1), std::max(aScale, bScale)}; } + + inline static std::pair adjustPrecisionScale( + const uint8_t rPrecision, + const uint8_t rScale) { + if (rPrecision <= 38) { + return {rPrecision, rScale}; + } else if (rScale < 0) { + return {38, rScale}; + } else { + int32_t minScale = std::min(static_cast(rScale), 6); + int32_t delta = rPrecision - 38; + return {38, std::max(rScale - delta, minScale)}; + } + } }; class Subtraction { public: template - inline static void - apply(R& r, const A& a, const B& b, uint8_t aRescale, uint8_t bRescale) + inline static void apply( + R& r, + const A& a, + const B& b, + uint8_t aRescale, + uint8_t bRescale, + uint8_t /* aPrecision */, + uint8_t /* aScale */, + uint8_t /* bPrecision */, + uint8_t /* bScale */, + uint8_t /* rPrecision */, + uint8_t /* rScale */, + bool* overflow) #if defined(__has_feature) #if __has_feature(__address_sanitizer__) __attribute__((__no_sanitize__("signed-integer-overflow"))) @@ -224,10 +340,13 @@ class Subtraction { b.unscaledValue(), DecimalUtil::kPowersOfTen[bRescale], &bRescaled)) { - VELOX_ARITHMETIC_ERROR( - "Decimal overflow: {} - {}", a.unscaledValue(), b.unscaledValue()); + *overflow = true; + return; + } + auto res = R(aRescaled).minus(R(bRescaled), overflow); + if (!*overflow) { + r = res; } - r = checkedMinus(R(aRescaled), R(bRescaled)); } inline static uint8_t @@ -248,11 +367,83 @@ class Subtraction { class Multiply { public: template - inline static void - apply(R& r, const A& a, const B& b, uint8_t aRescale, uint8_t bRescale) { - r = checkedMultiply( - checkedMultiply(R(a), R(b)), - R(DecimalUtil::kPowersOfTen[aRescale + bRescale])); + inline static void apply( + R& r, + const A& a, + const B& b, + uint8_t aRescale, + uint8_t bRescale, + uint8_t aPrecision, + uint8_t aScale, + uint8_t bPrecision, + uint8_t bScale, + uint8_t rPrecision, + uint8_t rScale, + bool* overflow) { + // derive from Arrow + if (rPrecision < 38) { + auto res = checkedMultiply( + R(a).multiply(R(b), overflow), + R(DecimalUtil::kPowersOfTen[aRescale + bRescale])); + if (!*overflow) { + r = res; + } + } else if (a.unscaledValue() == 0 && b.unscaledValue() == 0) { + // Handle this separately to avoid divide-by-zero errors. + r = R(0); + } else { + auto deltaScale = aScale + bScale - rScale; + if (deltaScale == 0) { + // No scale down + auto res = R(a).multiply(R(b), overflow); + if (!*overflow) { + r = res; + } + } else { + // scale down + // It's possible that the intermediate value does not fit in 128-bits, + // but the final value will (after scaling down). + int32_t total_leading_zeros = + a.countLeadingZeros() + b.countLeadingZeros(); + // This check is quick, but conservative. In some cases it will + // indicate that converting to 256 bits is necessary, when it's not + // actually the case. + if (UNLIKELY(total_leading_zeros <= 128)) { + // needs_int256 + int256_t aLarge = a.unscaledValue(); + int256_t blarge = b.unscaledValue(); + int256_t reslarge = aLarge * blarge; + reslarge = ReduceScaleBy(reslarge, deltaScale); + auto res = R::convert(reslarge, overflow); + if (!*overflow) { + r = res; + } + } else { + if (LIKELY(deltaScale <= 38)) { + // The largest value that result can have here is (2^64 - 1) * (2^63 + // - 1), which is greater than BasicDecimal128::kMaxValue. + auto res = R(a).multiply(R(b), overflow); + VELOX_DCHECK(!*overflow); + // Since deltaScale is greater than zero, result can now be at most + // ((2^64 - 1) * (2^63 - 1)) / 10, which is less than + // BasicDecimal128::kMaxValue, so there cannot be any overflow. + r = res / R(DecimalUtil::kPowersOfTen[deltaScale]); + } else { + // We are multiplying decimal(38, 38) by decimal(38, 38). The result + // should be a + // decimal(38, 37), so delta scale = 38 + 38 - 37 = 39. Since we are + // not in the 256 bit intermediate value case and we are scaling + // down by 39, then we are guaranteed that the result is 0 (even if + // we try to round). The largest possible intermediate result is 38 + // "9"s. If we scale down by 39, the leftmost 9 is now two digits to + // the right of the rightmost "visible" one. The reason why we have + // to handle this case separately is because a scale multiplier with + // a deltaScale 39 does not fit into 128 bit. + r = R(0); + } + } + } + } } inline static uint8_t @@ -265,16 +456,49 @@ class Multiply { const uint8_t aScale, const uint8_t bPrecision, const uint8_t bScale) { - return {std::min(38, aPrecision + bPrecision), aScale + bScale}; + return Addition::adjustPrecisionScale( + aPrecision + bPrecision + 1, aScale + bScale); + } + + private: + // derive from Arrow + inline static int256_t ReduceScaleBy(int256_t in, int32_t reduceBy) { + if (reduceBy == 0) { + // nothing to do. + return in; + } + + int256_t divisor = DecimalUtil::kPowersOfTen[reduceBy]; + DCHECK_GT(divisor, 0); + DCHECK_EQ(divisor % 2, 0); // multiple of 10. + auto result = in / divisor; + auto remainder = in % divisor; + // round up (same as BasicDecimal128::ReduceScaleBy) + if (abs(remainder) >= (divisor >> 1)) { + result += (in > 0 ? 1 : -1); + } + return result; } }; class Divide { public: template - inline static void - apply(R& r, const A& a, const B& b, uint8_t aRescale, uint8_t /*bRescale*/) { - DecimalUtilOp::divideWithRoundUp(r, a, b, false, aRescale, 0); + inline static void apply( + R& r, + const A& a, + const B& b, + uint8_t aRescale, + uint8_t /*bRescale*/, + uint8_t /* aPrecision */, + uint8_t /* aScale */, + uint8_t /* bPrecision */, + uint8_t /* bScale */, + uint8_t /* rPrecision */, + uint8_t /* rScale */, + bool* overflow) { + DecimalUtilOp::divideWithRoundUp( + r, a, b, false, aRescale, 0, overflow); } inline static uint8_t @@ -289,13 +513,7 @@ class Divide { const uint8_t bScale) { auto scale = std::max(6, aScale + bPrecision + 1); auto precision = aPrecision - aScale + bScale + scale; - if (precision > 38) { - int32_t min_scale = std::min(scale, 6); - int32_t delta = precision - 38; - precision = 38; - scale = std::max(scale - delta, min_scale); - } - return {precision, scale}; + return Addition::adjustPrecisionScale(precision, scale); } }; @@ -373,18 +591,19 @@ class Negate { std::vector> decimalMultiplySignature() { - return { - exec::FunctionSignatureBuilder() - .integerVariable("a_precision") - .integerVariable("a_scale") - .integerVariable("b_precision") - .integerVariable("b_scale") - .integerVariable("r_precision", "min(38, a_precision + b_precision)") - .integerVariable("r_scale", "a_scale + b_scale") - .returnType("DECIMAL(r_precision, r_scale)") - .argumentType("DECIMAL(a_precision, a_scale)") - .argumentType("DECIMAL(b_precision, b_scale)") - .build()}; + return {exec::FunctionSignatureBuilder() + .integerVariable("a_precision") + .integerVariable("a_scale") + .integerVariable("b_precision") + .integerVariable("b_scale") + .integerVariable( + "r_precision", "min(38, a_precision + b_precision + 1)") + .integerVariable( + "r_scale", "a_scale") // not same with the result type + .returnType("DECIMAL(r_precision, r_scale)") + .argumentType("DECIMAL(a_precision, a_scale)") + .argumentType("DECIMAL(b_precision, b_scale)") + .build()}; } std::vector> @@ -420,9 +639,10 @@ std::vector> decimalDivideSignature() { "min(37, max(6, a_scale + b_precision + 1))") // if precision is // more than 38, // scale has new - // value, this check - // constrait is not - // same with result + // value, this + // check constrait + // is not same + // with result // type .returnType("DECIMAL(r_precision, r_scale)") .argumentType("DECIMAL(a_precision, a_scale)") @@ -502,38 +722,85 @@ std::shared_ptr createDecimalFunction( UnscaledLongDecimal /*result*/, UnscaledShortDecimal, UnscaledShortDecimal, - Operation>>(aRescale, bRescale, LONG_DECIMAL(rPrecision, rScale)); + Operation>>( + aRescale, + bRescale, + aPrecision, + aScale, + bPrecision, + bScale, + rPrecision, + rScale, + LONG_DECIMAL(rPrecision, rScale)); } else { // Arguments are short decimals and result is a short decimal. return std::make_shared>(aRescale, bRescale, SHORT_DECIMAL(rPrecision, rScale)); + Operation>>( + aRescale, + bRescale, + aPrecision, + aScale, + bPrecision, + bScale, + rPrecision, + rScale, + SHORT_DECIMAL(rPrecision, rScale)); } } else { - // LHS is short decimal and rhs is a long decimal, result is long decimal. + // LHS is short decimal and rhs is a long decimal, result is long + // decimal. return std::make_shared>(aRescale, bRescale, LONG_DECIMAL(rPrecision, rScale)); + Operation>>( + aRescale, + bRescale, + aPrecision, + aScale, + bPrecision, + bScale, + rPrecision, + rScale, + LONG_DECIMAL(rPrecision, rScale)); } } else { if (bType->kind() == TypeKind::SHORT_DECIMAL) { - // LHS is long decimal and rhs is short decimal, result is a long decimal. + // LHS is long decimal and rhs is short decimal, result is a long + // decimal. return std::make_shared>(aRescale, bRescale, LONG_DECIMAL(rPrecision, rScale)); + Operation>>( + aRescale, + bRescale, + aPrecision, + aScale, + bPrecision, + bScale, + rPrecision, + rScale, + LONG_DECIMAL(rPrecision, rScale)); } else { // Arguments and result are all long decimals. return std::make_shared>(aRescale, bRescale, LONG_DECIMAL(rPrecision, rScale)); + Operation>>( + aRescale, + bRescale, + aPrecision, + aScale, + bPrecision, + bScale, + rPrecision, + rScale, + LONG_DECIMAL(rPrecision, rScale)); } } VELOX_UNSUPPORTED(); diff --git a/velox/functions/prestosql/tests/DecimalArithmeticTest.cpp b/velox/functions/prestosql/tests/DecimalArithmeticTest.cpp index eaf4268d9332..c61b53fd9f0f 100644 --- a/velox/functions/prestosql/tests/DecimalArithmeticTest.cpp +++ b/velox/functions/prestosql/tests/DecimalArithmeticTest.cpp @@ -142,12 +142,6 @@ TEST_F(DecimalArithmeticTest, add) { "Decimal overflow: 1 + 99999999999999999999999999999999999999"); } -TEST_F(DecimalArithmeticTest, int128Abs) { - int128_t va = UnscaledLongDecimal::min().unscaledValue(); - int128_t absVal = std::abs(va); -; -} - TEST_F(DecimalArithmeticTest, subtract) { auto shortFlatA = makeShortDecimalFlatVector({1000, 2000}, DECIMAL(18, 3)); // Subtract short and short, returning long. @@ -230,6 +224,36 @@ TEST_F(DecimalArithmeticTest, subtract) { "Decimal overflow: 1 - -99999999999999999999999999999999999999"); } +TEST_F(DecimalArithmeticTest, sparkMultiply) { + // auto shortFlat = makeShortDecimalFlatVector({1000, 2000}, DECIMAL(17, + // 3)); + // // Multiply short and short, returning long. + // testDecimalExpr( + // makeLongDecimalFlatVector({1000000, 4000000}, DECIMAL(35, 6)), + // "multiply(c0, c1)", + // {shortFlat, shortFlat}); + + // auto longFlat = makeLongDecimalFlatVector({1000, 2000}, DECIMAL(21, 3)); + // auto longFlat1 = makeLongDecimalFlatVector({1000, 2000}, DECIMAL(21, 2)); + // // Multiply short and short, returning long. + // testDecimalExpr( + // makeLongDecimalFlatVector({1000000, 4000000}, DECIMAL(38, 5)), + // "multiply(c0, c1)", + // {longFlat, longFlat1}); + + // testDecimalExpr( + // makeLongDecimalFlatVector({1000, 4000}, DECIMAL(38, 7)), + // "multiply(c0, c1)", + // {makeLongDecimalFlatVector({1000, 2000}, DECIMAL(20, 5)), + // makeLongDecimalFlatVector({1000, 2000}, DECIMAL(20, 5))}); + + testDecimalExpr( + makeLongDecimalFlatVector({1000}, DECIMAL(38, 7)), + "multiply(c0, c1)", + {makeShortDecimalFlatVector({1}, DECIMAL(10, 0)), + makeLongDecimalFlatVector({1123210000000000000}, DECIMAL(38, 18))}); +} + TEST_F(DecimalArithmeticTest, multiply) { auto shortFlat = makeShortDecimalFlatVector({1000, 2000}, DECIMAL(17, 3)); // Multiply short and short, returning long. diff --git a/velox/type/DecimalUtil.h b/velox/type/DecimalUtil.h index 644f81bbc8b6..1a08516e3702 100644 --- a/velox/type/DecimalUtil.h +++ b/velox/type/DecimalUtil.h @@ -46,7 +46,8 @@ class DecimalUtil { const int fromScale, const int toPrecision, const int toScale, - bool nullOnOverflow = false) { + bool nullOnOverflow = false, + bool roundUp = true) { int128_t rescaledValue = inputValue.unscaledValue(); auto scaleDifference = toScale - fromScale; bool isOverflow = false; @@ -60,9 +61,10 @@ class DecimalUtil { const auto scalingFactor = DecimalUtil::kPowersOfTen[scaleDifference]; rescaledValue /= scalingFactor; int128_t remainder = inputValue.unscaledValue() % scalingFactor; - if (inputValue.unscaledValue() >= 0 && remainder >= scalingFactor / 2) { + if (roundUp && inputValue.unscaledValue() >= 0 && + remainder >= scalingFactor / 2) { ++rescaledValue; - } else if (remainder <= -scalingFactor / 2) { + } else if (roundUp && remainder <= -scalingFactor / 2) { --rescaledValue; } } @@ -98,6 +100,7 @@ class DecimalUtil { // Multiply decimal with the scale auto unscaled = inputValue * DecimalUtil::kPowersOfTen[toScale]; + bool isOverflow = std::isnan(unscaled); unscaled = std::round(unscaled); @@ -115,7 +118,7 @@ class DecimalUtil { if (rescaledValue < -DecimalUtil::kPowersOfTen[toPrecision] || rescaledValue > DecimalUtil::kPowersOfTen[toPrecision] || isOverflow) { VELOX_USER_FAIL( - "Cannot cast BIGINT '{}' to DECIMAL({},{})", + "Cannot cast DOUBLE '{}' to DECIMAL({},{})", inputValue, toPrecision, toScale); @@ -452,10 +455,10 @@ class DecimalUtil { } template - inline static int numDigits(T number) - { + inline static int numDigits(T number) { int digits = 0; - if (number < 0) digits = 1; // remove this line if '-' counts as a digit + if (number < 0) + digits = 1; // remove this line if '-' counts as a digit while (number) { number /= 10; digits++; diff --git a/velox/type/DecimalUtilOp.h b/velox/type/DecimalUtilOp.h index bdc6d7d41dc1..7b6fa09db1f5 100644 --- a/velox/type/DecimalUtilOp.h +++ b/velox/type/DecimalUtilOp.h @@ -24,11 +24,7 @@ #include "velox/type/UnscaledLongDecimal.h" #include "velox/type/UnscaledShortDecimal.h" -#include - namespace facebook::velox { -using boost::multiprecision::int256_t; -using uint128_t = __uint128_t; class DecimalUtilOp { public: @@ -57,60 +53,12 @@ class DecimalUtilOp { if constexpr (std::is_same_v) { num_occupied = 64 - bits::countLeadingZeros(valueAbs); } else { - uint64_t hi = valueAbs >> 64; - uint64_t lo = static_cast(valueAbs); - num_occupied = (hi == 0) ? 64 - bits::countLeadingZeros(lo) - : 64 - bits::countLeadingZeros(hi); + num_occupied = 128 - num.countLeadingZeros(); } return num_occupied + maxBitsRequiredIncreaseAfterScaling(aRescale); } - inline static int128_t ConvertToInt128(int256_t in) { - int128_t result; - int128_t INT128_MAX = int128_t(int128_t(-1L)) >> 1; - constexpr int256_t UINT128_MASK = std::numeric_limits::max(); - - int256_t in_abs = abs(in); - bool is_negative = in < 0; - - uint128_t unsignResult = (in_abs & UINT128_MASK).convert_to(); - in_abs >>= 128; - - if (in_abs > 0) { - // we've shifted in by 128-bit, so nothing should be left. - VELOX_FAIL("in_abs overflow"); - } else if (unsignResult > INT128_MAX) { - // the high-bit must not be set (signed 128-bit). - VELOX_FAIL("in_abs > int128 max"); - } else { - result = static_cast(unsignResult); - } - return is_negative ? -result : result; - } - - inline static int64_t ConvertToInt64(int256_t in) { - int64_t result; - constexpr int256_t UINT64_MASK = std::numeric_limits::max(); - - int256_t in_abs = abs(in); - bool is_negative = in < 0; - - uint128_t unsignResult = (in_abs & UINT64_MASK).convert_to(); - in_abs >>= 64; - - if (in_abs > 0) { - // we've shifted in by 128-bit, so nothing should be left. - VELOX_FAIL("in_abs overflow"); - } else if (unsignResult > INT64_MAX) { - // the high-bit must not be set (signed 128-bit). - VELOX_FAIL("in_abs > int64 max"); - } else { - result = static_cast(unsignResult); - } - return is_negative ? -result : result; - } - template inline static R divideWithRoundUp( R& r, @@ -118,8 +66,12 @@ class DecimalUtilOp { const B& b, bool noRoundUp, uint8_t aRescale, - uint8_t /*bRescale*/) { - VELOX_CHECK_NE(b, 0, "Division by zero"); + uint8_t /*bRescale*/, + bool* overflow) { + if (b.unscaledValue() == 0) { + *overflow = true; + return R(-1); + } int resultSign = 1; R unsignedDividendRescaled(a); int aSign = 1; @@ -136,10 +88,12 @@ class DecimalUtilOp { bSign = -1; } auto bitsRequiredAfterScaling = maxBitsRequiredAfterScaling(a, aRescale); - if (bitsRequiredAfterScaling <= 127) { - unsignedDividendRescaled = checkedMultiply( - unsignedDividendRescaled, R(DecimalUtil::kPowersOfTen[aRescale])); + unsignedDividendRescaled = unsignedDividendRescaled.multiply( + R(DecimalUtil::kPowersOfTen[aRescale]), overflow); + if (*overflow) { + return R(-1); + } R quotient = unsignedDividendRescaled / unsignedDivisor; R remainder = unsignedDividendRescaled % unsignedDivisor; if (!noRoundUp && remainder * 2 >= unsignedDivisor) { @@ -152,10 +106,8 @@ class DecimalUtilOp { std::is_same_v) { // Derives from Arrow BasicDecimal128 Divide if (aRescale > 38 && bitsRequiredAfterScaling > 255) { - VELOX_FAIL( - "Decimal overflow because rescale {} > 38 and bitsRequiredAfterScaling {} > 255", - aRescale, - bitsRequiredAfterScaling); + *overflow = true; + return R(-1); } int256_t aLarge = a.unscaledValue(); int256_t x_large_scaled_up = aLarge * DecimalUtil::kPowersOfTen[aRescale]; @@ -171,24 +123,135 @@ class DecimalUtilOp { // x -ve and y -ve, result is +ve => (-1 ^ -1) + 1 = 0 + 1 = +1 result_large += (aSign ^ bSign) + 1; } - if constexpr (std::is_same_v) { - int64_t result = ConvertToInt128(result_large); - if (!R::valueInRange(result)) { - VELOX_FAIL("overflow long decimal"); - } - r = UnscaledShortDecimal(result); - return UnscaledShortDecimal(ConvertToInt64(remainder_large)); + + auto result = R::convert(result_large, overflow); + auto remainder = R::convert(remainder_large, overflow); + if (!R::valueInRange(result.unscaledValue())) { + *overflow = true; } else { - int128_t result = ConvertToInt128(result_large); - if (!R::valueInRange(result)) { - VELOX_FAIL("overflow long decimal"); - } - r = UnscaledLongDecimal(result); - return UnscaledLongDecimal(ConvertToInt128(remainder_large)); + r = result; } + return remainder; } else { VELOX_FAIL("Should not reach here in DecimalUtilOp.h"); } } + + // return unscaled value and scale + inline static std::pair splitVarChar( + const StringView& value) { + std::string s = value.str(); + size_t pos = s.find('.'); + if (pos == std::string::npos) { + return {s.substr(0, pos), 0}; + } else { + return { + s.substr(0, pos) + s.substr(pos + 1, s.length()), s.length() - pos - 1}; + } + } + + static int128_t convertStringToInt128( + const std::string& value, + bool& nullOutput) { + // Handling integer target cases + const char* v = value.c_str(); + nullOutput = true; + bool negative = false; + int128_t result = 0; + int index = 0; + int len = value.size(); + if (len == 0) { + return -1; + } + // Setting negative flag + if (v[0] == '-') { + if (len == 1) { + return -1; + } + negative = true; + index = 1; + } + if (negative) { + for (; index < len; index++) { + if (!std::isdigit(v[index])) { + return -1; + } + result = result * 10 - (v[index] - '0'); + // Overflow check + if (result > 0) { + return -1; + } + } + } else { + for (; index < len; index++) { + if (!std::isdigit(v[index])) { + return -1; + } + result = result * 10 + (v[index] - '0'); + // Overflow check + if (result < 0) { + return -1; + } + } + } + // Final result + nullOutput = false; + return result; + } + + template + inline static std::optional rescaleVarchar( + const StringView inputValue, + const int toPrecision, + const int toScale) { + static_assert( + std::is_same_v || + std::is_same_v); + auto [unscaledStr, fromScale] = splitVarChar(inputValue); + uint8_t fromPrecision = unscaledStr.size(); + VELOX_CHECK_LE( + fromPrecision, DecimalType::kMaxPrecision); + if (fromPrecision <= 18) { + int64_t fromUnscaledValue = folly::to(unscaledStr); + return DecimalUtil::rescaleWithRoundUp( + UnscaledShortDecimal(fromUnscaledValue), + fromPrecision, + fromScale, + toPrecision, + toScale, + false, + false); + } else { + bool nullOutput = true; + int128_t decimalValue = convertStringToInt128(unscaledStr, nullOutput); + if (nullOutput) { + VELOX_USER_FAIL( + "Cannot cast StringView '{}' to DECIMAL({},{})", + inputValue, + toPrecision, + toScale); + } + return DecimalUtil::rescaleWithRoundUp( + UnscaledLongDecimal(decimalValue), + fromPrecision, + fromScale, + toPrecision, + toScale, + false, + false); + } + } + + template + inline static std::optional rescaleDouble( + const TInput inputValue, + const int toPrecision, + const int toScale) { + static_assert( + std::is_same_v || + std::is_same_v); + return rescaleVarchar( + velox::to(inputValue), toPrecision, toScale); + } }; } // namespace facebook::velox diff --git a/velox/type/UnscaledLongDecimal.h b/velox/type/UnscaledLongDecimal.h index b3966836e972..5a3d4bff361e 100644 --- a/velox/type/UnscaledLongDecimal.h +++ b/velox/type/UnscaledLongDecimal.h @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#include #include #include #include @@ -26,6 +27,8 @@ namespace facebook::velox { using int128_t = __int128_t; +using boost::multiprecision::int256_t; +using uint128_t = __uint128_t; constexpr int128_t buildInt128(uint64_t hi, uint64_t lo) { // GCC does not allow left shift negative value. @@ -173,6 +176,66 @@ struct UnscaledLongDecimal { memcpy(&ans.unscaledValue_, serializedData, sizeof(int128_t)); return ans; } + + UnscaledLongDecimal plus(const UnscaledLongDecimal& a, bool* overflow) { + int128_t result; + *overflow = + __builtin_add_overflow(unscaledValue_, a.unscaledValue(), &result); + if (UNLIKELY(*overflow || !UnscaledLongDecimal::valueInRange(result))) { + *overflow = true; + return UnscaledLongDecimal(-1); + } + return UnscaledLongDecimal(result); + } + + UnscaledLongDecimal minus(const UnscaledLongDecimal& a, bool* overflow) { + int128_t result; + *overflow = + __builtin_sub_overflow(unscaledValue_, a.unscaledValue(), &result); + if (UNLIKELY(*overflow || !UnscaledLongDecimal::valueInRange(result))) { + *overflow = true; + return UnscaledLongDecimal(-1); + } + return UnscaledLongDecimal(result); + } + + UnscaledLongDecimal multiply(const UnscaledLongDecimal& a, bool* overflow) { + int128_t result; + *overflow = + __builtin_mul_overflow(unscaledValue_, a.unscaledValue(), &result); + if (UNLIKELY(*overflow || !UnscaledLongDecimal::valueInRange(result))) { + *overflow = true; + return UnscaledLongDecimal(-1); + } + return UnscaledLongDecimal(result); + } + + int32_t countLeadingZeros() const { + auto abs = std::abs(unscaledValue_); + return bits::countLeadingZerosUint128(abs); + } + + static inline UnscaledLongDecimal convert(int256_t in, bool* overflow) { + int128_t result; + int128_t INT128_MAX = int128_t(int128_t(-1L)) >> 1; + constexpr int256_t UINT128_MASK = std::numeric_limits::max(); + + int256_t inAbs = abs(in); + bool isNegative = in < 0; + + uint128_t unsignResult = (inAbs & UINT128_MASK).convert_to(); + inAbs >>= 128; + + if (inAbs > 0) { + // we've shifted in by 128-bit, so nothing should be left. + *overflow = true; + } else if (unsignResult > INT128_MAX) { + *overflow = true; + } else { + result = static_cast(unsignResult); + } + return UnscaledLongDecimal(isNegative ? -result : result); + } private: static constexpr int128_t kMin = diff --git a/velox/type/UnscaledShortDecimal.h b/velox/type/UnscaledShortDecimal.h index 581bddee9306..1679550b3d7d 100644 --- a/velox/type/UnscaledShortDecimal.h +++ b/velox/type/UnscaledShortDecimal.h @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#include #include #include #include @@ -23,6 +24,8 @@ #pragma once namespace facebook::velox { +using boost::multiprecision::int256_t; +using uint128_t = __uint128_t; struct UnscaledShortDecimal { public: @@ -122,6 +125,66 @@ struct UnscaledShortDecimal { return *this; } + UnscaledShortDecimal plus(const UnscaledShortDecimal& a, bool* overflow) { + int64_t result; + *overflow = + __builtin_add_overflow(unscaledValue_, a.unscaledValue(), &result); + if (UNLIKELY(*overflow || !UnscaledShortDecimal::valueInRange(result))) { + *overflow = true; + return UnscaledShortDecimal(-1); + } + return UnscaledShortDecimal(result); + } + + UnscaledShortDecimal minus(const UnscaledShortDecimal& a, bool* overflow) { + int64_t result; + *overflow = + __builtin_sub_overflow(unscaledValue_, a.unscaledValue(), &result); + if (UNLIKELY(*overflow || !UnscaledShortDecimal::valueInRange(result))) { + *overflow = true; + return UnscaledShortDecimal(-1); + } + return UnscaledShortDecimal(result); + } + + UnscaledShortDecimal multiply(const UnscaledShortDecimal& a, bool* overflow) { + int64_t result; + *overflow = + __builtin_mul_overflow(unscaledValue_, a.unscaledValue(), &result); + if (UNLIKELY(*overflow || !UnscaledShortDecimal::valueInRange(result))) { + *overflow = true; + return UnscaledShortDecimal(-1); + } + return UnscaledShortDecimal(result); + } + + int32_t countLeadingZeros() const { + auto abs = std::abs(unscaledValue_); + return bits::countLeadingZeros(abs); + } + + static inline UnscaledShortDecimal convert(int256_t in, bool* overflow) { + int64_t result; + constexpr int256_t UINT64_MASK = std::numeric_limits::max(); + + int256_t inAbs = abs(in); + bool isNegative = in < 0; + + uint128_t unsignResult = (inAbs & UINT64_MASK).convert_to(); + inAbs >>= 64; + + if (inAbs > 0) { + // we've shifted in by 128-bit, so nothing should be left. + *overflow = true; + } else if (unsignResult > INT64_MAX) { + // the high-bit must not be set (signed 128-bit). + *overflow = true; + } else { + result = static_cast(unsignResult); + } + return UnscaledShortDecimal(isNegative ? -result : result); + } + private: static constexpr int64_t kMin = -1'000'000'000'000'000'000 + 1; static constexpr int64_t kMax = 1'000'000'000'000'000'000 - 1;