Skip to content

Commit

Permalink
Allow marking arguments as constant in the function signature (#4204)
Browse files Browse the repository at this point in the history
Summary:
Details in #4140

Pull Request resolved: #4204

Reviewed By: Yuhta

Differential Revision: D43876547

Pulled By: mbasmanova

fbshipit-source-id: 58e7a4ada5e9edcd5bfad6d2a7e76471fae7a014
  • Loading branch information
duanmeng authored and facebook-github-bot committed Mar 7, 2023
1 parent bb7a20a commit 09bd177
Show file tree
Hide file tree
Showing 3 changed files with 119 additions and 3 deletions.
70 changes: 70 additions & 0 deletions velox/exec/tests/FunctionSignatureBuilderTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -177,3 +177,73 @@ TEST_F(FunctionSignatureBuilderTest, anyInReturn) {
},
"Type 'Any' cannot appear in return type");
}

TEST_F(FunctionSignatureBuilderTest, scalarConstantFlags) {
{
auto signature = FunctionSignatureBuilder()
.returnType("bigint")
.argumentType("double")
.constantArgumentType("boolean")
.argumentType("bigint")
.build();
EXPECT_FALSE(signature->constantArguments().at(0));
EXPECT_TRUE(signature->constantArguments().at(1));
EXPECT_FALSE(signature->constantArguments().at(2));
EXPECT_EQ(
"(double,constant boolean,bigint) -> bigint", signature->toString());
}

{
auto signature = FunctionSignatureBuilder()
.typeVariable("T")
.returnType("bigint")
.argumentType("double")
.constantArgumentType("T")
.argumentType("bigint")
.constantArgumentType("boolean")
.variableArity()
.build();
EXPECT_FALSE(signature->constantArguments().at(0));
EXPECT_TRUE(signature->constantArguments().at(1));
EXPECT_FALSE(signature->constantArguments().at(2));
EXPECT_EQ(
"(double,constant T,bigint,constant boolean...) -> bigint",
signature->toString());
}
}

TEST_F(FunctionSignatureBuilderTest, aggregateConstantFlags) {
{
auto aggSignature = AggregateFunctionSignatureBuilder()
.typeVariable("T")
.returnType("T")
.intermediateType("array(T)")
.argumentType("T")
.constantArgumentType("bigint")
.argumentType("T")
.build();
EXPECT_FALSE(aggSignature->constantArguments().at(0));
EXPECT_TRUE(aggSignature->constantArguments().at(1));
EXPECT_FALSE(aggSignature->constantArguments().at(2));
EXPECT_EQ("(T,constant bigint,T) -> T", aggSignature->toString());
}

{
auto aggSignature = AggregateFunctionSignatureBuilder()
.typeVariable("T")
.returnType("T")
.intermediateType("array(T)")
.argumentType("bigint")
.constantArgumentType("T")
.argumentType("T")
.constantArgumentType("double")
.variableArity()
.build();
EXPECT_FALSE(aggSignature->constantArguments().at(0));
EXPECT_TRUE(aggSignature->constantArguments().at(1));
EXPECT_FALSE(aggSignature->constantArguments().at(2));
EXPECT_EQ(
"(bigint,constant T,T,constant double...) -> T",
aggSignature->toString());
}
}
27 changes: 24 additions & 3 deletions velox/expression/FunctionSignature.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,19 @@ std::string TypeSignature::toString() const {
}

std::string FunctionSignature::toString() const {
std::vector<std::string> arguments;
auto size = argumentTypes_.size();
arguments.reserve(size);
for (auto i = 0; i < size; ++i) {
auto arg = argumentTypes_.at(i).toString();
if (constantArguments_.at(i)) {
arguments.emplace_back("constant " + arg);
} else {
arguments.emplace_back(arg);
}
}
std::ostringstream out;
out << "(" << folly::join(",", argumentTypes_);
out << "(" << folly::join(",", arguments);
if (variableArity_) {
out << "...";
}
Expand Down Expand Up @@ -175,7 +186,8 @@ void validateBaseTypeAndCollectTypeParams(
void validate(
const std::unordered_map<std::string, SignatureVariable>& variables,
const TypeSignature& returnType,
const std::vector<TypeSignature>& argumentTypes) {
const std::vector<TypeSignature>& argumentTypes,
const std::vector<bool>& constantArguments) {
std::unordered_set<std::string> usedVariables;
// Validate the argument types.
for (const auto& arg : argumentTypes) {
Expand All @@ -199,6 +211,11 @@ void validate(
usedVariables.size(),
variables.size(),
"Some integer variables are not used");

VELOX_USER_CHECK_EQ(
argumentTypes.size(),
constantArguments.size(),
"Argument types size is not equal to constant flags");
}

} // namespace
Expand All @@ -225,12 +242,14 @@ FunctionSignature::FunctionSignature(
std::unordered_map<std::string, SignatureVariable> variables,
TypeSignature returnType,
std::vector<TypeSignature> argumentTypes,
std::vector<bool> constantArguments,
bool variableArity)
: variables_{std::move(variables)},
returnType_{std::move(returnType)},
argumentTypes_{std::move(argumentTypes)},
constantArguments_{std::move(constantArguments)},
variableArity_{variableArity} {
validate(variables_, returnType_, argumentTypes_);
validate(variables_, returnType_, argumentTypes_, constantArguments_);
}

FunctionSignaturePtr FunctionSignatureBuilder::build() {
Expand All @@ -239,6 +258,7 @@ FunctionSignaturePtr FunctionSignatureBuilder::build() {
std::move(variables_),
returnType_.value(),
std::move(argumentTypes_),
std::move(constantArguments_),
variableArity_);
}

Expand All @@ -251,6 +271,7 @@ AggregateFunctionSignatureBuilder::build() {
returnType_.value(),
intermediateType_.value(),
std::move(argumentTypes_),
std::move(constantArguments_),
variableArity_);
}
} // namespace facebook::velox::exec
25 changes: 25 additions & 0 deletions velox/expression/FunctionSignature.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ class FunctionSignature {
std::unordered_map<std::string, SignatureVariable> variables,
TypeSignature returnType,
std::vector<TypeSignature> argumentTypes,
std::vector<bool> constantArguments,
bool variableArity);

const TypeSignature& returnType() const {
Expand All @@ -138,6 +139,10 @@ class FunctionSignature {
return argumentTypes_;
}

const std::vector<bool>& constantArguments() const {
return constantArguments_;
}

bool variableArity() const {
return variableArity_;
}
Expand All @@ -161,6 +166,7 @@ class FunctionSignature {
const std::unordered_map<std::string, SignatureVariable> variables_;
const TypeSignature returnType_;
const std::vector<TypeSignature> argumentTypes_;
const std::vector<bool> constantArguments_;
const bool variableArity_;
};

Expand All @@ -173,11 +179,13 @@ class AggregateFunctionSignature : public FunctionSignature {
TypeSignature returnType,
TypeSignature intermediateType,
std::vector<TypeSignature> argumentTypes,
std::vector<bool> constantArguments,
bool variableArity)
: FunctionSignature(
std::move(variables),
std::move(returnType),
std::move(argumentTypes),
std::move(constantArguments),
variableArity),
intermediateType_{std::move(intermediateType)} {}

Expand Down Expand Up @@ -264,6 +272,13 @@ class FunctionSignatureBuilder {

FunctionSignatureBuilder& argumentType(const std::string& type) {
argumentTypes_.emplace_back(parseTypeSignature(type));
constantArguments_.push_back(false);
return *this;
}

FunctionSignatureBuilder& constantArgumentType(const std::string& type) {
argumentTypes_.emplace_back(parseTypeSignature(type));
constantArguments_.push_back(true);
return *this;
}

Expand All @@ -278,6 +293,7 @@ class FunctionSignatureBuilder {
std::unordered_map<std::string, SignatureVariable> variables_;
std::optional<TypeSignature> returnType_;
std::vector<TypeSignature> argumentTypes_;
std::vector<bool> constantArguments_;
bool variableArity_{false};
};

Expand Down Expand Up @@ -325,6 +341,14 @@ class AggregateFunctionSignatureBuilder {

AggregateFunctionSignatureBuilder& argumentType(const std::string& type) {
argumentTypes_.emplace_back(parseTypeSignature(type));
constantArguments_.push_back(false);
return *this;
}

AggregateFunctionSignatureBuilder& constantArgumentType(
const std::string& type) {
argumentTypes_.emplace_back(parseTypeSignature(type));
constantArguments_.push_back(true);
return *this;
}

Expand All @@ -345,6 +369,7 @@ class AggregateFunctionSignatureBuilder {
std::optional<TypeSignature> returnType_;
std::optional<TypeSignature> intermediateType_;
std::vector<TypeSignature> argumentTypes_;
std::vector<bool> constantArguments_;
bool variableArity_{false};
};

Expand Down

0 comments on commit 09bd177

Please sign in to comment.