Skip to content

Commit

Permalink
[OPPRO-279] Add bloom_filter_agg and might_contain SparkSql function (o…
Browse files Browse the repository at this point in the history
…ap-project#79)

* add sparksql function bloom_filter_agg and might_contain

Change bit_ size to fix TPCDS performance

* change to statefil function

* optimize MightContain

* change back to spark value

* fix merge bloomfilter

* remove comment
  • Loading branch information
jinchengchenghh authored and zhejiangxiaomai committed Apr 14, 2023
1 parent 30e5614 commit d228dd9
Show file tree
Hide file tree
Showing 13 changed files with 364 additions and 14 deletions.
2 changes: 1 addition & 1 deletion velox/common/base/tests/BloomFilterTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class BloomFilterTest : public ::testing::Test {};

TEST_F(BloomFilterTest, basic) {
constexpr int32_t kSize = 1024;
BloomFilter bloom;
BloomFilter<int32_t> bloom;
bloom.reset(kSize);
for (auto i = 0; i < kSize; ++i) {
bloom.insert(folly::hasher<int32_t>()(i));
Expand Down
8 changes: 6 additions & 2 deletions velox/functions/sparksql/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ add_library(
RegisterCompare.cpp
Size.cpp
SplitFunctions.cpp
String.cpp)
String.cpp
MightContain.cpp)

target_link_libraries(
velox_functions_spark velox_functions_lib velox_functions_prestosql_impl
Expand All @@ -36,9 +37,12 @@ target_link_libraries(
set_property(TARGET velox_functions_spark PROPERTY JOB_POOL_COMPILE
high_memory_pool)

if(${VELOX_ENABLE_AGGREGATES})
add_subdirectory(aggregates)
endif()

if(${VELOX_BUILD_TESTING})
add_subdirectory(tests)
add_subdirectory(aggregates)
add_subdirectory(coverage)
endif()

Expand Down
1 change: 1 addition & 0 deletions velox/functions/sparksql/MightContain.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "velox/functions/sparksql/MightContain.h"

#include "velox/common/base/BloomFilter.h"

#include "velox/common/memory/HashStringAllocator.h"
#include "velox/expression/DecodedArgs.h"
#include "velox/vector/FlatVector.h"
Expand Down
3 changes: 2 additions & 1 deletion velox/functions/sparksql/Register.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,6 @@ void registerFunctions(const std::string& prefix) {
int64_t,
Varchar,
Varchar>({prefix + "unix_timestamp", prefix + "to_unix_timestamp"});

// Register DateTime functions.
registerFunction<MillisecondFunction, int32_t, Date>(
{prefix + "millisecond"});
Expand Down Expand Up @@ -215,6 +214,8 @@ void registerFunctions(const std::string& prefix) {
registerFunction<QuarterFunction, int32_t, Timestamp>({prefix + "quarter"});
registerFunction<QuarterFunction, int32_t, TimestampWithTimezone>(
{prefix + "quarter"});
registerFunction<YearFunction, int32_t, Date>({prefix + "year"});
registerFunction<YearFunction, int32_t, Timestamp>({prefix + "year"});
registerFunction<YearOfWeekFunction, int32_t, Date>(
{prefix + "year_of_week"});
registerFunction<YearOfWeekFunction, int32_t, Timestamp>(
Expand Down
272 changes: 272 additions & 0 deletions velox/functions/sparksql/aggregates/BloomFilterAggAggregate.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,272 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "velox/functions/sparksql/aggregates/BloomFilterAggAggregate.h"

#include <fmt/format.h>
#include "velox/common/base/BloomFilter.h"
#include "velox/exec/Aggregate.h"
#include "velox/expression/FunctionSignature.h"
#include "velox/vector/FlatVector.h"

namespace facebook::velox::functions::sparksql::aggregates {

namespace {
struct BloomFilterAccumulator {
int32_t serializedSize() {
return bloomFilter.serializedSize();
}

void serialize(StringView& output) {
return bloomFilter.serialize(output);
}

void deserialize(
StringView& serialized,
BloomFilter<int64_t, false>& output) {
BloomFilter<int64_t, false>::deserialize(serialized.data(), output);
}

void mergeWith(StringView& serialized) {
BloomFilter<int64_t, false> output;
deserialize(serialized, output);
bloomFilter.merge(output);
}

void init(int32_t capacity) {
if (!bloomFilter.isSet()) {
bloomFilter.reset(capacity);
}
}

BloomFilter<int64_t, false> bloomFilter;
};

class BloomFilterAggAggregate : public exec::Aggregate {
public:
explicit BloomFilterAggAggregate(const TypePtr& resultType)
: Aggregate(resultType) {}

int32_t accumulatorFixedWidthSize() const override {
return sizeof(BloomFilterAccumulator);
}

/// Initialize each group.
void initializeNewGroups(
char** groups,
folly::Range<const vector_size_t*> indices) override {
setAllNulls(groups, indices);
for (auto i : indices) {
new (groups[i] + offset_) BloomFilterAccumulator();
}
}

void addRawInput(
char** groups,
const SelectivityVector& rows,
const std::vector<VectorPtr>& args,
bool /*mayPushdown*/) override {
// ignore the estimatedNumItems, this config is not used in
// velox bloom filter implementation
decodeArguments(rows, args);
VELOX_CHECK(!decodedRaw_.mayHaveNulls());
rows.applyToSelected([&](vector_size_t row) {
auto accumulator = value<BloomFilterAccumulator>(groups[row]);
accumulator->init(capacity_);
accumulator->bloomFilter.insert(decodedRaw_.valueAt<int64_t>(row));
});
}

void addIntermediateResults(
char** groups,
const SelectivityVector& rows,
const std::vector<VectorPtr>& args,
bool /*mayPushdown*/) override {
VELOX_CHECK_EQ(args.size(), 1);
decodedIntermediate_.decode(*args[0], rows);
VELOX_CHECK(!decodedIntermediate_.mayHaveNulls());
rows.applyToSelected([&](auto row) {
auto group = groups[row];
auto tracker = trackRowSize(group);
auto serialized = decodedIntermediate_.valueAt<StringView>(row);
auto accumulator = value<BloomFilterAccumulator>(group);
accumulator->mergeWith(serialized);
});
}

void addSingleGroupRawInput(
char* group,
const SelectivityVector& rows,
const std::vector<VectorPtr>& args,
bool /*mayPushdown*/) override {
decodeArguments(rows, args);
auto accumulator = value<BloomFilterAccumulator>(group);
if (decodedRaw_.isConstantMapping()) {
// all values are same, just do for the first
accumulator->init(capacity_);
accumulator->bloomFilter.insert(decodedRaw_.valueAt<int64_t>(0));
return;
}
rows.applyToSelected([&](vector_size_t row) {
accumulator->init(capacity_);
accumulator->bloomFilter.insert(decodedRaw_.valueAt<int64_t>(row));
});
}

void addSingleGroupIntermediateResults(
char* group,
const SelectivityVector& rows,
const std::vector<VectorPtr>& args,
bool /*mayPushdown*/) override {
VELOX_CHECK_EQ(args.size(), 1);
decodedIntermediate_.decode(*args[0], rows);
VELOX_CHECK(!decodedIntermediate_.mayHaveNulls());
auto tracker = trackRowSize(group);
auto accumulator = value<BloomFilterAccumulator>(group);
rows.applyToSelected([&](auto row) {
auto serialized = decodedIntermediate_.valueAt<StringView>(row);
accumulator->mergeWith(serialized);
});

}

void extractValues(char** groups, int32_t numGroups, VectorPtr* result)
override {
VELOX_CHECK(result);
auto flatResult = (*result)->asUnchecked<FlatVector<StringView>>();
flatResult->resize(numGroups);
for (vector_size_t i = 0; i < numGroups; ++i) {
auto group = groups[i];
VELOX_CHECK_NOT_NULL(group);
auto accumulator = value<BloomFilterAccumulator>(group);
auto size = accumulator->serializedSize();
if (StringView::isInline(size)) {
StringView serialized(size);
accumulator->serialize(serialized);
flatResult->setNoCopy(i, serialized);
} else {
Buffer* buffer = flatResult->getBufferWithSpace(size);
StringView serialized(buffer->as<char>() + buffer->size(), size);
accumulator->serialize(serialized);
buffer->setSize(buffer->size() + size);
flatResult->setNoCopy(i, serialized);
}
}
}

void extractAccumulators(char** groups, int32_t numGroups, VectorPtr* result)
override {
extractValues(groups, numGroups, result);
}

private:
const int64_t DEFAULT_ESPECTED_NUM_ITEMS = 1000000;
const int64_t MAX_NUM_ITEMS = 4000000;
const int64_t MAX_NUM_BITS = 67108864;

void decodeArguments(
const SelectivityVector& rows,
const std::vector<VectorPtr>& args) {
if (args.size() > 0) {
decodedRaw_.decode(*args[0], rows);
if (args.size() > 1) {
DecodedVector decodedEstimatedNumItems(*args[1], rows);
setConstantArgument(
"estimatedNumItems", estimatedNumItems_, decodedEstimatedNumItems);
if (args.size() > 2) {
DecodedVector decodedNumBits(*args[2], rows);
setConstantArgument("numBits", numBits_, decodedNumBits);
} else {
VELOX_CHECK_EQ(args.size(), 3);
numBits_ = estimatedNumItems_ * 8;
}
} else {
estimatedNumItems_ = DEFAULT_ESPECTED_NUM_ITEMS;
numBits_ = estimatedNumItems_ * 8;
}
} else {
VELOX_USER_FAIL("Function args size must be more than 0")
}
estimatedNumItems_ = std::min(estimatedNumItems_, MAX_NUM_ITEMS);
numBits_ = std::min(numBits_, MAX_NUM_BITS);
capacity_ = numBits_ / 16;
}

static void
setConstantArgument(const char* name, int64_t& val, int64_t newVal) {
VELOX_USER_CHECK_GT(newVal, 0, "{} must be positive", name);
if (val == kMissingArgument) {
val = newVal;
} else {
VELOX_USER_CHECK_EQ(
newVal, val, "{} argument must be constant for all input rows", name);
}
}

static void setConstantArgument(
const char* name,
int64_t& val,
const DecodedVector& vec) {
VELOX_CHECK(
vec.isConstantMapping(),
"{} argument must be constant for all input rows",
name);
setConstantArgument(name, val, vec.valueAt<int64_t>(0));
}

static constexpr int64_t kMissingArgument = -1;
// Reusable instance of DecodedVector for decoding input vectors.
DecodedVector decodedRaw_;
DecodedVector decodedIntermediate_;
int64_t estimatedNumItems_ = kMissingArgument;
int64_t numBits_ = kMissingArgument;
int32_t capacity_ = kMissingArgument;
};

} // namespace

bool registerBloomFilterAggAggregate(const std::string& name) {
std::vector<std::shared_ptr<exec::AggregateFunctionSignature>> signatures{
exec::AggregateFunctionSignatureBuilder()
.argumentType("bigint")
.argumentType("bigint")
.argumentType("bigint")
.intermediateType("varbinary")
.returnType("varbinary")
.build(),
exec::AggregateFunctionSignatureBuilder()
.argumentType("bigint")
.argumentType("bigint")
.intermediateType("varbinary")
.returnType("varbinary")
.build(),
exec::AggregateFunctionSignatureBuilder()
.argumentType("bigint")
.intermediateType("varbinary")
.returnType("varbinary")
.build()};

return exec::registerAggregateFunction(
name,
std::move(signatures),
[name](
core::AggregationNode::Step step,
const std::vector<TypePtr>& argTypes,
const TypePtr& resultType) -> std::unique_ptr<exec::Aggregate> {
return std::make_unique<BloomFilterAggAggregate>(resultType);
});
}
} // namespace facebook::velox::functions::sparksql::aggregates
25 changes: 25 additions & 0 deletions velox/functions/sparksql/aggregates/BloomFilterAggAggregate.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <string>

namespace facebook::velox::functions::sparksql::aggregates {

bool registerBloomFilterAggAggregate(const std::string& name);

} // namespace facebook::velox::functions::sparksql::aggregates
3 changes: 2 additions & 1 deletion velox/functions/sparksql/aggregates/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

add_library(velox_functions_spark_aggregates BitwiseXorAggregate.cpp
LastAggregate.cpp Register.cpp)
LastAggregate.cpp BloomFilterAggAggregate.cpp Register.cpp)

target_link_libraries(velox_functions_spark_aggregates ${FMT} velox_exec
velox_expression_functions velox_aggregates velox_vector)
Expand Down
2 changes: 2 additions & 0 deletions velox/functions/sparksql/aggregates/Register.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,14 @@
#include "velox/functions/sparksql/aggregates/Register.h"

#include "velox/functions/sparksql/aggregates/BitwiseXorAggregate.h"
#include "velox/functions/sparksql/aggregates/BloomFilterAggAggregate.h"
#include "velox/functions/sparksql/aggregates/LastAggregate.h"

namespace facebook::velox::functions::sparksql::aggregates {

void registerAggregateFunctions(const std::string& prefix) {
aggregates::registerLastAggregate(prefix + "last");
aggregates::registerBitwiseXorAggregate(prefix + "bit_xor");
aggregates::registerBloomFilterAggAggregate(prefix + "bloom_filter_agg");
}
} // namespace facebook::velox::functions::sparksql::aggregates
Loading

0 comments on commit d228dd9

Please sign in to comment.