From 1c31582714d978bc3d14e1082732bb45e3c55f2d Mon Sep 17 00:00:00 2001 From: mayan Date: Mon, 10 Jul 2023 09:57:39 +0000 Subject: [PATCH] Implement Compare functions for Spark --- velox/functions/sparksql/CMakeLists.txt | 1 + velox/functions/sparksql/Comparisons.cpp | 155 +++++++++ velox/functions/sparksql/Comparisons.h | 84 +++++ velox/functions/sparksql/RegisterCompare.cpp | 30 +- velox/functions/sparksql/tests/CMakeLists.txt | 1 + .../functions/sparksql/tests/CompareTests.cpp | 298 ++++++++++++++++++ 6 files changed, 561 insertions(+), 8 deletions(-) create mode 100644 velox/functions/sparksql/Comparisons.cpp create mode 100644 velox/functions/sparksql/tests/CompareTests.cpp diff --git a/velox/functions/sparksql/CMakeLists.txt b/velox/functions/sparksql/CMakeLists.txt index b9ec0498a589..76ffce938e92 100644 --- a/velox/functions/sparksql/CMakeLists.txt +++ b/velox/functions/sparksql/CMakeLists.txt @@ -17,6 +17,7 @@ add_library( ArraySort.cpp Bitwise.cpp CompareFunctionsNullSafe.cpp + Comparisons.cpp Hash.cpp In.cpp LeastGreatest.cpp diff --git a/velox/functions/sparksql/Comparisons.cpp b/velox/functions/sparksql/Comparisons.cpp new file mode 100644 index 000000000000..4f1616917b35 --- /dev/null +++ b/velox/functions/sparksql/Comparisons.cpp @@ -0,0 +1,155 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "velox/functions/sparksql/LeastGreatest.h" + +#include "velox/expression/EvalCtx.h" +#include "velox/expression/Expr.h" +#include "velox/functions/sparksql/Comparisons.h" +#include "velox/type/Type.h" + +namespace facebook::velox::functions::sparksql { +namespace { + +template +class ComparisonFunction final : public exec::VectorFunction { + using T = typename TypeTraits::NativeType; + + bool isDefaultNullBehavior() const override { + return true; + } + + bool supportsFlatNoNullsFastPath() const override { + return true; + } + + void apply( + const SelectivityVector& rows, + std::vector& args, + const TypePtr& outputType, + exec::EvalCtx& context, + VectorPtr& result) const override { + context.ensureWritable(rows, BOOLEAN(), result); + auto* flatResult = result->asFlatVector(); + flatResult->mutableRawValues(); + const Cmp cmp; + if (args[0]->isFlatEncoding() && args[1]->isFlatEncoding()) { + // Fast path for (flat, flat). + auto flatA = args[0]->asUnchecked>(); + auto rawA = flatA->mutableRawValues(); + auto flatB = args[1]->asUnchecked>(); + auto rawB = flatB->mutableRawValues(); + rows.applyToSelected( + [&](vector_size_t i) { flatResult->set(i, cmp(rawA[i], rawB[i])); }); + } else if (args[0]->isConstantEncoding() && args[1]->isFlatEncoding()) { + // Fast path for (const, flat). + auto constant = args[0]->asUnchecked>()->valueAt(0); + auto flatValues = args[1]->asUnchecked>(); + auto rawValues = flatValues->mutableRawValues(); + rows.applyToSelected([&](vector_size_t i) { + flatResult->set(i, cmp(constant, rawValues[i])); + }); + } else if (args[0]->isFlatEncoding() && args[1]->isConstantEncoding()) { + // Fast path for (flat, const). + auto flatValues = args[0]->asUnchecked>(); + auto constant = args[1]->asUnchecked>()->valueAt(0); + auto rawValues = flatValues->mutableRawValues(); + rows.applyToSelected([&](vector_size_t i) { + flatResult->set(i, cmp(rawValues[i], constant)); + }); + } else { + // Fast path if one or more arguments are encoded. + exec::DecodedArgs decodedArgs(rows, args, context); + auto decoded0 = decodedArgs.at(0); + auto decoded1 = decodedArgs.at(1); + rows.applyToSelected([&](vector_size_t i) { + flatResult->set( + i, cmp(decoded0->valueAt(i), decoded1->valueAt(i))); + }); + } + } +}; + +template