Skip to content

Commit

Permalink
[OPPRO-279] Add bloom_filter_agg and might_contain SparkSql function (#…
Browse files Browse the repository at this point in the history
…79)

* add sparksql function bloom_filter_agg and might_contain

Change bit_ size to fix TPCDS performance

* change to statefil function

* optimize MightContain

* change back to spark value

* fix merge bloomfilter

* remove comment
  • Loading branch information
jinchengchenghh authored and zhejiangxiaomai committed Jan 31, 2023
1 parent f9cac51 commit b4b43f3
Show file tree
Hide file tree
Showing 15 changed files with 625 additions and 14 deletions.
63 changes: 57 additions & 6 deletions velox/common/base/BloomFilter.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@
#include <vector>

#include <folly/Hash.h>

#include "velox/common/base/BitUtil.h"
#include "velox/common/base/Exceptions.h"
#include "velox/common/base/IOUtils.h"
#include "velox/type/StringView.h"

namespace facebook::velox {

Expand All @@ -31,9 +33,15 @@ namespace facebook::velox {
// expected entry, we get ~2% false positives. 'hashInput' determines
// if the value added or checked needs to be hashed. If this is false,
// we assume that the input is already a 64 bit hash number.
template <bool hashInput = true>
// case:
// InputType can be one of folly hasher support type when hashInput = false
// InputType can only be uint64_t when hashInput = true
template <class InputType = uint64_t, bool hashInput = true>
class BloomFilter {
public:
BloomFilter(){};
BloomFilter(std::vector<uint64_t> bits) : bits_(bits){};

// Prepares 'this' for use with an expected 'capacity'
// entries. Drops any prior content.
void reset(int32_t capacity) {
Expand All @@ -42,18 +50,61 @@ class BloomFilter {
bits_.resize(std::max<int32_t>(4, bits::nextPowerOfTwo(capacity) / 4));
}

bool isSet() {
return bits_.size() > 0;
}

// Adds 'value'.
void insert(uint64_t value) {
void insert(InputType value) {
set(bits_.data(),
bits_.size(),
hashInput ? folly::hasher<uint64_t>()(value) : value);
hashInput ? folly::hasher<InputType>()(value) : value);
}

bool mayContain(uint64_t value) const {
bool mayContain(InputType value) const {
return test(
bits_.data(),
bits_.size(),
hashInput ? folly::hasher<uint64_t>()(value) : value);
hashInput ? folly::hasher<InputType>()(value) : value);
}

// Combines the two bloomFilter bits_ using bitwise OR.
void merge(BloomFilter& bloomFilter) {
if (bits_.size() == 0) {
bits_ = bloomFilter.bits_;
return;
} else if (bloomFilter.bits_.size() == 0) {
return;
}
VELOX_CHECK_EQ(bits_.size(), bloomFilter.bits_.size());
for (auto i = 0; i < bloomFilter.bits_.size(); i++) {
bits_[i] |= bloomFilter.bits_[i];
}
}

uint32_t serializedSize() {
return 4 /* number of bits */
+ bits_.size() * 8;
}

void serialize(StringView& output) {
char* outputBuffer = const_cast<char*>(output.data());
common::OutputByteStream stream(outputBuffer);
stream.appendOne((int32_t)bits_.size());
for (auto bit : bits_) {
stream.appendOne(bit);
}
}

static void deserialize(const char* serialized, BloomFilter& output) {
common::InputByteStream stream(serialized);
auto size = stream.read<int32_t>();
output.bits_.resize(size);
auto bitsdata =
reinterpret_cast<const uint64_t*>(serialized + stream.offset());
for (auto i = 0; i < size; i++) {
output.bits_[i] = bitsdata[i];
}
}

private:
Expand Down
45 changes: 44 additions & 1 deletion velox/common/base/tests/BloomFilterTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ using namespace facebook::velox;

TEST(BloomFilterTest, basic) {
constexpr int32_t kSize = 1024;
BloomFilter bloom;
BloomFilter<int32_t> bloom;
bloom.reset(kSize);
for (auto i = 0; i < kSize; ++i) {
bloom.insert(i);
Expand All @@ -37,3 +37,46 @@ TEST(BloomFilterTest, basic) {
}
EXPECT_GT(2, 100 * numFalsePositives / kSize);
}

TEST(BloomFilterTest, serialize) {
constexpr int32_t kSize = 1024;
BloomFilter<int32_t> bloom;
bloom.reset(kSize);
for (auto i = 0; i < kSize; ++i) {
bloom.insert(i);
}
std::string data;
data.resize(bloom.serializedSize());
StringView serialized(data.data(), data.size());
bloom.serialize(serialized);
BloomFilter<int32_t> deserialized;
BloomFilter<int32_t>::deserialize(data.data(), deserialized);
for (auto i = 0; i < kSize; ++i) {
EXPECT_TRUE(deserialized.mayContain(i));
}

EXPECT_EQ(bloom.serializedSize(), deserialized.serializedSize());
}

TEST(BloomFilterTest, merge) {
constexpr int32_t kSize = 10;
BloomFilter<int32_t> bloom;
bloom.reset(kSize);
for (auto i = 0; i < kSize; ++i) {
bloom.insert(i);
}

BloomFilter<int32_t> merge;
merge.reset(kSize);
for (auto i = kSize; i < kSize + kSize; i++) {
merge.insert(i);
}

bloom.merge(merge);

for (auto i = 0; i < kSize + kSize; ++i) {
EXPECT_TRUE(bloom.mayContain(i));
}

EXPECT_EQ(bloom.serializedSize(), merge.serializedSize());
}
8 changes: 6 additions & 2 deletions velox/functions/sparksql/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ add_library(
Size.cpp
SplitFunctions.cpp
String.cpp
Subscript.cpp)
Subscript.cpp
MightContain.cpp)

target_link_libraries(
velox_functions_spark velox_functions_lib velox_functions_prestosql_impl
Expand All @@ -36,9 +37,12 @@ target_link_libraries(
set_property(TARGET velox_functions_spark PROPERTY JOB_POOL_COMPILE
high_memory_pool)

if(${VELOX_ENABLE_AGGREGATES})
add_subdirectory(aggregates)
endif()

if(${VELOX_BUILD_TESTING})
add_subdirectory(tests)
add_subdirectory(aggregates)
endif()

if(${VELOX_ENABLE_BENCHMARKS})
Expand Down
84 changes: 84 additions & 0 deletions velox/functions/sparksql/MightContain.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "velox/functions/sparksql/MightContain.h"

#include "velox/common/base/BloomFilter.h"
#include "velox/expression/DecodedArgs.h"
#include "velox/vector/FlatVector.h"

#include <glog/logging.h>

namespace facebook::velox::functions::sparksql {
namespace {
class BloomFilterMightContainFunction final : public exec::VectorFunction {
bool isDefaultNullBehavior() const final {
return false;
}

void apply(
const SelectivityVector& rows,
std::vector<VectorPtr>& args, // Not using const ref so we can reuse args
const TypePtr& outputType,
exec::EvalCtx& context,
VectorPtr& resultRef) const final {
VELOX_CHECK_EQ(args.size(), 2);
context.ensureWritable(rows, BOOLEAN(), resultRef);
auto& result = *resultRef->as<FlatVector<bool>>();
exec::DecodedArgs decodedArgs(rows, args, context);
auto serialized = decodedArgs.at(0);
auto value = decodedArgs.at(1);
if (serialized->isConstantMapping() && serialized->isNullAt(0)) {
rows.applyToSelected([&](int row) { result.setNull(row, true); });
return;
}

if (serialized->isConstantMapping()) {
BloomFilter<int64_t, false> output;
auto serializedBloom = serialized->valueAt<StringView>(0);
BloomFilter<int64_t, false>::deserialize(serializedBloom.data(), output);
rows.applyToSelected([&](int row) {
result.set(row, output.mayContain(value->valueAt<int64_t>(row)));
});
return;
}

rows.applyToSelected([&](int row) {
BloomFilter<int64_t, false> output;
auto serializedBloom = serialized->valueAt<StringView>(row);
BloomFilter<int64_t, false>::deserialize(serializedBloom.data(), output);
result.set(row, output.mayContain(value->valueAt<int64_t>(row)));
});
}
};
} // namespace

std::vector<std::shared_ptr<exec::FunctionSignature>> mightContainSignatures() {
return {exec::FunctionSignatureBuilder()
.returnType("boolean")
.argumentType("varbinary")
.argumentType("bigint")
.build()};
}

std::shared_ptr<exec::VectorFunction> makeMightContain(
const std::string& name,
const std::vector<exec::VectorFunctionArg>& inputArgs) {
static const auto kHashFunction =
std::make_shared<BloomFilterMightContainFunction>();
return kHashFunction;
}

} // namespace facebook::velox::functions::sparksql
26 changes: 26 additions & 0 deletions velox/functions/sparksql/MightContain.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "velox/expression/VectorFunction.h"

namespace facebook::velox::functions::sparksql {

std::vector<std::shared_ptr<exec::FunctionSignature>> mightContainSignatures();

std::shared_ptr<exec::VectorFunction> makeMightContain(
const std::string& name,
const std::vector<exec::VectorFunctionArg>& inputArgs);

} // namespace facebook::velox::functions::sparksql
8 changes: 6 additions & 2 deletions velox/functions/sparksql/Register.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include "velox/functions/sparksql/Hash.h"
#include "velox/functions/sparksql/In.h"
#include "velox/functions/sparksql/LeastGreatest.h"
#include "velox/functions/sparksql/MightContain.h"
#include "velox/functions/sparksql/RegexFunctions.h"
#include "velox/functions/sparksql/RegisterArithmetic.h"
#include "velox/functions/sparksql/RegisterCompare.h"
Expand Down Expand Up @@ -149,8 +150,9 @@ void registerFunctions(const std::string& prefix) {
exec::registerStatefulVectorFunction(
prefix + "sort_array", sortArraySignatures(), makeSortArray);

registerFunction<YearFunction, int32_t, Timestamp>({"year"});
registerFunction<YearFunction, int32_t, Date>({"year"});
// Register bloom filter function
exec::registerStatefulVectorFunction(
prefix + "might_contain", mightContainSignatures(), makeMightContain);
// Register DateTime functions.
registerFunction<MillisecondFunction, int32_t, Date>(
{prefix + "millisecond"});
Expand Down Expand Up @@ -194,6 +196,8 @@ void registerFunctions(const std::string& prefix) {
registerFunction<QuarterFunction, int32_t, Timestamp>({prefix + "quarter"});
registerFunction<QuarterFunction, int32_t, TimestampWithTimezone>(
{prefix + "quarter"});
registerFunction<YearFunction, int32_t, Date>({prefix + "year"});
registerFunction<YearFunction, int32_t, Timestamp>({prefix + "year"});
registerFunction<YearOfWeekFunction, int32_t, Date>(
{prefix + "year_of_week"});
registerFunction<YearOfWeekFunction, int32_t, Timestamp>(
Expand Down
Loading

0 comments on commit b4b43f3

Please sign in to comment.