Skip to content

Commit

Permalink
WIP: Add support for DECIMAL types to Expression Fuzzer
Browse files Browse the repository at this point in the history
  • Loading branch information
mbasmanova authored and rui-mo committed Apr 11, 2024
1 parent cef85a1 commit ae741ca
Show file tree
Hide file tree
Showing 8 changed files with 297 additions and 13 deletions.
2 changes: 1 addition & 1 deletion velox/expression/SignatureBinder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ TypePtr SignatureBinder::tryResolveType(
const std::unordered_map<std::string, SignatureVariable>& variables,
const std::unordered_map<std::string, TypePtr>& typeVariablesBindings,
std::unordered_map<std::string, int>& integerVariablesBindings) {
const auto baseName = typeSignature.baseName();
const auto& baseName = typeSignature.baseName();

if (variables.count(baseName)) {
auto it = typeVariablesBindings.find(baseName);
Expand Down
255 changes: 244 additions & 11 deletions velox/expression/tests/ExpressionFuzzer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,195 @@
namespace facebook::velox::test {

namespace {

int32_t rand(FuzzerGenerator& rng) {
return boost::random::uniform_int_distribution<int32_t>()(rng);
}

class DecimalArgGeneratorBase : public ArgGenerator {
public:
TypePtr generateReturnType(
const exec::FunctionSignature& /*signature*/,
FuzzerGenerator& /*rng*/) override {
auto [p, s] = inputs_.begin()->first;
return DECIMAL(p, s);
}

std::vector<TypePtr> generateArgs(
const exec::FunctionSignature& /*signature*/,
const TypePtr& returnType,
FuzzerGenerator& rng) override {
auto inputs = findInputs(returnType, rng);
if (inputs.a == nullptr || inputs.b == nullptr) {
return {};
}

return {std::move(inputs.a), std::move(inputs.b)};
}

protected:
// Compute result type for all possible pairs of decimal input types. Store
// the results in 'inputs_' maps keyed by return type.
void initialize() {
std::vector<TypePtr> allTypes;
for (auto p = 1; p < 38; ++p) {
for (auto s = 0; s <= p; ++s) {
allTypes.push_back(DECIMAL(p, s));
}
}

for (auto& a : allTypes) {
for (auto& b : allTypes) {
auto [p1, s1] = getDecimalPrecisionScale(*a);
auto [p2, s2] = getDecimalPrecisionScale(*b);

if (auto returnType = toReturnType(p1, s1, p2, s2)) {
inputs_[returnType.value()].push_back({a, b});
}
}
}
}

struct Inputs {
TypePtr a;
TypePtr b;
};

// Return randomly selected pair of input types that produce the specified
// result type.
Inputs findInputs(const TypePtr& returnType, FuzzerGenerator& rng) const {
auto [p, s] = getDecimalPrecisionScale(*returnType);
auto it = inputs_.find({p, s});
if (it == inputs_.end()) {
LOG(ERROR) << "Cannot find input types for " << returnType->toString();
return {nullptr, nullptr};
}

auto index = rand(rng) % it->second.size();
return it->second[index];
}

// Given precisions and scales of the inputs, return precision and scale of
// the result.
virtual std::optional<std::pair<int, int>>
toReturnType(int p1, int s1, int p2, int s2) = 0;

std::map<std::pair<int, int>, std::vector<Inputs>> inputs_;
};

class PlusMinusArgGenerator : public DecimalArgGeneratorBase {
public:
PlusMinusArgGenerator() {
initialize();
}

protected:
std::optional<std::pair<int, int>>
toReturnType(int p1, int s1, int p2, int s2) override {
auto s = std::max(s1, s2);
auto p = std::min(38, std::max(p1 - s1, p2 - s2) + 1 + s);
return {{p, s}};
}
};

class MultiplyArgGenerator : public DecimalArgGeneratorBase {
public:
MultiplyArgGenerator() {
initialize();
}

protected:
std::optional<std::pair<int, int>>
toReturnType(int p1, int s1, int p2, int s2) override {
if (s1 + s2 > 38) {
return std::nullopt;
}

auto p = std::min(38, p1 + p2);
auto s = s1 + s2;
return {{p, s}};
}
};

class DivideArgGenerator : public DecimalArgGeneratorBase {
public:
DivideArgGenerator() {
initialize();
}

protected:
std::optional<std::pair<int, int>>
toReturnType(int p1, int s1, int p2, int s2) override {
if (s1 + s2 > 38) {
return std::nullopt;
}

auto p = std::min(38, p1 + s2 + std::max(0, s2 - s1));
auto s = std::max(s1, s2);
return {{p, s}};
}
};

class FloorAndRoundArgGenerator : public ArgGenerator {
public:
TypePtr generateReturnType(
const exec::FunctionSignature& signature,
FuzzerGenerator& rng) override {
if (signature.argumentTypes().size() == 1) {
auto p = 1 + rand(rng) % 38;
return DECIMAL(p, 0);
} else {
auto p = 2 + rand(rng) % 37;
auto s = rand(rng) % p;
return DECIMAL(p, s);
}
}

std::vector<TypePtr> generateArgs(
const exec::FunctionSignature& signature,
const TypePtr& returnType,
FuzzerGenerator& rng) override {
if (signature.argumentTypes().size() == 1) {
return generateSingleArg(returnType, rng);
} else {
VELOX_CHECK_EQ(2, signature.argumentTypes().size())
return generateTwoArgs(returnType);
}
}

private:
std::vector<TypePtr> generateSingleArg(
const TypePtr& returnType,
FuzzerGenerator& rng) {
auto [p, s] = getDecimalPrecisionScale(*returnType);

// p = p1 - s1 + min(s1, 1)
// s = 0
if (s != 0) {
return {};
}

auto s1 = rand(rng) % (38 - p + 1);
if (s1 == 0) {
return {DECIMAL(p, 0)};
}

return {DECIMAL(p - 1 + s1, s1)};
}

std::vector<TypePtr> generateTwoArgs(const TypePtr& returnType) {
auto [p, s] = getDecimalPrecisionScale(*returnType);

// p = p1 + 1
// s = s1
if (p == 1 || p == s) {
return {};
}

return {DECIMAL(p - 1, s), INTEGER()};
}
};

using exec::SignatureBinder;
using exec::SignatureBinderBase;

Expand Down Expand Up @@ -249,8 +438,8 @@ static const std::unordered_map<
},
{
"cast",
/// TODO: Add supported Cast signatures to CastTypedExpr and expose
/// them to fuzzer instead of hard-coding signatures here.
// TODO: Add supported Cast signatures to CastTypedExpr and expose
// them to fuzzer instead of hard-coding signatures here.
getSignaturesForCast(),
},
};
Expand Down Expand Up @@ -440,9 +629,9 @@ bool isSupportedSignature(
// timestamp with time zone types.
return !(
useTypeName(signature, "opaque") ||
useTypeName(signature, "long_decimal") ||
useTypeName(signature, "short_decimal") ||
useTypeName(signature, "decimal") ||
// useTypeName(signature, "long_decimal") ||
// useTypeName(signature, "short_decimal") ||
// useTypeName(signature, "decimal") ||
useTypeName(signature, "timestamp with time zone") ||
useTypeName(signature, "interval day to second") ||
(enableComplexType && useTypeName(signature, "unknown")));
Expand Down Expand Up @@ -533,6 +722,15 @@ ExpressionFuzzer::ExpressionFuzzer(
VELOX_CHECK(vectorFuzzer, "Vector fuzzer must be provided");
seed(initialSeed);

argGenerators_.emplace("plus", std::make_shared<PlusMinusArgGenerator>());
argGenerators_.emplace("minus", std::make_shared<PlusMinusArgGenerator>());
argGenerators_.emplace("multiply", std::make_shared<MultiplyArgGenerator>());
argGenerators_.emplace("divide", std::make_shared<DivideArgGenerator>());
argGenerators_.emplace(
"floor", std::make_shared<FloorAndRoundArgGenerator>());
argGenerators_.emplace(
"round", std::make_shared<FloorAndRoundArgGenerator>());

appendSpecialForms(signatureMap, options_.specialForms);
filterSignatures(
signatureMap, options_.useOnlyFunctions, options_.skipFunctions);
Expand Down Expand Up @@ -590,11 +788,28 @@ ExpressionFuzzer::ExpressionFuzzer(
continue;
}
} else {
// TODO Remove this code. argTypes are used only to call
// isDeterministic() which can be be figured out without the argTypes.
ArgumentTypeFuzzer typeFuzzer{*signature, localRng};
typeFuzzer.fuzzReturnType();
VELOX_CHECK_EQ(
typeFuzzer.fuzzArgumentTypes(options_.maxNumVarArgs), true);
argTypes = typeFuzzer.argumentTypes();
auto returnType = typeFuzzer.fuzzReturnType();
bool ok = typeFuzzer.fuzzArgumentTypes(options_.maxNumVarArgs);
if (!ok) {
auto it = argGenerators_.find(function.first);
if (it != argGenerators_.end()) {
returnType = it->second->generateReturnType(*signature, localRng);
argTypes =
it->second->generateArgs(*signature, returnType, localRng);
VELOX_CHECK(
!argTypes.empty(),
"Failed to generate arguments for {} with return type {}",
function.first,
returnType->toString());
} else {
VELOX_FAIL("Failed to generate arguments");
}
} else {
argTypes = typeFuzzer.argumentTypes();
}
}
if (!isDeterministic(function.first, argTypes)) {
LOG(WARNING) << "Skipping non-deterministic function: "
Expand Down Expand Up @@ -1168,6 +1383,11 @@ const SignatureTemplate* ExpressionFuzzer::chooseRandomSignatureTemplate(
}
}
if (eligible.empty()) {
if (argGenerators_.find(functionName) != argGenerators_.end()) {
auto idx = rand32(0, signatureTemplates.size() - 1);
return signatureTemplates[idx];
}

return nullptr;
}

Expand Down Expand Up @@ -1223,8 +1443,21 @@ core::TypedExprPtr ExpressionFuzzer::generateExpressionFromSignatureTemplate(

auto chosenSignature = *chosen->signature;
ArgumentTypeFuzzer fuzzer{chosenSignature, returnType, rng_};
VELOX_CHECK_EQ(fuzzer.fuzzArgumentTypes(options_.maxNumVarArgs), true);
auto& argumentTypes = fuzzer.argumentTypes();
bool ok = fuzzer.fuzzArgumentTypes(options_.maxNumVarArgs);

std::vector<TypePtr> argumentTypes;
if (!ok) {
auto it = argGenerators_.find(functionName);
VELOX_CHECK(it != argGenerators_.end());

argumentTypes = it->second->generateArgs(chosenSignature, returnType, rng_);
if (argumentTypes.empty()) {
return nullptr;
}
} else {
argumentTypes = fuzzer.argumentTypes();
}

auto constantArguments = chosenSignature.constantArguments();

// ArgumentFuzzer may generate duplicate arguments if the signature's
Expand Down
21 changes: 21 additions & 0 deletions velox/expression/tests/ExpressionFuzzer.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,25 @@

namespace facebook::velox::test {

class ArgGenerator {
public:
virtual ~ArgGenerator() = default;

/// Given a signature and a concrete return type return randomly selected
/// valid input types. Returns empty vector if no input types can
/// produce the specified result type.
virtual std::vector<TypePtr> generateArgs(
const exec::FunctionSignature& signature,
const TypePtr& returnType,
FuzzerGenerator& rng) = 0;

/// Returns a randomly generated valid return type for a given signature.
/// TODO Remove. This API should not be needed.
virtual TypePtr generateReturnType(
const exec::FunctionSignature& signature,
FuzzerGenerator& rng) = 0;
};

// A tool that can be used to generate random expressions.
class ExpressionFuzzer {
public:
Expand Down Expand Up @@ -416,6 +435,8 @@ class ExpressionFuzzer {

} state;
friend class ExpressionFuzzerUnitTest;

std::unordered_map<std::string, std::shared_ptr<ArgGenerator>> argGenerators_;
};

} // namespace facebook::velox::test
3 changes: 2 additions & 1 deletion velox/expression/tests/ExpressionVerifier.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ ResultOrError ExpressionVerifier::verify(
bool canThrow,
std::vector<int> columnsToWrapInLazy) {
for (int i = 0; i < plans.size(); ++i) {
LOG(INFO) << "Executing expression " << i << " : " << plans[i]->toString();
LOG(INFO) << "Executing expression " << i << " : " << plans[i]->toString()
<< " -> " << plans[i]->type()->toString();
}
logRowVector(rowVector);

Expand Down
2 changes: 2 additions & 0 deletions velox/expression/tests/utils/ArgumentTypeFuzzer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,8 @@ bool ArgumentTypeFuzzer::fuzzArgumentTypes(uint32_t maxVariadicArgs) {
const auto& formalArgs = signature_.argumentTypes();
auto formalArgsCnt = formalArgs.size();

std::unordered_map<std::string, int> integerBindings;

if (returnType_) {
exec::ReverseSignatureBinder binder{signature_, returnType_};
if (!binder.tryBind()) {
Expand Down
1 change: 1 addition & 0 deletions velox/functions/prestosql/DecimalFunctions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ struct DecimalDivideFunction {
auto [rPrecision, rScale] =
computeResultPrecisionScale(aPrecision, aScale, bPrecision, bScale);
aRescale_ = computeRescaleFactor(aScale, bScale, rScale);
VELOX_USER_CHECK_LE(aRescale_, LongDecimalType::kMaxPrecision);
}

template <typename R, typename A, typename B>
Expand Down
Loading

0 comments on commit ae741ca

Please sign in to comment.