Skip to content

Commit

Permalink
fix bloomfilter bit_ size
Browse files Browse the repository at this point in the history
  • Loading branch information
jinchengchenghh committed Nov 25, 2022
1 parent 785c043 commit ab283cc
Showing 1 changed file with 57 additions and 22 deletions.
79 changes: 57 additions & 22 deletions velox/functions/sparksql/aggregates/BloomFilterAggAggregate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,12 @@ struct BloomFilterAccumulator {
bloomFilter.merge(output);
}

void init(int32_t capacity) {
if (!bloomFilter.isSet()) {
bloomFilter.reset(capacity);
}
}

BloomFilter<int64_t, false> bloomFilter;
};

Expand Down Expand Up @@ -80,9 +86,7 @@ class BloomFilterAggAggregate : public exec::Aggregate {
VELOX_CHECK(!decodedRaw_.mayHaveNulls());
rows.applyToSelected([&](vector_size_t row) {
auto accumulator = value<BloomFilterAccumulator>(groups[row]);
if (!accumulator->bloomFilter.isSet()) {
accumulator->bloomFilter.reset(numBits_);
}
accumulator->init(capacity_);
accumulator->bloomFilter.insert(decodedRaw_.valueAt<int64_t>(row));
});
}
Expand Down Expand Up @@ -111,18 +115,14 @@ class BloomFilterAggAggregate : public exec::Aggregate {
bool /*mayPushdown*/) override {
decodeArguments(rows, args);
auto accumulator = value<BloomFilterAccumulator>(group);
// if (decodedRaw_.isConstantMapping()) {
// // all values are same, just do for the first
// if (!accumulator->bloomFilter.isSet()) {
// accumulator->bloomFilter.reset(numBits_);
// }
// accumulator->bloomFilter.insert(decodedRaw_.valueAt<int64_t>(0));
// return;
// }
if (decodedRaw_.isConstantMapping()) {
// all values are same, just do for the first
accumulator->init(capacity_);
accumulator->bloomFilter.insert(decodedRaw_.valueAt<int64_t>(0));
return;
}
rows.applyToSelected([&](vector_size_t row) {
if (!accumulator->bloomFilter.isSet()) {
accumulator->bloomFilter.reset(numBits_);
}
accumulator->init(capacity_);
accumulator->bloomFilter.insert(decodedRaw_.valueAt<int64_t>(row));
});
}
Expand Down Expand Up @@ -177,16 +177,39 @@ class BloomFilterAggAggregate : public exec::Aggregate {
}

private:
const int64_t DEFAULT_ESPECTED_NUM_ITEMS = 1000000;
const int64_t MAX_NUM_ITEMS = 4000000;
const int64_t MAX_NUM_BITS = 67108864;

void decodeArguments(
const SelectivityVector& rows,
const std::vector<VectorPtr>& args) {
VELOX_CHECK_EQ(args.size(), 3);
decodedRaw_.decode(*args[0], rows);
DecodedVector decodedEstimatedNumItems(*args[1], rows);
DecodedVector decodedNumBits(*args[2], rows);
setConstantArgument(
"estimatedNumItems", estimatedNumItems_, decodedEstimatedNumItems);
setConstantArgument("numBits", numBits_, decodedNumBits);
if (args.size() > 0) {
decodedRaw_.decode(*args[0], rows);
if (args.size() > 1) {
DecodedVector decodedEstimatedNumItems(*args[1], rows);
setConstantArgument(
"estimatedNumItems", estimatedNumItems_, decodedEstimatedNumItems);
if (args.size() > 2) {
DecodedVector decodedNumBits(*args[2], rows);
setConstantArgument("numBits", numBits_, decodedNumBits);
} else {
VELOX_CHECK_EQ(args.size(), 3);
numBits_ = estimatedNumItems_ * 8;
}
} else {
estimatedNumItems_ = DEFAULT_ESPECTED_NUM_ITEMS;
numBits_ = estimatedNumItems_ * 8;
}
} else {
VELOX_USER_FAIL("Function args size must be more than 0")
}
estimatedNumItems_ = std::min(estimatedNumItems_, MAX_NUM_ITEMS);
numBits_ = std::min(numBits_, MAX_NUM_BITS);
// velox BloomFilter bit_ size is bits::nextPowerOfTwo(capacity) / 4, and
// spark bit_ size is Math.ceil(numBits / 64.0) so there is equal bit_ size
// using numBits_ / 16
capacity_ = numBits_ / 16;
}

static void
Expand All @@ -211,12 +234,13 @@ class BloomFilterAggAggregate : public exec::Aggregate {
setConstantArgument(name, val, vec.valueAt<int64_t>(0));
}

// Reusable instance of DecodedVector for decoding input vectors.
static constexpr int64_t kMissingArgument = -1;
// Reusable instance of DecodedVector for decoding input vectors.
DecodedVector decodedRaw_;
DecodedVector decodedIntermediate_;
int64_t estimatedNumItems_ = kMissingArgument;
int64_t numBits_ = kMissingArgument;
int32_t capacity_ = kMissingArgument;
};

} // namespace
Expand All @@ -226,6 +250,17 @@ bool registerBloomFilterAggAggregate(const std::string& name) {
exec::AggregateFunctionSignatureBuilder()
.argumentType("bigint")
.argumentType("bigint")
.argumentType("bigint")
.intermediateType("varbinary")
.returnType("varbinary")
.build(),
exec::AggregateFunctionSignatureBuilder()
.argumentType("bigint")
.argumentType("bigint")
.intermediateType("varbinary")
.returnType("varbinary")
.build(),
exec::AggregateFunctionSignatureBuilder()
.argumentType("bigint")
.intermediateType("varbinary")
.returnType("varbinary")
Expand Down

0 comments on commit ab283cc

Please sign in to comment.