Skip to content

Commit

Permalink
add sparksql function bloom_filter_agg and might_contain
Browse files Browse the repository at this point in the history
Change bit_ size to fix TPCDS performance
  • Loading branch information
jinchengchenghh committed Jan 6, 2023
1 parent 277e0bf commit a02646f
Show file tree
Hide file tree
Showing 14 changed files with 567 additions and 38 deletions.
61 changes: 56 additions & 5 deletions velox/common/base/BloomFilter.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@
#include <folly/Hash.h>

#include "velox/common/base/BitUtil.h"
#include "velox/common/base/Exceptions.h"
#include "velox/common/base/IOUtils.h"
#include "velox/type/StringView.h"

namespace facebook::velox {

Expand All @@ -31,9 +34,15 @@ namespace facebook::velox {
// expected entry, we get ~2% false positives. 'hashInput' determines
// if the value added or checked needs to be hashed. If this is false,
// we assume that the input is already a 64 bit hash number.
template <bool hashInput = true>
// case:
// InputType can be one of folly hasher support type when hashInput = false
// InputType can only be uint64_t when hashInput = true
template <class InputType = uint64_t, bool hashInput = true>
class BloomFilter {
public:
BloomFilter(){};
BloomFilter(std::vector<uint64_t> bits) : bits_(bits){};

// Prepares 'this' for use with an expected 'capacity'
// entries. Drops any prior content.
void reset(int32_t capacity) {
Expand All @@ -42,18 +51,60 @@ class BloomFilter {
bits_.resize(std::max<int32_t>(4, bits::nextPowerOfTwo(capacity) / 4));
}

bool isSet() {
return bits_.size() > 0;
}

// Adds 'value'.
void insert(uint64_t value) {
void insert(InputType value) {
set(bits_.data(),
bits_.size(),
hashInput ? folly::hasher<uint64_t>()(value) : value);
hashInput ? folly::hasher<InputType>()(value) : value);
}

bool mayContain(uint64_t value) const {
bool mayContain(InputType value) const {
return test(
bits_.data(),
bits_.size(),
hashInput ? folly::hasher<uint64_t>()(value) : value);
hashInput ? folly::hasher<InputType>()(value) : value);
}

// Combines the two bloomFilter bits_ using bitwise OR.
void merge(BloomFilter& bloomFilter) {
if (bits_.size() == 0) {
bits_ = bloomFilter.bits_;
return;
} else if (bloomFilter.bits_.size() == 0){
VELOX_FAIL("Input bit length should not be 0");
}
VELOX_CHECK_EQ(bits_.size(), bloomFilter.bits_.size());
for (auto i = 0; i < bloomFilter.bits_.size(); i++) {
bits_[i] |= bloomFilter.bits_[i];
}
}

uint32_t serializedSize() {
return 4 /* number of bits */
+ bits_.size() * 8;
}

void serialize(StringView& output) {
char* outputBuffer = const_cast<char*>(output.data());
common::OutputByteStream stream(outputBuffer);
stream.appendOne((int32_t)bits_.size());
for (auto bit : bits_) {
stream.appendOne(bit);
}
}

static void deserialize(const char* serialized, BloomFilter& output) {
common::InputByteStream stream(serialized);
auto size = stream.read<int32_t>();
output.bits_.resize(size);
auto bitsdata = reinterpret_cast<const uint64_t*>(serialized + stream.offset());
for (auto i = 0; i < size; i++) {
output.bits_[i] = bitsdata[i];
}
}

private:
Expand Down
45 changes: 44 additions & 1 deletion velox/common/base/tests/BloomFilterTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ using namespace facebook::velox;

TEST(BloomFilterTest, basic) {
constexpr int32_t kSize = 1024;
BloomFilter bloom;
BloomFilter<int32_t> bloom;
bloom.reset(kSize);
for (auto i = 0; i < kSize; ++i) {
bloom.insert(i);
Expand All @@ -37,3 +37,46 @@ TEST(BloomFilterTest, basic) {
}
EXPECT_GT(2, 100 * numFalsePositives / kSize);
}

TEST(BloomFilterTest, serialize) {
constexpr int32_t kSize = 1024;
BloomFilter<int32_t> bloom;
bloom.reset(kSize);
for (auto i = 0; i < kSize; ++i) {
bloom.insert(i);
}
std::string data;
data.resize(bloom.serializedSize());
StringView serialized(data.data(), data.size());
bloom.serialize(serialized);
BloomFilter<int32_t> deserialized;
BloomFilter<int32_t>::deserialize(data.data(), deserialized);
for (auto i = 0; i < kSize; ++i) {
EXPECT_TRUE(deserialized.mayContain(i));
}

EXPECT_EQ(bloom.serializedSize(), deserialized.serializedSize());
}

TEST(BloomFilterTest, merge) {
constexpr int32_t kSize = 10;
BloomFilter<int32_t> bloom;
bloom.reset(kSize);
for (auto i = 0; i < kSize; ++i) {
bloom.insert(i);
}

BloomFilter<int32_t> merge;
merge.reset(kSize);
for (auto i = kSize; i < kSize + kSize; i++) {
merge.insert(i);
}

bloom.merge(merge);

for (auto i = 0; i < kSize + kSize; ++i) {
EXPECT_TRUE(bloom.mayContain(i));
}

EXPECT_EQ(bloom.serializedSize(), merge.serializedSize());
}
5 changes: 4 additions & 1 deletion velox/functions/sparksql/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,12 @@ target_link_libraries(
set_property(TARGET velox_functions_spark PROPERTY JOB_POOL_COMPILE
high_memory_pool)

if(${VELOX_ENABLE_AGGREGATES})
add_subdirectory(aggregates)
endif()

if(${VELOX_BUILD_TESTING})
add_subdirectory(tests)
add_subdirectory(aggregates)
endif()

if(${VELOX_ENABLE_BENCHMARKS})
Expand Down
38 changes: 38 additions & 0 deletions velox/functions/sparksql/MightContain.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "velox/common/base/BloomFilter.h"
#include "velox/expression/VectorFunction.h"
#include "velox/functions/Macros.h"
#include "velox/functions/lib/string/StringImpl.h"

namespace facebook::velox::functions::sparksql {

template <typename T>
struct MightContainFunction {
VELOX_DEFINE_FUNCTION_TYPES(T);

FOLLY_ALWAYS_INLINE bool call(
out_type<bool>& result,
const arg_type<Varbinary>& serializedBloom,
const arg_type<int64_t>& value) {
BloomFilter<int64_t, false> output;
BloomFilter<int64_t, false>::deserialize(
std::string(std::string_view(serializedBloom)).data(), output);
return output.mayContain(value);
}
};

} // namespace facebook::velox::functions::sparksql
47 changes: 19 additions & 28 deletions velox/functions/sparksql/Register.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include "velox/functions/sparksql/Hash.h"
#include "velox/functions/sparksql/In.h"
#include "velox/functions/sparksql/LeastGreatest.h"
#include "velox/functions/sparksql/MightContain.h"
#include "velox/functions/sparksql/RegexFunctions.h"
#include "velox/functions/sparksql/RegisterArithmetic.h"
#include "velox/functions/sparksql/RegisterCompare.h"
Expand Down Expand Up @@ -156,22 +157,16 @@ void registerFunctions(const std::string& prefix) {
{prefix + "millisecond"});
registerFunction<MillisecondFunction, int32_t, TimestampWithTimezone>(
{prefix + "millisecond"});
registerFunction<SecondFunction, int32_t, Date>(
{prefix + "second"});
registerFunction<SecondFunction, int32_t, Timestamp>(
{prefix + "second"});
registerFunction<SecondFunction, int32_t, Date>({prefix + "second"});
registerFunction<SecondFunction, int32_t, Timestamp>({prefix + "second"});
registerFunction<SecondFunction, int32_t, TimestampWithTimezone>(
{prefix + "second"});
registerFunction<MinuteFunction, int32_t, Date>(
{prefix + "minute"});
registerFunction<MinuteFunction, int32_t, Timestamp>(
{prefix + "minute"});
registerFunction<MinuteFunction, int32_t, Date>({prefix + "minute"});
registerFunction<MinuteFunction, int32_t, Timestamp>({prefix + "minute"});
registerFunction<MinuteFunction, int32_t, TimestampWithTimezone>(
{prefix + "minute"});
registerFunction<HourFunction, int32_t, Date>(
{prefix + "hour"});
registerFunction<HourFunction, int32_t, Timestamp>(
{prefix + "hour"});
registerFunction<HourFunction, int32_t, Date>({prefix + "hour"});
registerFunction<HourFunction, int32_t, Timestamp>({prefix + "hour"});
registerFunction<HourFunction, int32_t, TimestampWithTimezone>(
{prefix + "hour"});
registerFunction<DayFunction, int32_t, Date>(
Expand All @@ -180,34 +175,26 @@ void registerFunctions(const std::string& prefix) {
{prefix + "day", prefix + "day_of_month"});
registerFunction<DayFunction, int32_t, TimestampWithTimezone>(
{prefix + "day", prefix + "day_of_month"});
registerFunction<DayOfWeekFunction, int32_t, Date>(
{prefix + "day_of_week"});
registerFunction<DayOfWeekFunction, int32_t, Date>({prefix + "day_of_week"});
registerFunction<DayOfWeekFunction, int32_t, Timestamp>(
{prefix + "day_of_week"});
registerFunction<DayOfWeekFunction, int32_t, TimestampWithTimezone>(
{prefix + "day_of_week"});
registerFunction<DayOfYearFunction, int32_t, Date>(
{prefix + "day_of_year"});
registerFunction<DayOfYearFunction, int32_t, Date>({prefix + "day_of_year"});
registerFunction<DayOfYearFunction, int32_t, Timestamp>(
{prefix + "day_of_year"});
registerFunction<DayOfYearFunction, int32_t, TimestampWithTimezone>(
{prefix + "day_of_year"});
registerFunction<MonthFunction, int32_t, Date>(
{prefix + "month"});
registerFunction<MonthFunction, int32_t, Timestamp>(
{prefix + "month"});
registerFunction<MonthFunction, int32_t, Date>({prefix + "month"});
registerFunction<MonthFunction, int32_t, Timestamp>({prefix + "month"});
registerFunction<MonthFunction, int32_t, TimestampWithTimezone>(
{prefix + "month"});
registerFunction<QuarterFunction, int32_t, Date>(
{prefix + "quarter"});
registerFunction<QuarterFunction, int32_t, Timestamp>(
{prefix + "quarter"});
registerFunction<QuarterFunction, int32_t, Date>({prefix + "quarter"});
registerFunction<QuarterFunction, int32_t, Timestamp>({prefix + "quarter"});
registerFunction<QuarterFunction, int32_t, TimestampWithTimezone>(
{prefix + "quarter"});
registerFunction<YearFunction, int32_t, Date>(
{prefix + "year"});
registerFunction<YearFunction, int32_t, Timestamp>(
{prefix + "year"});
registerFunction<YearFunction, int32_t, Date>({prefix + "year"});
registerFunction<YearFunction, int32_t, Timestamp>({prefix + "year"});
registerFunction<YearFunction, int32_t, TimestampWithTimezone>(
{prefix + "year"});
registerFunction<YearOfWeekFunction, int32_t, Date>(
Expand All @@ -216,6 +203,10 @@ void registerFunctions(const std::string& prefix) {
{prefix + "year_of_week"});
registerFunction<YearOfWeekFunction, int32_t, TimestampWithTimezone>(
{prefix + "year_of_week"});

// Register bloom filter function
registerFunction<MightContainFunction, bool, Varbinary, int64_t>(
{prefix + "might_contain"});
}

} // namespace sparksql
Expand Down
Loading

0 comments on commit a02646f

Please sign in to comment.