Skip to content

Commit

Permalink
Add might_contain SparkSql function
Browse files Browse the repository at this point in the history
  • Loading branch information
jinchengchenghh committed Apr 13, 2023
1 parent 9e5f57b commit df43e1b
Show file tree
Hide file tree
Showing 8 changed files with 198 additions and 1 deletion.
2 changes: 1 addition & 1 deletion velox/common/base/BloomFilter.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ class BloomFilter {
void merge(const char* serialized) {
common::InputByteStream stream(serialized);
auto version = stream.read<int8_t>();
VELOX_CHECK_EQ(kBloomFilterV1, version);
VELOX_USER_CHECK_EQ(kBloomFilterV1, version);
auto size = stream.read<int32_t>();
bits_.resize(size);
auto bitsdata =
Expand Down
7 changes: 7 additions & 0 deletions velox/docs/functions/spark/binary.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,13 @@ Binary Functions
Computes the md5 of x.

.. spark:function:: might_contain(bloomFilter, value) -> boolean
Returns TRUE if ``bloomFilter`` might contain ``value``.

``bloomFilter`` is a VARBINARY computed using ::spark::function::`bloom_filter_agg` aggregate function.
``value`` is a BIGINT.

.. spark:function:: sha1(x) -> varchar
Computes SHA-1 digest of x and convert the result to a hex string.
Expand Down
1 change: 1 addition & 0 deletions velox/functions/sparksql/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ add_library(
In.cpp
LeastGreatest.cpp
Map.cpp
MightContain.cpp
RegexFunctions.cpp
Register.cpp
RegisterArithmetic.cpp
Expand Down
71 changes: 71 additions & 0 deletions velox/functions/sparksql/MightContain.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "velox/functions/sparksql/MightContain.h"

#include "velox/common/base/BloomFilter.h"
#include "velox/common/memory/HashStringAllocator.h"
#include "velox/expression/DecodedArgs.h"
#include "velox/vector/FlatVector.h"

namespace facebook::velox::functions::sparksql {
namespace {
class BloomFilterMightContainFunction final : public exec::VectorFunction {
void apply(
const SelectivityVector& rows,
std::vector<VectorPtr>& args, // Not using const ref so we can reuse args
const TypePtr& outputType,
exec::EvalCtx& context,
VectorPtr& resultRef) const final {
VELOX_CHECK_EQ(args.size(), 2);
context.ensureWritable(rows, BOOLEAN(), resultRef);
auto& result = *resultRef->as<FlatVector<bool>>();
exec::DecodedArgs decodedArgs(rows, args, context);
auto serialized = decodedArgs.at(0);
auto value = decodedArgs.at(1);

HashStringAllocator allocator{context.pool()};
VELOX_USER_CHECK(serialized->isConstantMapping())
BloomFilter output{StlAllocator<uint64_t>(&allocator)};
try {
output.merge(serialized->valueAt<StringView>(0).str().c_str());
} catch (const std::exception& e) {
rows.applyToSelected(
[&](int row) { context.setError(row, std::current_exception()); });
return;
}

rows.applyToSelected([&](int row) {
auto contain = output.mayContain(
folly::hasher<int64_t>()(value->valueAt<int64_t>(row)));
result.set(row, contain);
});
}
};
} // namespace

std::vector<std::shared_ptr<exec::FunctionSignature>> mightContainSignatures() {
return {exec::FunctionSignatureBuilder()
.returnType("boolean")
.constantArgumentType("varbinary")
.argumentType("bigint")
.build()};
}

std::unique_ptr<exec::VectorFunction> makeMightContain() {
return std::make_unique<BloomFilterMightContainFunction>();
}

} // namespace facebook::velox::functions::sparksql
24 changes: 24 additions & 0 deletions velox/functions/sparksql/MightContain.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "velox/expression/VectorFunction.h"

namespace facebook::velox::functions::sparksql {

std::vector<std::shared_ptr<exec::FunctionSignature>> mightContainSignatures();

std::unique_ptr<exec::VectorFunction> makeMightContain();

} // namespace facebook::velox::functions::sparksql
5 changes: 5 additions & 0 deletions velox/functions/sparksql/Register.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include "velox/functions/sparksql/Hash.h"
#include "velox/functions/sparksql/In.h"
#include "velox/functions/sparksql/LeastGreatest.h"
#include "velox/functions/sparksql/MightContain.h"
#include "velox/functions/sparksql/RegexFunctions.h"
#include "velox/functions/sparksql/RegisterArithmetic.h"
#include "velox/functions/sparksql/RegisterCompare.h"
Expand Down Expand Up @@ -165,6 +166,10 @@ void registerFunctions(const std::string& prefix) {
int64_t,
Varchar,
Varchar>({prefix + "unix_timestamp", prefix + "to_unix_timestamp"});

// Register bloom filter function
exec::registerVectorFunction(
prefix + "might_contain", mightContainSignatures(), makeMightContain());
}

} // namespace sparksql
Expand Down
1 change: 1 addition & 0 deletions velox/functions/sparksql/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ add_executable(
InTest.cpp
LeastGreatestTest.cpp
MapTest.cpp
MightContainTest.cpp
RegexFunctionsTest.cpp
SizeTest.cpp
SortArrayTest.cpp
Expand Down
88 changes: 88 additions & 0 deletions velox/functions/sparksql/tests/MightContainTest.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "velox/functions/sparksql/MightContain.h"
#include "velox/common/base/BloomFilter.h"
#include "velox/common/memory/HashStringAllocator.h"
#include "velox/functions/sparksql/tests/SparkFunctionBaseTest.h"

namespace facebook::velox::functions::sparksql::test {
namespace {

using namespace facebook::velox::test;

class MightContainTest : public SparkFunctionBaseTest {
protected:
void testMightContain(
const VectorPtr& bloom,
const VectorPtr& value,
const VectorPtr& expected) {
auto result = evaluate(
"might_contain(cast(c0 as varbinary), c1)",
makeRowVector({bloom, value}));
assertEqualVectors(expected, result);
}

std::string getSerializedBloomFilter() {
constexpr int64_t kSize = 10;
BloomFilter bloom;
bloom.reset(kSize);
for (auto i = 0; i < kSize; ++i) {
bloom.insert(folly::hasher<int64_t>()(i));
}
std::string data;
data.resize(bloom.serializedSize());
bloom.serialize(data.data());
return data;
}
};

TEST_F(MightContainTest, basic) {
auto serialized = getSerializedBloomFilter();
auto bloom = makeConstant<StringView>(StringView(serialized), 10);
auto value =
makeFlatVector<int64_t>(10, [](vector_size_t row) { return row; });
auto expected = makeConstant(true, 10);
testMightContain(bloom, value, expected);

auto valueNotContain = makeFlatVector<int64_t>(
10, [](vector_size_t row) { return row + 123451; });
auto expectedNotContain = makeConstant(false, 10);
testMightContain(bloom, valueNotContain, expectedNotContain);

auto values = makeNullableFlatVector<int64_t>(
{1, 2, 3, 4, 5, std::nullopt, 123451, 23456, 4, 5});
auto expects = makeNullableFlatVector<bool>(
{true, true, true, true, true, std::nullopt, false, false, true, true});
testMightContain(bloom, values, expects);
}

TEST_F(MightContainTest, nullBloom) {
auto bloom = makeConstant<StringView>(std::nullopt, 2);
auto value = makeFlatVector<int64_t>({2, 4});
auto expected = makeNullConstant(TypeKind::BOOLEAN, 2);
testMightContain(bloom, value, expected);
}

TEST_F(MightContainTest, nullValue) {
auto serializedBloom = getSerializedBloomFilter();
auto bloom = makeConstant<StringView>(StringView(serializedBloom), 2);
auto value = makeNullableFlatVector<int64_t>({std::nullopt, 2});
auto expected = makeNullableFlatVector<bool>({std::nullopt, true});
testMightContain(bloom, value, expected);
}
} // namespace
} // namespace facebook::velox::functions::sparksql::test

0 comments on commit df43e1b

Please sign in to comment.