From ab283cc914b4b2a4d3c5b9fa793d8b6256801101 Mon Sep 17 00:00:00 2001 From: Chengcheng Jin Date: Wed, 23 Nov 2022 15:32:45 +0000 Subject: [PATCH] fix bloomfilter bit_ size --- .../aggregates/BloomFilterAggAggregate.cpp | 79 +++++++++++++------ 1 file changed, 57 insertions(+), 22 deletions(-) diff --git a/velox/functions/sparksql/aggregates/BloomFilterAggAggregate.cpp b/velox/functions/sparksql/aggregates/BloomFilterAggAggregate.cpp index c23ccfaf9679..cc19b0c35d9c 100644 --- a/velox/functions/sparksql/aggregates/BloomFilterAggAggregate.cpp +++ b/velox/functions/sparksql/aggregates/BloomFilterAggAggregate.cpp @@ -47,6 +47,12 @@ struct BloomFilterAccumulator { bloomFilter.merge(output); } + void init(int32_t capacity) { + if (!bloomFilter.isSet()) { + bloomFilter.reset(capacity); + } + } + BloomFilter bloomFilter; }; @@ -80,9 +86,7 @@ class BloomFilterAggAggregate : public exec::Aggregate { VELOX_CHECK(!decodedRaw_.mayHaveNulls()); rows.applyToSelected([&](vector_size_t row) { auto accumulator = value(groups[row]); - if (!accumulator->bloomFilter.isSet()) { - accumulator->bloomFilter.reset(numBits_); - } + accumulator->init(capacity_); accumulator->bloomFilter.insert(decodedRaw_.valueAt(row)); }); } @@ -111,18 +115,14 @@ class BloomFilterAggAggregate : public exec::Aggregate { bool /*mayPushdown*/) override { decodeArguments(rows, args); auto accumulator = value(group); - // if (decodedRaw_.isConstantMapping()) { - // // all values are same, just do for the first - // if (!accumulator->bloomFilter.isSet()) { - // accumulator->bloomFilter.reset(numBits_); - // } - // accumulator->bloomFilter.insert(decodedRaw_.valueAt(0)); - // return; - // } + if (decodedRaw_.isConstantMapping()) { + // all values are same, just do for the first + accumulator->init(capacity_); + accumulator->bloomFilter.insert(decodedRaw_.valueAt(0)); + return; + } rows.applyToSelected([&](vector_size_t row) { - if (!accumulator->bloomFilter.isSet()) { - accumulator->bloomFilter.reset(numBits_); - } + accumulator->init(capacity_); accumulator->bloomFilter.insert(decodedRaw_.valueAt(row)); }); } @@ -177,16 +177,39 @@ class BloomFilterAggAggregate : public exec::Aggregate { } private: + const int64_t DEFAULT_ESPECTED_NUM_ITEMS = 1000000; + const int64_t MAX_NUM_ITEMS = 4000000; + const int64_t MAX_NUM_BITS = 67108864; + void decodeArguments( const SelectivityVector& rows, const std::vector& args) { - VELOX_CHECK_EQ(args.size(), 3); - decodedRaw_.decode(*args[0], rows); - DecodedVector decodedEstimatedNumItems(*args[1], rows); - DecodedVector decodedNumBits(*args[2], rows); - setConstantArgument( - "estimatedNumItems", estimatedNumItems_, decodedEstimatedNumItems); - setConstantArgument("numBits", numBits_, decodedNumBits); + if (args.size() > 0) { + decodedRaw_.decode(*args[0], rows); + if (args.size() > 1) { + DecodedVector decodedEstimatedNumItems(*args[1], rows); + setConstantArgument( + "estimatedNumItems", estimatedNumItems_, decodedEstimatedNumItems); + if (args.size() > 2) { + DecodedVector decodedNumBits(*args[2], rows); + setConstantArgument("numBits", numBits_, decodedNumBits); + } else { + VELOX_CHECK_EQ(args.size(), 3); + numBits_ = estimatedNumItems_ * 8; + } + } else { + estimatedNumItems_ = DEFAULT_ESPECTED_NUM_ITEMS; + numBits_ = estimatedNumItems_ * 8; + } + } else { + VELOX_USER_FAIL("Function args size must be more than 0") + } + estimatedNumItems_ = std::min(estimatedNumItems_, MAX_NUM_ITEMS); + numBits_ = std::min(numBits_, MAX_NUM_BITS); + // velox BloomFilter bit_ size is bits::nextPowerOfTwo(capacity) / 4, and + // spark bit_ size is Math.ceil(numBits / 64.0) so there is equal bit_ size + // using numBits_ / 16 + capacity_ = numBits_ / 16; } static void @@ -211,12 +234,13 @@ class BloomFilterAggAggregate : public exec::Aggregate { setConstantArgument(name, val, vec.valueAt(0)); } - // Reusable instance of DecodedVector for decoding input vectors. static constexpr int64_t kMissingArgument = -1; + // Reusable instance of DecodedVector for decoding input vectors. DecodedVector decodedRaw_; DecodedVector decodedIntermediate_; int64_t estimatedNumItems_ = kMissingArgument; int64_t numBits_ = kMissingArgument; + int32_t capacity_ = kMissingArgument; }; } // namespace @@ -226,6 +250,17 @@ bool registerBloomFilterAggAggregate(const std::string& name) { exec::AggregateFunctionSignatureBuilder() .argumentType("bigint") .argumentType("bigint") + .argumentType("bigint") + .intermediateType("varbinary") + .returnType("varbinary") + .build(), + exec::AggregateFunctionSignatureBuilder() + .argumentType("bigint") + .argumentType("bigint") + .intermediateType("varbinary") + .returnType("varbinary") + .build(), + exec::AggregateFunctionSignatureBuilder() .argumentType("bigint") .intermediateType("varbinary") .returnType("varbinary")