Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow marking arguments as constant in the function signature #4204

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 70 additions & 0 deletions velox/exec/tests/FunctionSignatureBuilderTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -177,3 +177,73 @@ TEST_F(FunctionSignatureBuilderTest, anyInReturn) {
},
"Type 'Any' cannot appear in return type");
}

TEST_F(FunctionSignatureBuilderTest, scalarConstantFlags) {
{
auto signature = FunctionSignatureBuilder()
.returnType("bigint")
.argumentType("double")
.constantArgumentType("boolean")
.argumentType("bigint")
.build();
EXPECT_FALSE(signature->constantArguments().at(0));
EXPECT_TRUE(signature->constantArguments().at(1));
EXPECT_FALSE(signature->constantArguments().at(2));
EXPECT_EQ(
"(double,constant boolean,bigint) -> bigint", signature->toString());
}

{
auto signature = FunctionSignatureBuilder()
.typeVariable("T")
.returnType("bigint")
.argumentType("double")
.constantArgumentType("T")
.argumentType("bigint")
.constantArgumentType("boolean")
.variableArity()
.build();
EXPECT_FALSE(signature->constantArguments().at(0));
EXPECT_TRUE(signature->constantArguments().at(1));
EXPECT_FALSE(signature->constantArguments().at(2));
EXPECT_EQ(
"(double,constant T,bigint,constant boolean...) -> bigint",
signature->toString());
}
}

TEST_F(FunctionSignatureBuilderTest, aggregateConstantFlags) {
{
auto aggSignature = AggregateFunctionSignatureBuilder()
.typeVariable("T")
.returnType("T")
.intermediateType("array(T)")
.argumentType("T")
.constantArgumentType("bigint")
.argumentType("T")
.build();
EXPECT_FALSE(aggSignature->constantArguments().at(0));
EXPECT_TRUE(aggSignature->constantArguments().at(1));
EXPECT_FALSE(aggSignature->constantArguments().at(2));
EXPECT_EQ("(T,constant bigint,T) -> T", aggSignature->toString());
}

{
auto aggSignature = AggregateFunctionSignatureBuilder()
.typeVariable("T")
.returnType("T")
.intermediateType("array(T)")
.argumentType("bigint")
.constantArgumentType("T")
.argumentType("T")
.constantArgumentType("double")
.variableArity()
.build();
EXPECT_FALSE(aggSignature->constantArguments().at(0));
EXPECT_TRUE(aggSignature->constantArguments().at(1));
EXPECT_FALSE(aggSignature->constantArguments().at(2));
EXPECT_EQ(
"(bigint,constant T,T,constant double...) -> T",
aggSignature->toString());
}
}
27 changes: 24 additions & 3 deletions velox/expression/FunctionSignature.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,19 @@ std::string TypeSignature::toString() const {
}

std::string FunctionSignature::toString() const {
std::vector<std::string> arguments;
auto size = argumentTypes_.size();
arguments.reserve(size);
for (auto i = 0; i < size; ++i) {
auto arg = argumentTypes_.at(i).toString();
if (constantArguments_.at(i)) {
arguments.emplace_back("constant " + arg);
} else {
arguments.emplace_back(arg);
}
}
std::ostringstream out;
out << "(" << folly::join(",", argumentTypes_);
out << "(" << folly::join(",", arguments);
if (variableArity_) {
out << "...";
}
Expand Down Expand Up @@ -175,7 +186,8 @@ void validateBaseTypeAndCollectTypeParams(
void validate(
const std::unordered_map<std::string, SignatureVariable>& variables,
const TypeSignature& returnType,
const std::vector<TypeSignature>& argumentTypes) {
const std::vector<TypeSignature>& argumentTypes,
const std::vector<bool>& constantArguments) {
std::unordered_set<std::string> usedVariables;
// Validate the argument types.
for (const auto& arg : argumentTypes) {
Expand All @@ -199,6 +211,11 @@ void validate(
usedVariables.size(),
variables.size(),
"Some integer variables are not used");

VELOX_USER_CHECK_EQ(
argumentTypes.size(),
constantArguments.size(),
"Argument types size is not equal to constant flags");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice check.

}

} // namespace
Expand All @@ -225,12 +242,14 @@ FunctionSignature::FunctionSignature(
std::unordered_map<std::string, SignatureVariable> variables,
TypeSignature returnType,
std::vector<TypeSignature> argumentTypes,
std::vector<bool> constantArguments,
bool variableArity)
: variables_{std::move(variables)},
returnType_{std::move(returnType)},
argumentTypes_{std::move(argumentTypes)},
constantArguments_{std::move(constantArguments)},
variableArity_{variableArity} {
validate(variables_, returnType_, argumentTypes_);
validate(variables_, returnType_, argumentTypes_, constantArguments_);
}

FunctionSignaturePtr FunctionSignatureBuilder::build() {
Expand All @@ -239,6 +258,7 @@ FunctionSignaturePtr FunctionSignatureBuilder::build() {
std::move(variables_),
returnType_.value(),
std::move(argumentTypes_),
std::move(constantArguments_),
variableArity_);
}

Expand All @@ -251,6 +271,7 @@ AggregateFunctionSignatureBuilder::build() {
returnType_.value(),
intermediateType_.value(),
std::move(argumentTypes_),
std::move(constantArguments_),
variableArity_);
}
} // namespace facebook::velox::exec
25 changes: 25 additions & 0 deletions velox/expression/FunctionSignature.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ class FunctionSignature {
std::unordered_map<std::string, SignatureVariable> variables,
TypeSignature returnType,
std::vector<TypeSignature> argumentTypes,
std::vector<bool> constantArguments,
bool variableArity);

const TypeSignature& returnType() const {
Expand All @@ -138,6 +139,10 @@ class FunctionSignature {
return argumentTypes_;
}

const std::vector<bool>& constantArguments() const {
return constantArguments_;
}

bool variableArity() const {
return variableArity_;
}
Expand All @@ -161,6 +166,7 @@ class FunctionSignature {
const std::unordered_map<std::string, SignatureVariable> variables_;
const TypeSignature returnType_;
const std::vector<TypeSignature> argumentTypes_;
const std::vector<bool> constantArguments_;
const bool variableArity_;
};

Expand All @@ -173,11 +179,13 @@ class AggregateFunctionSignature : public FunctionSignature {
TypeSignature returnType,
TypeSignature intermediateType,
std::vector<TypeSignature> argumentTypes,
std::vector<bool> constantArguments,
bool variableArity)
: FunctionSignature(
std::move(variables),
std::move(returnType),
std::move(argumentTypes),
std::move(constantArguments),
variableArity),
intermediateType_{std::move(intermediateType)} {}

Expand Down Expand Up @@ -264,6 +272,13 @@ class FunctionSignatureBuilder {

FunctionSignatureBuilder& argumentType(const std::string& type) {
argumentTypes_.emplace_back(parseTypeSignature(type));
constantArguments_.push_back(false);
return *this;
}

FunctionSignatureBuilder& constantArgumentType(const std::string& type) {
argumentTypes_.emplace_back(parseTypeSignature(type));
constantArguments_.push_back(true);
return *this;
}

Expand All @@ -278,6 +293,7 @@ class FunctionSignatureBuilder {
std::unordered_map<std::string, SignatureVariable> variables_;
std::optional<TypeSignature> returnType_;
std::vector<TypeSignature> argumentTypes_;
std::vector<bool> constantArguments_;
bool variableArity_{false};
};

Expand Down Expand Up @@ -325,6 +341,14 @@ class AggregateFunctionSignatureBuilder {

AggregateFunctionSignatureBuilder& argumentType(const std::string& type) {
argumentTypes_.emplace_back(parseTypeSignature(type));
constantArguments_.push_back(false);
return *this;
}

AggregateFunctionSignatureBuilder& constantArgumentType(
const std::string& type) {
argumentTypes_.emplace_back(parseTypeSignature(type));
constantArguments_.push_back(true);
return *this;
}

Expand All @@ -345,6 +369,7 @@ class AggregateFunctionSignatureBuilder {
std::optional<TypeSignature> returnType_;
std::optional<TypeSignature> intermediateType_;
std::vector<TypeSignature> argumentTypes_;
std::vector<bool> constantArguments_;
bool variableArity_{false};
};

Expand Down