Skip to content

Commit

Permalink
orderCompara
Browse files Browse the repository at this point in the history
  • Loading branch information
duanmeng committed Oct 5, 2023
1 parent fb944bb commit 905275c
Show file tree
Hide file tree
Showing 5 changed files with 275 additions and 10 deletions.
89 changes: 89 additions & 0 deletions velox/exec/tests/FunctionSignatureBuilderTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -252,3 +252,92 @@ TEST_F(FunctionSignatureBuilderTest, toString) {

ASSERT_EQ("foo(BIGINT, VARCHAR)", toString("foo", {BIGINT(), VARCHAR()}));
}

TEST_F(FunctionSignatureBuilderTest, orderableComparable) {
{
auto signature = FunctionSignatureBuilder()
.typeVariable("T")
.returnType("array(T)")
.argumentType("array(T)")
.build();
ASSERT_FALSE(signature->variables().at("T").orderableTypesOnly());
ASSERT_FALSE(signature->variables().at("T").comparableTypesOnly());
}

{
auto signature = FunctionSignatureBuilder()
.orderableTypeVariable("T")
.returnType("array(T)")
.argumentType("array(T)")
.build();
ASSERT_TRUE(signature->variables().at("T").orderableTypesOnly());
ASSERT_FALSE(signature->variables().at("T").comparableTypesOnly());
}

{
auto signature = FunctionSignatureBuilder()
.comparableTypeVariable("T")
.returnType("array(T)")
.argumentType("array(T)")
.build();
ASSERT_FALSE(signature->variables().at("T").orderableTypesOnly());
ASSERT_TRUE(signature->variables().at("T").comparableTypesOnly());
}

{
auto signature = FunctionSignatureBuilder()
.orderableTypeVariable("T")
.comparableTypeVariable("T")
.returnType("array(T)")
.argumentType("array(T)")
.build();
ASSERT_TRUE(signature->variables().at("T").orderableTypesOnly());
ASSERT_TRUE(signature->variables().at("T").comparableTypesOnly());
}
}

TEST_F(FunctionSignatureBuilderTest, orderableComparableAggregate) {
{
auto signature = exec::AggregateFunctionSignatureBuilder()
.typeVariable("T")
.returnType("T")
.intermediateType("T")
.argumentType("T")
.build();
ASSERT_FALSE(signature->variables().at("T").orderableTypesOnly());
ASSERT_FALSE(signature->variables().at("T").comparableTypesOnly());
}

{
auto signature = exec::AggregateFunctionSignatureBuilder()
.orderableTypeVariable("T")
.returnType("T")
.intermediateType("T")
.argumentType("T")
.build();
ASSERT_TRUE(signature->variables().at("T").orderableTypesOnly());
ASSERT_FALSE(signature->variables().at("T").comparableTypesOnly());
}

{
auto signature = exec::AggregateFunctionSignatureBuilder()
.comparableTypeVariable("T")
.returnType("T")
.intermediateType("T")
.argumentType("T")
.build();
ASSERT_FALSE(signature->variables().at("T").orderableTypesOnly());
ASSERT_TRUE(signature->variables().at("T").comparableTypesOnly());
}

{
auto signature = FunctionSignatureBuilder()
.orderableTypeVariable("T")
.comparableTypeVariable("T")
.returnType("array(T)")
.argumentType("array(T)")
.build();
ASSERT_TRUE(signature->variables().at("T").orderableTypesOnly());
ASSERT_TRUE(signature->variables().at("T").comparableTypesOnly());
}
}
8 changes: 6 additions & 2 deletions velox/expression/FunctionSignature.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -229,11 +229,15 @@ SignatureVariable::SignatureVariable(
std::string name,
std::optional<std::string> constraint,
ParameterType type,
bool knownTypesOnly)
bool knownTypesOnly,
bool orderableTypesOnly,
bool comaprableTypesOnly)
: name_{std::move(name)},
constraint_(constraint.has_value() ? std::move(constraint.value()) : ""),
type_{type},
knownTypesOnly_(knownTypesOnly) {
knownTypesOnly_(knownTypesOnly),
orderableTypesOnly_(orderableTypesOnly),
comparableTypesOnly_(comaprableTypesOnly) {
VELOX_CHECK(
!knownTypesOnly_ || isTypeParameter(),
"Non-Type variables cannot have the knownTypesOnly constraint");
Expand Down
133 changes: 125 additions & 8 deletions velox/expression/FunctionSignature.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,9 @@ class SignatureVariable {
std::string name,
std::optional<std::string> constraint,
ParameterType type,
bool knownTypesOnly = false);
bool knownTypesOnly = false,
bool orderableTypesOnly = false,
bool comparableTypesOnly = false);

const std::string& name() const {
return name_;
Expand All @@ -58,6 +60,27 @@ class SignatureVariable {
return knownTypesOnly_;
}

void setKnownTypesOnly(bool knownTypesOnly) {
VELOX_USER_CHECK(isTypeParameter());
knownTypesOnly_ = knownTypesOnly;
}

bool orderableTypesOnly() const {
return orderableTypesOnly_;
}

void setOrderableTypesOnly(bool orderableTypesOnly) {
orderableTypesOnly_ = orderableTypesOnly;
}

bool comparableTypesOnly() const {
return comparableTypesOnly_;
}

void setComparableTypesOnly(bool comparableTypesOnly) {
comparableTypesOnly_ = comparableTypesOnly;
}

bool isTypeParameter() const {
return type_ == ParameterType::kTypeParameter;
}
Expand All @@ -69,7 +92,9 @@ class SignatureVariable {
bool operator==(const SignatureVariable& rhs) const {
return type_ == rhs.type_ && name_ == rhs.name_ &&
constraint_ == rhs.constraint_ &&
knownTypesOnly_ == rhs.knownTypesOnly_;
knownTypesOnly_ == rhs.knownTypesOnly_ &&
orderableTypesOnly_ == rhs.orderableTypesOnly_ &&
comparableTypesOnly_ == rhs.comparableTypesOnly_;
}

private:
Expand All @@ -79,6 +104,8 @@ class SignatureVariable {
// This property only applies to type variables and indicates if the type
// can bind to unknown or not.
bool knownTypesOnly_ = false;
bool orderableTypesOnly_ = false;
bool comparableTypesOnly_ = false;
};

// Base type (e.g. map) and optional parameters (e.g. K, V).
Expand Down Expand Up @@ -259,9 +286,53 @@ class FunctionSignatureBuilder {
}

FunctionSignatureBuilder& knownTypeVariable(const std::string& name) {
addVariable(
variables_,
SignatureVariable(name, "", ParameterType::kTypeParameter, true));
if (variables_.find(name) == variables_.end()) {
addVariable(
variables_,
SignatureVariable(
name,
"",
ParameterType::kTypeParameter,
/*knownTypesOnly*/ true,
/*orderableTypesOnly*/ false,
/*comaprableTypesOnly*/ false));
} else {
variables_.at(name).setKnownTypesOnly(true);
}
return *this;
}

FunctionSignatureBuilder& orderableTypeVariable(const std::string& name) {
if (variables_.find(name) == variables_.end()) {
addVariable(
variables_,
SignatureVariable(
name,
"",
ParameterType::kTypeParameter,
/*knownTypesOnly*/ false,
/*orderableTypesOnly*/ true,
/*comaprableTypesOnly*/ false));
} else {
variables_.at(name).setOrderableTypesOnly(true);
}
return *this;
}

FunctionSignatureBuilder& comparableTypeVariable(const std::string& name) {
if (variables_.find(name) == variables_.end()) {
addVariable(
variables_,
SignatureVariable(
name,
"",
ParameterType::kTypeParameter,
/*knownTypesOnly*/ false,
/*orderableTypesOnly*/ false,
/*comaprableTypesOnly*/ true));
} else {
variables_.at(name).setComparableTypesOnly(true);
}
return *this;
}

Expand Down Expand Up @@ -328,9 +399,55 @@ class AggregateFunctionSignatureBuilder {

AggregateFunctionSignatureBuilder& knownTypeVariable(
const std::string& name) {
addVariable(
variables_,
SignatureVariable(name, "", ParameterType::kTypeParameter, true));
if (variables_.find(name) == variables_.end()) {
addVariable(
variables_,
SignatureVariable(
name,
"",
ParameterType::kTypeParameter,
/*knownTypesOnly*/ true,
/*orderableTypesOnly*/ false,
/*comaprableTypesOnly*/ false));
} else {
variables_.at(name).setKnownTypesOnly(true);
}
return *this;
}

AggregateFunctionSignatureBuilder& orderableTypeVariable(
const std::string& name) {
if (variables_.find(name) == variables_.end()) {
addVariable(
variables_,
SignatureVariable(
name,
"",
ParameterType::kTypeParameter,
/*knownTypesOnly*/ false,
/*orderableTypesOnly*/ true,
/*comaprableTypesOnly*/ false));
} else {
variables_.at(name).setOrderableTypesOnly(true);
}
return *this;
}

AggregateFunctionSignatureBuilder& comparableTypeVariable(
const std::string& name) {
if (variables_.find(name) == variables_.end()) {
addVariable(
variables_,
SignatureVariable(
name,
"",
ParameterType::kTypeParameter,
/*knownTypesOnly*/ false,
/*orderableTypesOnly*/ false,
/*comaprableTypesOnly*/ true));
} else {
variables_.at(name).setComparableTypesOnly(true);
}
return *this;
}

Expand Down
8 changes: 8 additions & 0 deletions velox/expression/SignatureBinder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,14 @@ bool SignatureBinderBase::tryBind(
return false;
}

if (variable.orderableTypesOnly() && !actualType->isOrderable()) {
return false;
}

if (variable.comparableTypesOnly() && !actualType->isComparable()) {
return false;
}

typeVariablesBindings_[baseName] = actualType;
return true;
}
Expand Down
47 changes: 47 additions & 0 deletions velox/expression/tests/SignatureBinderTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,53 @@ TEST(SignatureBinderTest, knownOnly) {
}
}

TEST(SignatureBinderTest, orderableComparable) {
auto signature = exec::FunctionSignatureBuilder()
.orderableTypeVariable("T")
.comparableTypeVariable("T")
.returnType("array(T)")
.argumentType("array(T)")
.build();
{
auto actualTypes = std::vector<TypePtr>{ARRAY(BIGINT())};
exec::SignatureBinder binder(*signature, actualTypes);
ASSERT_TRUE(binder.tryBind());
}

{
auto actualTypes = std::vector<TypePtr>{ARRAY(MAP(BIGINT(), BIGINT()))};
exec::SignatureBinder binder(*signature, actualTypes);
ASSERT_FALSE(binder.tryBind());
}

signature = exec::FunctionSignatureBuilder()
.typeVariable("V")
.orderableTypeVariable("T")
.comparableTypeVariable("T")
.returnType("ROW(V)")
.argumentType("ROW(V, T)")
.build();
{
auto actualTypes = std::vector<TypePtr>{ROW({BIGINT(), ARRAY(DOUBLE())})};
exec::SignatureBinder binder(*signature, actualTypes);
ASSERT_TRUE(binder.tryBind());
}

{
auto actualTypes =
std::vector<TypePtr>{ROW({MAP(VARCHAR(), BIGINT()), ARRAY(DOUBLE())})};
exec::SignatureBinder binder(*signature, actualTypes);
ASSERT_TRUE(binder.tryBind());
}

{
auto actualTypes =
std::vector<TypePtr>{ROW({BIGINT(), MAP(VARCHAR(), BIGINT())})};
exec::SignatureBinder binder(*signature, actualTypes);
ASSERT_FALSE(binder.tryBind());
}
}

TEST(SignatureBinderTest, generics) {
// array(T), T -> boolean
{
Expand Down

0 comments on commit 905275c

Please sign in to comment.