diff --git a/velox/common/base/BloomFilter.h b/velox/common/base/BloomFilter.h index 936c22934dfe..5474072e7cdc 100644 --- a/velox/common/base/BloomFilter.h +++ b/velox/common/base/BloomFilter.h @@ -22,6 +22,9 @@ #include #include "velox/common/base/BitUtil.h" +#include "velox/common/base/Exceptions.h" +#include "velox/common/base/IOUtils.h" +#include "velox/type/StringView.h" namespace facebook::velox { @@ -31,9 +34,15 @@ namespace facebook::velox { // expected entry, we get ~2% false positives. 'hashInput' determines // if the value added or checked needs to be hashed. If this is false, // we assume that the input is already a 64 bit hash number. -template +// case: +// InputType can be one of folly hasher support type when hashInput = false +// InputType can only be uint64_t when hashInput = true +template class BloomFilter { public: + BloomFilter(){}; + BloomFilter(std::vector bits) : bits_(bits){}; + // Prepares 'this' for use with an expected 'capacity' // entries. Drops any prior content. void reset(int32_t capacity) { @@ -42,18 +51,60 @@ class BloomFilter { bits_.resize(std::max(4, bits::nextPowerOfTwo(capacity) / 4)); } + bool isSet() { + return bits_.size() > 0; + } + // Adds 'value'. - void insert(uint64_t value) { + void insert(InputType value) { set(bits_.data(), bits_.size(), - hashInput ? folly::hasher()(value) : value); + hashInput ? folly::hasher()(value) : value); } - bool mayContain(uint64_t value) const { + bool mayContain(InputType value) const { return test( bits_.data(), bits_.size(), - hashInput ? folly::hasher()(value) : value); + hashInput ? folly::hasher()(value) : value); + } + +// Combines the two bloomFilter bits_ using bitwise OR. + void merge(BloomFilter& bloomFilter) { + if (bits_.size() == 0) { + bits_ = bloomFilter.bits_; + return; + } else if (bloomFilter.bits_.size() == 0){ + VELOX_FAIL("Input bit length should not be 0"); + } + VELOX_CHECK_EQ(bits_.size(), bloomFilter.bits_.size()); + for (auto i = 0; i < bloomFilter.bits_.size(); i++) { + bits_[i] |= bloomFilter.bits_[i]; + } + } + + uint32_t serializedSize() { + return 4 /* number of bits */ + + bits_.size() * 8; + } + + void serialize(StringView& output) { + char* outputBuffer = const_cast(output.data()); + common::OutputByteStream stream(outputBuffer); + stream.appendOne((int32_t)bits_.size()); + for (auto bit : bits_) { + stream.appendOne(bit); + } + } + + static void deserialize(const char* serialized, BloomFilter& output) { + common::InputByteStream stream(serialized); + auto size = stream.read(); + output.bits_.resize(size); + auto bitsdata = reinterpret_cast(serialized + stream.offset()); + for (auto i = 0; i < size; i++) { + output.bits_[i] = bitsdata[i]; + } } private: diff --git a/velox/common/base/tests/BloomFilterTest.cpp b/velox/common/base/tests/BloomFilterTest.cpp index 17543db7468b..b5b59fbf8dcd 100644 --- a/velox/common/base/tests/BloomFilterTest.cpp +++ b/velox/common/base/tests/BloomFilterTest.cpp @@ -24,7 +24,7 @@ using namespace facebook::velox; TEST(BloomFilterTest, basic) { constexpr int32_t kSize = 1024; - BloomFilter bloom; + BloomFilter bloom; bloom.reset(kSize); for (auto i = 0; i < kSize; ++i) { bloom.insert(i); @@ -37,3 +37,46 @@ TEST(BloomFilterTest, basic) { } EXPECT_GT(2, 100 * numFalsePositives / kSize); } + +TEST(BloomFilterTest, serialize) { + constexpr int32_t kSize = 1024; + BloomFilter bloom; + bloom.reset(kSize); + for (auto i = 0; i < kSize; ++i) { + bloom.insert(i); + } + std::string data; + data.resize(bloom.serializedSize()); + StringView serialized(data.data(), data.size()); + bloom.serialize(serialized); + BloomFilter deserialized; + BloomFilter::deserialize(data.data(), deserialized); + for (auto i = 0; i < kSize; ++i) { + EXPECT_TRUE(deserialized.mayContain(i)); + } + + EXPECT_EQ(bloom.serializedSize(), deserialized.serializedSize()); +} + +TEST(BloomFilterTest, merge) { + constexpr int32_t kSize = 10; + BloomFilter bloom; + bloom.reset(kSize); + for (auto i = 0; i < kSize; ++i) { + bloom.insert(i); + } + + BloomFilter merge; + merge.reset(kSize); + for (auto i = kSize; i < kSize + kSize; i++) { + merge.insert(i); + } + + bloom.merge(merge); + + for (auto i = 0; i < kSize + kSize; ++i) { + EXPECT_TRUE(bloom.mayContain(i)); + } + + EXPECT_EQ(bloom.serializedSize(), merge.serializedSize()); +} diff --git a/velox/functions/sparksql/CMakeLists.txt b/velox/functions/sparksql/CMakeLists.txt index d4f1facc0904..bd9cbb97bac6 100644 --- a/velox/functions/sparksql/CMakeLists.txt +++ b/velox/functions/sparksql/CMakeLists.txt @@ -36,9 +36,12 @@ target_link_libraries( set_property(TARGET velox_functions_spark PROPERTY JOB_POOL_COMPILE high_memory_pool) +if(${VELOX_ENABLE_AGGREGATES}) + add_subdirectory(aggregates) +endif() + if(${VELOX_BUILD_TESTING}) add_subdirectory(tests) - add_subdirectory(aggregates) endif() if(${VELOX_ENABLE_BENCHMARKS}) diff --git a/velox/functions/sparksql/MightContain.h b/velox/functions/sparksql/MightContain.h new file mode 100644 index 000000000000..f5a469e22946 --- /dev/null +++ b/velox/functions/sparksql/MightContain.h @@ -0,0 +1,38 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "velox/common/base/BloomFilter.h" +#include "velox/expression/VectorFunction.h" +#include "velox/functions/Macros.h" +#include "velox/functions/lib/string/StringImpl.h" + +namespace facebook::velox::functions::sparksql { + +template +struct MightContainFunction { + VELOX_DEFINE_FUNCTION_TYPES(T); + + FOLLY_ALWAYS_INLINE bool call( + out_type& result, + const arg_type& serializedBloom, + const arg_type& value) { + BloomFilter output; + BloomFilter::deserialize( + std::string(std::string_view(serializedBloom)).data(), output); + return output.mayContain(value); + } +}; + +} // namespace facebook::velox::functions::sparksql diff --git a/velox/functions/sparksql/Register.cpp b/velox/functions/sparksql/Register.cpp index adf5d3aa1705..a68cd6b64c15 100644 --- a/velox/functions/sparksql/Register.cpp +++ b/velox/functions/sparksql/Register.cpp @@ -28,6 +28,7 @@ #include "velox/functions/sparksql/Hash.h" #include "velox/functions/sparksql/In.h" #include "velox/functions/sparksql/LeastGreatest.h" +#include "velox/functions/sparksql/MightContain.h" #include "velox/functions/sparksql/RegexFunctions.h" #include "velox/functions/sparksql/RegisterArithmetic.h" #include "velox/functions/sparksql/RegisterCompare.h" @@ -156,22 +157,16 @@ void registerFunctions(const std::string& prefix) { {prefix + "millisecond"}); registerFunction( {prefix + "millisecond"}); - registerFunction( - {prefix + "second"}); - registerFunction( - {prefix + "second"}); + registerFunction({prefix + "second"}); + registerFunction({prefix + "second"}); registerFunction( {prefix + "second"}); - registerFunction( - {prefix + "minute"}); - registerFunction( - {prefix + "minute"}); + registerFunction({prefix + "minute"}); + registerFunction({prefix + "minute"}); registerFunction( {prefix + "minute"}); - registerFunction( - {prefix + "hour"}); - registerFunction( - {prefix + "hour"}); + registerFunction({prefix + "hour"}); + registerFunction({prefix + "hour"}); registerFunction( {prefix + "hour"}); registerFunction( @@ -180,34 +175,26 @@ void registerFunctions(const std::string& prefix) { {prefix + "day", prefix + "day_of_month"}); registerFunction( {prefix + "day", prefix + "day_of_month"}); - registerFunction( - {prefix + "day_of_week"}); + registerFunction({prefix + "day_of_week"}); registerFunction( {prefix + "day_of_week"}); registerFunction( {prefix + "day_of_week"}); - registerFunction( - {prefix + "day_of_year"}); + registerFunction({prefix + "day_of_year"}); registerFunction( {prefix + "day_of_year"}); registerFunction( {prefix + "day_of_year"}); - registerFunction( - {prefix + "month"}); - registerFunction( - {prefix + "month"}); + registerFunction({prefix + "month"}); + registerFunction({prefix + "month"}); registerFunction( {prefix + "month"}); - registerFunction( - {prefix + "quarter"}); - registerFunction( - {prefix + "quarter"}); + registerFunction({prefix + "quarter"}); + registerFunction({prefix + "quarter"}); registerFunction( {prefix + "quarter"}); - registerFunction( - {prefix + "year"}); - registerFunction( - {prefix + "year"}); + registerFunction({prefix + "year"}); + registerFunction({prefix + "year"}); registerFunction( {prefix + "year"}); registerFunction( @@ -216,6 +203,10 @@ void registerFunctions(const std::string& prefix) { {prefix + "year_of_week"}); registerFunction( {prefix + "year_of_week"}); + + // Register bloom filter function + registerFunction( + {prefix + "might_contain"}); } } // namespace sparksql diff --git a/velox/functions/sparksql/aggregates/BloomFilterAggAggregate.cpp b/velox/functions/sparksql/aggregates/BloomFilterAggAggregate.cpp new file mode 100644 index 000000000000..bb5e53ef75ed --- /dev/null +++ b/velox/functions/sparksql/aggregates/BloomFilterAggAggregate.cpp @@ -0,0 +1,276 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/functions/sparksql/aggregates/BloomFilterAggAggregate.h" + +#include + +#include "velox/common/base/BloomFilter.h" +#include "velox/exec/Aggregate.h" +#include "velox/expression/FunctionSignature.h" +#include "velox/vector/FlatVector.h" + +namespace facebook::velox::functions::sparksql::aggregates { + +namespace { +struct BloomFilterAccumulator { + int32_t serializedSize() { + return bloomFilter.serializedSize(); + } + + void serialize(StringView& output) { + return bloomFilter.serialize(output); + } + + void deserialize( + StringView& serialized, + BloomFilter& output) { + BloomFilter::deserialize(serialized.data(), output); + } + + void mergeWith(StringView& serialized) { + BloomFilter output; + deserialize(serialized, output); + bloomFilter.merge(output); + } + + void init(int32_t capacity) { + if (!bloomFilter.isSet()) { + bloomFilter.reset(capacity); + } + } + + BloomFilter bloomFilter; +}; + +class BloomFilterAggAggregate : public exec::Aggregate { + public: + explicit BloomFilterAggAggregate(const TypePtr& resultType) + : Aggregate(resultType) {} + + int32_t accumulatorFixedWidthSize() const override { + return sizeof(BloomFilterAccumulator); + } + + /// Initialize each group. + void initializeNewGroups( + char** groups, + folly::Range indices) override { + setAllNulls(groups, indices); + for (auto i : indices) { + new (groups[i] + offset_) BloomFilterAccumulator(); + } + } + + void addRawInput( + char** groups, + const SelectivityVector& rows, + const std::vector& args, + bool /*mayPushdown*/) override { + // ignore the estimatedNumItems, this config is not used in + // velox bloom filter implementation + decodeArguments(rows, args); + VELOX_CHECK(!decodedRaw_.mayHaveNulls()); + rows.applyToSelected([&](vector_size_t row) { + auto accumulator = value(groups[row]); + accumulator->init(capacity_); + accumulator->bloomFilter.insert(decodedRaw_.valueAt(row)); + }); + } + + void addIntermediateResults( + char** groups, + const SelectivityVector& rows, + const std::vector& args, + bool /*mayPushdown*/) override { + VELOX_CHECK_EQ(args.size(), 1); + decodedIntermediate_.decode(*args[0], rows); + VELOX_CHECK(!decodedIntermediate_.mayHaveNulls()); + rows.applyToSelected([&](auto row) { + auto group = groups[row]; + auto tracker = trackRowSize(group); + auto serialized = decodedIntermediate_.valueAt(row); + auto accumulator = value(group); + accumulator->mergeWith(serialized); + }); + } + + void addSingleGroupRawInput( + char* group, + const SelectivityVector& rows, + const std::vector& args, + bool /*mayPushdown*/) override { + decodeArguments(rows, args); + auto accumulator = value(group); + if (decodedRaw_.isConstantMapping()) { + // all values are same, just do for the first + accumulator->init(capacity_); + accumulator->bloomFilter.insert(decodedRaw_.valueAt(0)); + return; + } + rows.applyToSelected([&](vector_size_t row) { + accumulator->init(capacity_); + accumulator->bloomFilter.insert(decodedRaw_.valueAt(row)); + }); + } + + void addSingleGroupIntermediateResults( + char* group, + const SelectivityVector& rows, + const std::vector& args, + bool /*mayPushdown*/) override { + VELOX_CHECK_EQ(args.size(), 1); + decodedIntermediate_.decode(*args[0], rows); + VELOX_CHECK(!decodedIntermediate_.mayHaveNulls()); + auto tracker = trackRowSize(group); + rows.applyToSelected([&](auto row) { + auto serialized = decodedIntermediate_.valueAt(row); + auto accumulator = value(group); + accumulator->mergeWith(serialized); + }); + } + + void extractValues(char** groups, int32_t numGroups, VectorPtr* result) + override { + VELOX_CHECK(result); + auto flatResult = (*result)->asUnchecked>(); + flatResult->resize(numGroups); + for (vector_size_t i = 0; i < numGroups; ++i) { + auto group = groups[i]; + VELOX_CHECK_NOT_NULL(group); + auto accumulator = value(group); + auto size = accumulator->serializedSize(); + if (StringView::isInline(size)) { + StringView serialized(size); + accumulator->bloomFilter.serialize(serialized); + flatResult->setNoCopy(i, serialized); + } else { + Buffer* buffer = flatResult->getBufferWithSpace(size); + StringView serialized(buffer->as() + buffer->size(), size); + accumulator->bloomFilter.serialize(serialized); + buffer->setSize(buffer->size() + size); + flatResult->setNoCopy(i, serialized); + } + } + } + + void extractAccumulators(char** groups, int32_t numGroups, VectorPtr* result) + override { + extractValues(groups, numGroups, result); + } + + private: + const int64_t DEFAULT_ESPECTED_NUM_ITEMS = 1000000; + const int64_t MAX_NUM_ITEMS = 4000000; + const int64_t MAX_NUM_BITS = 67108864; + + void decodeArguments( + const SelectivityVector& rows, + const std::vector& args) { + if (args.size() > 0) { + decodedRaw_.decode(*args[0], rows); + if (args.size() > 1) { + DecodedVector decodedEstimatedNumItems(*args[1], rows); + setConstantArgument( + "estimatedNumItems", estimatedNumItems_, decodedEstimatedNumItems); + if (args.size() > 2) { + DecodedVector decodedNumBits(*args[2], rows); + setConstantArgument("numBits", numBits_, decodedNumBits); + } else { + VELOX_CHECK_EQ(args.size(), 3); + numBits_ = estimatedNumItems_ * 8; + } + } else { + estimatedNumItems_ = DEFAULT_ESPECTED_NUM_ITEMS; + numBits_ = estimatedNumItems_ * 8; + } + } else { + VELOX_USER_FAIL("Function args size must be more than 0") + } + estimatedNumItems_ = std::min(estimatedNumItems_, MAX_NUM_ITEMS); + numBits_ = std::min(numBits_, MAX_NUM_BITS); + // velox BloomFilter bit_ size is bits::nextPowerOfTwo(capacity) / 4, and + // spark bit_ size is Math.ceil(numBits / 64.0) so there is equal bit_ size + // using numBits_ / 16 + // but with TPCDS test, for velox BloomFilter, this value should be 64 + capacity_ = numBits_ / 64; + } + + static void + setConstantArgument(const char* name, int64_t& val, int64_t newVal) { + VELOX_USER_CHECK_GT(newVal, 0, "{} must be positive", name); + if (val == kMissingArgument) { + val = newVal; + } else { + VELOX_USER_CHECK_EQ( + newVal, val, "{} argument must be constant for all input rows", name); + } + } + + static void setConstantArgument( + const char* name, + int64_t& val, + const DecodedVector& vec) { + VELOX_CHECK( + vec.isConstantMapping(), + "{} argument must be constant for all input rows", + name); + setConstantArgument(name, val, vec.valueAt(0)); + } + + static constexpr int64_t kMissingArgument = -1; + // Reusable instance of DecodedVector for decoding input vectors. + DecodedVector decodedRaw_; + DecodedVector decodedIntermediate_; + int64_t estimatedNumItems_ = kMissingArgument; + int64_t numBits_ = kMissingArgument; + int32_t capacity_ = kMissingArgument; +}; + +} // namespace + +bool registerBloomFilterAggAggregate(const std::string& name) { + std::vector> signatures{ + exec::AggregateFunctionSignatureBuilder() + .argumentType("bigint") + .argumentType("bigint") + .argumentType("bigint") + .intermediateType("varbinary") + .returnType("varbinary") + .build(), + exec::AggregateFunctionSignatureBuilder() + .argumentType("bigint") + .argumentType("bigint") + .intermediateType("varbinary") + .returnType("varbinary") + .build(), + exec::AggregateFunctionSignatureBuilder() + .argumentType("bigint") + .intermediateType("varbinary") + .returnType("varbinary") + .build()}; + + return exec::registerAggregateFunction( + name, + std::move(signatures), + [name]( + core::AggregationNode::Step step, + const std::vector& argTypes, + const TypePtr& resultType) -> std::unique_ptr { + return std::make_unique(resultType); + }); +} +} // namespace facebook::velox::functions::sparksql::aggregates diff --git a/velox/functions/sparksql/aggregates/BloomFilterAggAggregate.h b/velox/functions/sparksql/aggregates/BloomFilterAggAggregate.h new file mode 100644 index 000000000000..c1d53bfca3ac --- /dev/null +++ b/velox/functions/sparksql/aggregates/BloomFilterAggAggregate.h @@ -0,0 +1,25 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +namespace facebook::velox::functions::sparksql::aggregates { + +bool registerBloomFilterAggAggregate(const std::string& name); + +} // namespace facebook::velox::functions::sparksql::aggregates diff --git a/velox/functions/sparksql/aggregates/CMakeLists.txt b/velox/functions/sparksql/aggregates/CMakeLists.txt index 116923f1c936..13cd2a7b4c3b 100644 --- a/velox/functions/sparksql/aggregates/CMakeLists.txt +++ b/velox/functions/sparksql/aggregates/CMakeLists.txt @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -add_library(velox_functions_spark_aggregates LastAggregate.cpp Register.cpp) +add_library(velox_functions_spark_aggregates LastAggregate.cpp BloomFilterAggAggregate.cpp Register.cpp) target_link_libraries(velox_functions_spark_aggregates ${FMT} velox_exec velox_expression_functions velox_aggregates velox_vector) diff --git a/velox/functions/sparksql/aggregates/Register.cpp b/velox/functions/sparksql/aggregates/Register.cpp index 60bcbaca673f..bdba930d7786 100644 --- a/velox/functions/sparksql/aggregates/Register.cpp +++ b/velox/functions/sparksql/aggregates/Register.cpp @@ -16,11 +16,13 @@ #include "velox/functions/sparksql/aggregates/Register.h" +#include "velox/functions/sparksql/aggregates/BloomFilterAggAggregate.h" #include "velox/functions/sparksql/aggregates/LastAggregate.h" namespace facebook::velox::functions::sparksql::aggregates { void registerAggregateFunctions(const std::string& prefix) { aggregates::registerLastAggregate(prefix + "last"); + aggregates::registerBloomFilterAggAggregate(prefix + "bloom_filter_agg"); } } // namespace facebook::velox::functions::sparksql::aggregates diff --git a/velox/functions/sparksql/aggregates/tests/BloomFilterAggAggregateTest.cpp b/velox/functions/sparksql/aggregates/tests/BloomFilterAggAggregateTest.cpp new file mode 100644 index 000000000000..110fec7e1b7b --- /dev/null +++ b/velox/functions/sparksql/aggregates/tests/BloomFilterAggAggregateTest.cpp @@ -0,0 +1,45 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/exec/tests/utils/PlanBuilder.h" +#include "velox/functions/prestosql/aggregates/tests/AggregationTestBase.h" +#include "velox/functions/sparksql/aggregates/Register.h" + +namespace facebook::velox::functions::sparksql::aggregates::test { +namespace { +class BloomFilterAggAggregateTest + : public aggregate::test::AggregationTestBase { + public: + BloomFilterAggAggregateTest() { + aggregate::test::AggregationTestBase::SetUp(); + aggregates::registerAggregateFunctions(""); + } +}; +} // namespace + +TEST_F(BloomFilterAggAggregateTest, bloomFilter) { + auto vectors = {makeRowVector({makeFlatVector( + 100, [](vector_size_t row) { return row / 3; })})}; + + auto expected = {makeRowVector({makeFlatVector< + StringView>(1, [](vector_size_t row) { + return "\u0004\u0000\u0000\u0000\u0003\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000\u0000"; + })})}; + + testAggregations(vectors, {}, {"bloom_filter_agg(c0, 5, 10)"}, expected); +} + +} // namespace facebook::velox::functions::sparksql::aggregates::test diff --git a/velox/functions/sparksql/aggregates/tests/CMakeLists.txt b/velox/functions/sparksql/aggregates/tests/CMakeLists.txt index 1f7c68bdb48f..448eed1ee8f5 100644 --- a/velox/functions/sparksql/aggregates/tests/CMakeLists.txt +++ b/velox/functions/sparksql/aggregates/tests/CMakeLists.txt @@ -13,7 +13,7 @@ # limitations under the License. add_executable(velox_functions_spark_aggregates_test LastAggregateTest.cpp - Main.cpp) + BloomFilterAggAggregateTest.cpp Main.cpp) add_test(velox_functions_spark_aggregates_test velox_functions_spark_aggregates_test) diff --git a/velox/functions/sparksql/tests/CMakeLists.txt b/velox/functions/sparksql/tests/CMakeLists.txt index b168c3adc374..3b69a3c151dd 100644 --- a/velox/functions/sparksql/tests/CMakeLists.txt +++ b/velox/functions/sparksql/tests/CMakeLists.txt @@ -29,7 +29,8 @@ add_executable( SplitFunctionsTest.cpp StringTest.cpp SubscriptTest.cpp - XxHash64Test.cpp) + XxHash64Test.cpp + MightContainTest.cpp) add_test(velox_functions_spark_test velox_functions_spark_test) diff --git a/velox/functions/sparksql/tests/MightContainTest.cpp b/velox/functions/sparksql/tests/MightContainTest.cpp new file mode 100644 index 000000000000..cf0f38d1ccf1 --- /dev/null +++ b/velox/functions/sparksql/tests/MightContainTest.cpp @@ -0,0 +1,51 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/functions/sparksql/MightContain.h" +#include "velox/common/base/BloomFilter.h" +#include "velox/functions/sparksql/tests/SparkFunctionBaseTest.h" + +namespace facebook::velox::functions::sparksql::test { +namespace { + +class MightContainTest : public SparkFunctionBaseTest { + protected: + std::optional mightContain( + std::optional bloom, + int64_t value) { + return evaluateOnce( + fmt::format("might_contain(cast(c0 as varbinary), {})", value), bloom); + } +}; + +TEST_F(MightContainTest, common) { + constexpr int64_t kSize = 1024; + BloomFilter bloom; + bloom.reset(kSize); + for (auto i = 0; i < kSize; ++i) { + bloom.insert(i); + } + std::string data; + data.resize(bloom.serializedSize()); + StringView serialized(data.data(), data.size()); + bloom.serialize(serialized); + + for (auto i = 0; i < kSize; ++i) { + EXPECT_TRUE(mightContain(serialized, i)); + } +} +} // namespace +} // namespace facebook::velox::functions::sparksql::test diff --git a/velox/substrait/SubstraitToVeloxExpr.cpp b/velox/substrait/SubstraitToVeloxExpr.cpp index f06fb3523c51..8fff6ff2bdbe 100644 --- a/velox/substrait/SubstraitToVeloxExpr.cpp +++ b/velox/substrait/SubstraitToVeloxExpr.cpp @@ -353,6 +353,9 @@ SubstraitVeloxExprConverter::toVeloxExpr( BaseVector::wrapInConstant(1, 0, literalsToArrayVector(substraitLit)); return std::make_shared(constantVector); } + case ::substrait::Expression_Literal::LiteralTypeCase::kBinary: + return std::make_shared( + variant::binary(substraitLit.binary())); default: VELOX_NYI( "Substrait conversion not supported for type case '{}'", typeCase);