Skip to content

Commit

Permalink
[OPPRO-279] Add bloom_filter_agg and might_contain SparkSql function (o…
Browse files Browse the repository at this point in the history
…ap-project#79)

* add sparksql function bloom_filter_agg and might_contain

Change bit_ size to fix TPCDS performance

* change to statefil function

* optimize MightContain

* change back to spark value

* fix merge bloomfilter

* remove comment
  • Loading branch information
jinchengchenghh authored and zhejiangxiaomai committed Feb 27, 2023
1 parent abcd915 commit 178b030
Show file tree
Hide file tree
Showing 14 changed files with 525 additions and 14 deletions.
2 changes: 1 addition & 1 deletion velox/common/base/tests/BloomFilterTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class BloomFilterTest : public ::testing::Test {};

TEST_F(BloomFilterTest, basic) {
constexpr int32_t kSize = 1024;
BloomFilter bloom;
BloomFilter<int32_t> bloom;
bloom.reset(kSize);
for (auto i = 0; i < kSize; ++i) {
bloom.insert(folly::hasher<int32_t>()(i));
Expand Down
8 changes: 6 additions & 2 deletions velox/functions/sparksql/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ add_library(
Size.cpp
SplitFunctions.cpp
String.cpp
Subscript.cpp)
Subscript.cpp
MightContain.cpp)

target_link_libraries(
velox_functions_spark velox_functions_lib velox_functions_prestosql_impl
Expand All @@ -36,9 +37,12 @@ target_link_libraries(
set_property(TARGET velox_functions_spark PROPERTY JOB_POOL_COMPILE
high_memory_pool)

if(${VELOX_ENABLE_AGGREGATES})
add_subdirectory(aggregates)
endif()

if(${VELOX_BUILD_TESTING})
add_subdirectory(tests)
add_subdirectory(aggregates)
endif()

if(${VELOX_ENABLE_BENCHMARKS})
Expand Down
84 changes: 84 additions & 0 deletions velox/functions/sparksql/MightContain.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "velox/functions/sparksql/MightContain.h"

#include "velox/common/base/BloomFilter.h"
#include "velox/expression/DecodedArgs.h"
#include "velox/vector/FlatVector.h"

#include <glog/logging.h>

namespace facebook::velox::functions::sparksql {
namespace {
class BloomFilterMightContainFunction final : public exec::VectorFunction {
bool isDefaultNullBehavior() const final {
return false;
}

void apply(
const SelectivityVector& rows,
std::vector<VectorPtr>& args, // Not using const ref so we can reuse args
const TypePtr& outputType,
exec::EvalCtx& context,
VectorPtr& resultRef) const final {
VELOX_CHECK_EQ(args.size(), 2);
context.ensureWritable(rows, BOOLEAN(), resultRef);
auto& result = *resultRef->as<FlatVector<bool>>();
exec::DecodedArgs decodedArgs(rows, args, context);
auto serialized = decodedArgs.at(0);
auto value = decodedArgs.at(1);
if (serialized->isConstantMapping() && serialized->isNullAt(0)) {
rows.applyToSelected([&](int row) { result.setNull(row, true); });
return;
}

if (serialized->isConstantMapping()) {
BloomFilter<int64_t, false> output;
auto serializedBloom = serialized->valueAt<StringView>(0);
BloomFilter<int64_t, false>::deserialize(serializedBloom.data(), output);
rows.applyToSelected([&](int row) {
result.set(row, output.mayContain(value->valueAt<int64_t>(row)));
});
return;
}

rows.applyToSelected([&](int row) {
BloomFilter<int64_t, false> output;
auto serializedBloom = serialized->valueAt<StringView>(row);
BloomFilter<int64_t, false>::deserialize(serializedBloom.data(), output);
result.set(row, output.mayContain(value->valueAt<int64_t>(row)));
});
}
};
} // namespace

std::vector<std::shared_ptr<exec::FunctionSignature>> mightContainSignatures() {
return {exec::FunctionSignatureBuilder()
.returnType("boolean")
.argumentType("varbinary")
.argumentType("bigint")
.build()};
}

std::shared_ptr<exec::VectorFunction> makeMightContain(
const std::string& name,
const std::vector<exec::VectorFunctionArg>& inputArgs) {
static const auto kHashFunction =
std::make_shared<BloomFilterMightContainFunction>();
return kHashFunction;
}

} // namespace facebook::velox::functions::sparksql
26 changes: 26 additions & 0 deletions velox/functions/sparksql/MightContain.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "velox/expression/VectorFunction.h"

namespace facebook::velox::functions::sparksql {

std::vector<std::shared_ptr<exec::FunctionSignature>> mightContainSignatures();

std::shared_ptr<exec::VectorFunction> makeMightContain(
const std::string& name,
const std::vector<exec::VectorFunctionArg>& inputArgs);

} // namespace facebook::velox::functions::sparksql
8 changes: 6 additions & 2 deletions velox/functions/sparksql/Register.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include "velox/functions/sparksql/Hash.h"
#include "velox/functions/sparksql/In.h"
#include "velox/functions/sparksql/LeastGreatest.h"
#include "velox/functions/sparksql/MightContain.h"
#include "velox/functions/sparksql/RegexFunctions.h"
#include "velox/functions/sparksql/RegisterArithmetic.h"
#include "velox/functions/sparksql/RegisterCompare.h"
Expand Down Expand Up @@ -151,8 +152,9 @@ void registerFunctions(const std::string& prefix) {
exec::registerStatefulVectorFunction(
prefix + "sort_array", sortArraySignatures(), makeSortArray);

registerFunction<YearFunction, int32_t, Timestamp>({"year"});
registerFunction<YearFunction, int32_t, Date>({"year"});
// Register bloom filter function
exec::registerStatefulVectorFunction(
prefix + "might_contain", mightContainSignatures(), makeMightContain);
// Register DateTime functions.
registerFunction<MillisecondFunction, int32_t, Date>(
{prefix + "millisecond"});
Expand Down Expand Up @@ -196,6 +198,8 @@ void registerFunctions(const std::string& prefix) {
registerFunction<QuarterFunction, int32_t, Timestamp>({prefix + "quarter"});
registerFunction<QuarterFunction, int32_t, TimestampWithTimezone>(
{prefix + "quarter"});
registerFunction<YearFunction, int32_t, Date>({prefix + "year"});
registerFunction<YearFunction, int32_t, Timestamp>({prefix + "year"});
registerFunction<YearOfWeekFunction, int32_t, Date>(
{prefix + "year_of_week"});
registerFunction<YearOfWeekFunction, int32_t, Timestamp>(
Expand Down
Loading

0 comments on commit 178b030

Please sign in to comment.