Skip to content

Commit

Permalink
Implement Compare functions for Spark
Browse files Browse the repository at this point in the history
  • Loading branch information
yma11 committed Jul 10, 2023
1 parent 480c9c1 commit 1c31582
Show file tree
Hide file tree
Showing 6 changed files with 561 additions and 8 deletions.
1 change: 1 addition & 0 deletions velox/functions/sparksql/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ add_library(
ArraySort.cpp
Bitwise.cpp
CompareFunctionsNullSafe.cpp
Comparisons.cpp
Hash.cpp
In.cpp
LeastGreatest.cpp
Expand Down
155 changes: 155 additions & 0 deletions velox/functions/sparksql/Comparisons.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "velox/functions/sparksql/LeastGreatest.h"

#include "velox/expression/EvalCtx.h"
#include "velox/expression/Expr.h"
#include "velox/functions/sparksql/Comparisons.h"
#include "velox/type/Type.h"

namespace facebook::velox::functions::sparksql {
namespace {

template <typename Cmp, TypeKind kind>
class ComparisonFunction final : public exec::VectorFunction {
using T = typename TypeTraits<kind>::NativeType;

bool isDefaultNullBehavior() const override {
return true;
}

bool supportsFlatNoNullsFastPath() const override {
return true;
}

void apply(
const SelectivityVector& rows,
std::vector<VectorPtr>& args,
const TypePtr& outputType,
exec::EvalCtx& context,
VectorPtr& result) const override {
context.ensureWritable(rows, BOOLEAN(), result);
auto* flatResult = result->asFlatVector<bool>();
flatResult->mutableRawValues<uint64_t>();
const Cmp cmp;
if (args[0]->isFlatEncoding() && args[1]->isFlatEncoding()) {
// Fast path for (flat, flat).
auto flatA = args[0]->asUnchecked<FlatVector<T>>();
auto rawA = flatA->mutableRawValues();
auto flatB = args[1]->asUnchecked<FlatVector<T>>();
auto rawB = flatB->mutableRawValues();
rows.applyToSelected(
[&](vector_size_t i) { flatResult->set(i, cmp(rawA[i], rawB[i])); });
} else if (args[0]->isConstantEncoding() && args[1]->isFlatEncoding()) {
// Fast path for (const, flat).
auto constant = args[0]->asUnchecked<SimpleVector<T>>()->valueAt(0);
auto flatValues = args[1]->asUnchecked<FlatVector<T>>();
auto rawValues = flatValues->mutableRawValues();
rows.applyToSelected([&](vector_size_t i) {
flatResult->set(i, cmp(constant, rawValues[i]));
});
} else if (args[0]->isFlatEncoding() && args[1]->isConstantEncoding()) {
// Fast path for (flat, const).
auto flatValues = args[0]->asUnchecked<FlatVector<T>>();
auto constant = args[1]->asUnchecked<SimpleVector<T>>()->valueAt(0);
auto rawValues = flatValues->mutableRawValues();
rows.applyToSelected([&](vector_size_t i) {
flatResult->set(i, cmp(rawValues[i], constant));
});
} else {
// Fast path if one or more arguments are encoded.
exec::DecodedArgs decodedArgs(rows, args, context);
auto decoded0 = decodedArgs.at(0);
auto decoded1 = decodedArgs.at(1);
rows.applyToSelected([&](vector_size_t i) {
flatResult->set(
i, cmp(decoded0->valueAt<T>(i), decoded1->valueAt<T>(i)));
});
}
}
};

template <template <typename> class Cmp>
std::shared_ptr<exec::VectorFunction> makeImpl(
const std::string& functionName,
const std::vector<exec::VectorFunctionArg>& args) {
VELOX_CHECK_EQ(args.size(), 2);
for (size_t i = 1; i < args.size(); i++) {
VELOX_CHECK(*args[i].type == *args[0].type);
}
switch (args[0].type->kind()) {
#define SCALAR_CASE(kind) \
case TypeKind::kind: \
return std::make_shared<ComparisonFunction< \
Cmp<TypeTraits<TypeKind::kind>::NativeType>, \
TypeKind::kind>>();
SCALAR_CASE(BOOLEAN)
SCALAR_CASE(TINYINT)
SCALAR_CASE(SMALLINT)
SCALAR_CASE(INTEGER)
SCALAR_CASE(BIGINT)
SCALAR_CASE(HUGEINT)
SCALAR_CASE(REAL)
SCALAR_CASE(DOUBLE)
SCALAR_CASE(VARCHAR)
SCALAR_CASE(VARBINARY)
SCALAR_CASE(TIMESTAMP)
#undef SCALAR_CASE
default:
VELOX_NYI(
"{} does not support arguments of type {}",
functionName,
args[0].type->kind());
}
}

} // namespace

std::shared_ptr<exec::VectorFunction> makeEqualTo(
const std::string& functionName,
const std::vector<exec::VectorFunctionArg>& args,
const core::QueryConfig& /*config*/) {
return makeImpl<Equal>(functionName, args);
}

std::shared_ptr<exec::VectorFunction> makeLessThan(
const std::string& functionName,
const std::vector<exec::VectorFunctionArg>& args,
const core::QueryConfig& /*config*/) {
return makeImpl<Less>(functionName, args);
}

std::shared_ptr<exec::VectorFunction> makeGreaterThan(
const std::string& functionName,
const std::vector<exec::VectorFunctionArg>& args,
const core::QueryConfig& /*config*/) {
return makeImpl<Greater>(functionName, args);
}

std::shared_ptr<exec::VectorFunction> makeLessThanOrEqual(
const std::string& functionName,
const std::vector<exec::VectorFunctionArg>& args,
const core::QueryConfig& /*config*/) {
return makeImpl<LessOrEqual>(functionName, args);
}

std::shared_ptr<exec::VectorFunction> makeGreaterThanOrEqual(
const std::string& functionName,
const std::vector<exec::VectorFunctionArg>& args,
const core::QueryConfig& /*config*/) {
return makeImpl<GreaterOrEqual>(functionName, args);
}
} // namespace facebook::velox::functions::sparksql
84 changes: 84 additions & 0 deletions velox/functions/sparksql/Comparisons.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once

#include <memory>

#include "velox/expression/VectorFunction.h"

namespace facebook::velox::functions::sparksql {
Expand Down Expand Up @@ -59,4 +63,84 @@ struct Equal {
}
};

template <typename T>
struct LessOrEqual {
constexpr bool operator()(const T& a, const T& b) const {
Less<T> less;
Equal<T> equal;
return less(a, b) || equal(a, b);
}
};

template <typename T>
struct GreaterOrEqual : private Less<T> {
constexpr bool operator()(const T& a, const T& b) const {
Less<T> less;
Equal<T> equal;
return less(b, a) || equal(a, b);
}
};

/// Supported Types:
/// TINYINT
/// SMALLINT
/// INTEGER
/// BIGINT
/// REAL
/// DOUBLE
/// BOOLEAN
/// VARCHAR
/// TIMESTAMP

/// Special cases:
/// NaN in Spark is handled differently from standard floating point semantics.
/// It is considered larger than any other numeric values.

std::shared_ptr<exec::VectorFunction> makeEqualTo(
const std::string& name,
const std::vector<exec::VectorFunctionArg>& inputArgs,
const core::QueryConfig& /*config*/);

std::shared_ptr<exec::VectorFunction> makeLessThan(
const std::string& name,
const std::vector<exec::VectorFunctionArg>& inputArgs,
const core::QueryConfig& /*config*/);

std::shared_ptr<exec::VectorFunction> makeGreaterThan(
const std::string& name,
const std::vector<exec::VectorFunctionArg>& inputArgs,
const core::QueryConfig& /*config*/);

std::shared_ptr<exec::VectorFunction> makeLessThanOrEqual(
const std::string& name,
const std::vector<exec::VectorFunctionArg>& inputArgs,
const core::QueryConfig& /*config*/);

std::shared_ptr<exec::VectorFunction> makeGreaterThanOrEqual(
const std::string& name,
const std::vector<exec::VectorFunctionArg>& inputArgs,
const core::QueryConfig& /*config*/);

inline std::vector<std::shared_ptr<exec::FunctionSignature>>
comparisonSignatures() {
return {exec::FunctionSignatureBuilder()
.typeVariable("T")
.returnType("boolean")
.argumentType("T")
.argumentType("T")
.build()};
}

template <typename T>
struct BetweenFunction {
template <typename TInput>
FOLLY_ALWAYS_INLINE void call(
bool& result,
const TInput& value,
const TInput& low,
const TInput& high) {
result = value >= low && value <= high;
}
};

} // namespace facebook::velox::functions::sparksql
30 changes: 22 additions & 8 deletions velox/functions/sparksql/RegisterCompare.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,28 @@
#include "velox/functions/sparksql/RegisterCompare.h"

#include "velox/functions/lib/RegistrationHelpers.h"
#include "velox/functions/prestosql/Comparisons.h"
#include "velox/functions/sparksql/CompareFunctionsNullSafe.h"
#include "velox/functions/sparksql/Comparisons.h"

namespace facebook::velox::functions::sparksql {

void registerCompareFunctions(const std::string& prefix) {
registerBinaryScalar<EqFunction, bool>({prefix + "equalto"});
registerBinaryScalar<NeqFunction, bool>({prefix + "notequalto"});
registerBinaryScalar<LtFunction, bool>({prefix + "lessthan"});
registerBinaryScalar<GtFunction, bool>({prefix + "greaterthan"});
registerBinaryScalar<LteFunction, bool>({prefix + "lessthanorequal"});
registerBinaryScalar<GteFunction, bool>({prefix + "greaterthanorequal"});

// Register compare functions
exec::registerStatefulVectorFunction(
prefix + "equalto", comparisonSignatures(), makeEqualTo);
exec::registerStatefulVectorFunction(
prefix + "lessthan", comparisonSignatures(), makeLessThan);
exec::registerStatefulVectorFunction(
prefix + "greaterthan", comparisonSignatures(), makeGreaterThan);
exec::registerStatefulVectorFunction(
prefix + "lessthanorequal", comparisonSignatures(), makeLessThanOrEqual);
exec::registerStatefulVectorFunction(
prefix + "greaterthanorequal",
comparisonSignatures(),
makeGreaterThanOrEqual);
// Compare nullsafe functions
exec::registerStatefulVectorFunction(
prefix + "equalnullsafe", equalNullSafeSignatures(), makeEqualNullSafe);
registerFunction<BetweenFunction, bool, int8_t, int8_t, int8_t>(
{prefix + "between"});
registerFunction<BetweenFunction, bool, int16_t, int16_t, int16_t>(
Expand All @@ -40,6 +50,10 @@ void registerCompareFunctions(const std::string& prefix) {
{prefix + "between"});
registerFunction<BetweenFunction, bool, float, float, float>(
{prefix + "between"});
registerFunction<BetweenFunction, bool, int64_t, int64_t, int64_t>(
{prefix + "between"});
registerFunction<BetweenFunction, bool, int128_t, int128_t, int128_t>(
{prefix + "between"});
}

} // namespace facebook::velox::functions::sparksql
1 change: 1 addition & 0 deletions velox/functions/sparksql/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ add_executable(
ArraySortTest.cpp
BitwiseTest.cpp
CompareNullSafeTests.cpp
CompareTests.cpp
DateTimeFunctionsTest.cpp
ElementAtTest.cpp
HashTest.cpp
Expand Down
Loading

0 comments on commit 1c31582

Please sign in to comment.