From 25ef68fc4d89cf1d79bcc496c33c62aea29396fc Mon Sep 17 00:00:00 2001 From: Alexandr Guzhva Date: Thu, 23 May 2024 20:52:49 -0400 Subject: [PATCH] bitset: multiple 'and' and 'or' in a single op Signed-off-by: Alexandr Guzhva --- internal/core/src/bitset/bitset.h | 268 ++++++++--- internal/core/src/bitset/common.h | 13 + internal/core/src/bitset/detail/bit_wise.h | 180 ++++--- .../src/bitset/detail/element_vectorized.h | 212 ++++---- .../core/src/bitset/detail/element_wise.h | 388 ++++++++++----- .../core/src/bitset/detail/maybe_vector.h | 91 ++++ .../bitset/detail/platform/arm/neon-decl.h | 120 +++++ .../bitset/detail/platform/arm/neon-impl.h | 146 ++++++ .../src/bitset/detail/platform/arm/neon.h | 24 + .../src/bitset/detail/platform/arm/sve-decl.h | 120 +++++ .../src/bitset/detail/platform/arm/sve-impl.h | 148 +++++- .../core/src/bitset/detail/platform/arm/sve.h | 24 + .../src/bitset/detail/platform/dynamic.cpp | 180 +++++++ .../core/src/bitset/detail/platform/dynamic.h | 192 +++++++- .../bitset/detail/platform/vectorized_ref.h | 72 ++- .../bitset/detail/platform/x86/avx2-decl.h | 120 +++++ .../bitset/detail/platform/x86/avx2-impl.h | 146 ++++++ .../src/bitset/detail/platform/x86/avx2.h | 24 + .../bitset/detail/platform/x86/avx512-decl.h | 120 +++++ .../bitset/detail/platform/x86/avx512-impl.h | 146 ++++++ .../src/bitset/detail/platform/x86/avx512.h | 24 + internal/core/src/bitset/detail/proxy.h | 13 +- internal/core/unittest/test_bitset.cpp | 455 ++++++++++++++++++ 23 files changed, 2871 insertions(+), 355 deletions(-) create mode 100644 internal/core/src/bitset/detail/maybe_vector.h diff --git a/internal/core/src/bitset/bitset.h b/internal/core/src/bitset/bitset.h index 27a659ae14560..0e5f75bff0bd0 100644 --- a/internal/core/src/bitset/bitset.h +++ b/internal/core/src/bitset/bitset.h @@ -23,6 +23,7 @@ #include #include "common.h" +#include "detail/maybe_vector.h" namespace milvus { namespace bitset { @@ -109,7 +110,6 @@ class BitsetBase { public: using policy_type = PolicyT; using data_type = typename policy_type::data_type; - using size_type = typename policy_type::size_type; using proxy_type = typename policy_type::proxy_type; using const_proxy_type = typename policy_type::const_proxy_type; @@ -128,21 +128,21 @@ class BitsetBase { } // Return the number of bits we're working with. - inline size_type + inline size_t size() const { return as_derived().size_impl(); } // Return the number of bytes which is needed to // contain all our bits. - inline size_type + inline size_t size_in_bytes() const { return policy_type::get_required_size_in_bytes(this->size()); } // Return the number of elements which is needed to // contain all our bits. - inline size_type + inline size_t size_in_elements() const { return policy_type::get_required_size_in_elements(this->size()); } @@ -155,19 +155,19 @@ class BitsetBase { // inline proxy_type - operator[](const size_type bit_idx) { + operator[](const size_t bit_idx) { range_checker::lt(bit_idx, this->size()); - const size_type idx_v = bit_idx + this->offset(); + const size_t idx_v = bit_idx + this->offset(); return policy_type::get_proxy(this->data(), idx_v); } // inline bool - operator[](const size_type bit_idx) const { + operator[](const size_t bit_idx) const { range_checker::lt(bit_idx, this->size()); - const size_type idx_v = bit_idx + this->offset(); + const size_t idx_v = bit_idx + this->offset(); const auto proxy = policy_type::get_proxy(this->data(), idx_v); return proxy.operator bool(); } @@ -180,7 +180,7 @@ class BitsetBase { // Set a given bit to a given value. inline void - set(const size_type bit_idx, const bool value = true) { + set(const size_t bit_idx, const bool value = true) { this->operator[](bit_idx) = value; } @@ -192,7 +192,7 @@ class BitsetBase { // Set a given bit to false. inline void - reset(const size_type bit_idx) { + reset(const size_t bit_idx) { this->operator[](bit_idx) = false; } @@ -217,7 +217,7 @@ class BitsetBase { // Inplace and. template inline void - inplace_and(const BitsetBase& other, const size_type size) { + inplace_and(const BitsetBase& other, const size_t size) { range_checker::le(size, this->size()); range_checker::le(size, other.size()); @@ -225,6 +225,74 @@ class BitsetBase { this->data(), other.data(), this->offset(), other.offset(), size); } + template + inline void + inplace_and(const BitsetView* const others, + const size_t n_others, + const size_t size) { + range_checker::le(size, this->size()); + for (size_t i = 0; i < n_others; i++) { + range_checker::le(size, others[i].size()); + } + + // pick buffers + detail::MaybeVector tmp_data(n_others); + detail::MaybeVector tmp_offset(n_others); + + for (size_t i = 0; i < n_others; i++) { + tmp_data[i] = others[i].data(); + tmp_offset[i] = others[i].offset(); + } + + policy_type::op_and_multiple(this->data(), + tmp_data.data(), + this->offset(), + tmp_offset.data(), + n_others, + size); + } + + template + inline void + inplace_and(const BitsetView* const others, + const size_t n_others) { + this->inplace_and(others, n_others, this->size()); + } + + template + inline void + inplace_and(const Bitset* const others, + const size_t n_others, + const size_t size) { + range_checker::le(size, this->size()); + for (size_t i = 0; i < n_others; i++) { + range_checker::le(size, others[i].size()); + } + + // pick buffers + detail::MaybeVector tmp_data(n_others); + detail::MaybeVector tmp_offset(n_others); + + for (size_t i = 0; i < n_others; i++) { + tmp_data[i] = others[i].data(); + tmp_offset[i] = others[i].offset(); + } + + policy_type::op_and_multiple(this->data(), + tmp_data.data(), + this->offset(), + tmp_offset.data(), + n_others, + size); + } + + template + inline void + inplace_and(const Bitset* const others, + const size_t n_others) { + this->inplace_and(others, n_others, this->size()); + } + // Inplace and. A given bitset / bitset view is expected to have the same size. template inline ImplT& @@ -238,7 +306,7 @@ class BitsetBase { // Inplace or. template inline void - inplace_or(const BitsetBase& other, const size_type size) { + inplace_or(const BitsetBase& other, const size_t size) { range_checker::le(size, this->size()); range_checker::le(size, other.size()); @@ -246,6 +314,74 @@ class BitsetBase { this->data(), other.data(), this->offset(), other.offset(), size); } + template + inline void + inplace_or(const BitsetView* const others, + const size_t n_others, + const size_t size) { + range_checker::le(size, this->size()); + for (size_t i = 0; i < n_others; i++) { + range_checker::le(size, others[i].size()); + } + + // pick buffers + detail::MaybeVector tmp_data(n_others); + detail::MaybeVector tmp_offset(n_others); + + for (size_t i = 0; i < n_others; i++) { + tmp_data[i] = others[i].data(); + tmp_offset[i] = others[i].offset(); + } + + policy_type::op_or_multiple(this->data(), + tmp_data.data(), + this->offset(), + tmp_offset.data(), + n_others, + size); + } + + template + inline void + inplace_or(const BitsetView* const others, + const size_t n_others) { + this->inplace_or(others, n_others, this->size()); + } + + template + inline void + inplace_or(const Bitset* const others, + const size_t n_others, + const size_t size) { + range_checker::le(size, this->size()); + for (size_t i = 0; i < n_others; i++) { + range_checker::le(size, others[i].size()); + } + + // pick buffers + detail::MaybeVector tmp_data(n_others); + detail::MaybeVector tmp_offset(n_others); + + for (size_t i = 0; i < n_others; i++) { + tmp_data[i] = others[i].data(); + tmp_offset[i] = others[i].offset(); + } + + policy_type::op_or_multiple(this->data(), + tmp_data.data(), + this->offset(), + tmp_offset.data(), + n_others, + size); + } + + template + inline void + inplace_or(const Bitset* const others, + const size_t n_others) { + this->inplace_or(others, n_others, this->size()); + } + // Inplace or. A given bitset / bitset view is expected to have the same size. template inline ImplT& @@ -264,13 +400,13 @@ class BitsetBase { // inline BitsetView - operator+(const size_type offset) { + operator+(const size_t offset) { return this->view(offset); } // Create a view of a given size from the given position. inline BitsetView - view(const size_type offset, const size_type size) { + view(const size_t offset, const size_t size) { range_checker::le(offset, this->size()); range_checker::le(offset + size, this->size()); @@ -280,7 +416,7 @@ class BitsetBase { // Create a const view of a given size from the given position. inline BitsetView - view(const size_type offset, const size_type size) const { + view(const size_t offset, const size_t size) const { range_checker::le(offset, this->size()); range_checker::le(offset + size, this->size()); @@ -292,7 +428,7 @@ class BitsetBase { // Create a view from the given position, which uses all available size. inline BitsetView - view(const size_type offset) { + view(const size_t offset) { range_checker::le(offset, this->size()); return BitsetView( @@ -301,7 +437,7 @@ class BitsetBase { // Create a const view from the given position, which uses all available size. inline const BitsetView - view(const size_type offset) const { + view(const size_t offset) const { range_checker::le(offset, this->size()); return BitsetView( @@ -323,7 +459,7 @@ class BitsetBase { } // Return the number of bits which are set to true. - inline size_type + inline size_t count() const { return policy_type::op_count( this->data(), this->offset(), this->size()); @@ -354,7 +490,7 @@ class BitsetBase { // Inplace xor. template inline void - inplace_xor(const BitsetBase& other, const size_type size) { + inplace_xor(const BitsetBase& other, const size_t size) { range_checker::le(size, this->size()); range_checker::le(size, other.size()); @@ -375,7 +511,7 @@ class BitsetBase { // Inplace sub. template inline void - inplace_sub(const BitsetBase& other, const size_type size) { + inplace_sub(const BitsetBase& other, const size_t size) { range_checker::le(size, this->size()); range_checker::le(size, other.size()); @@ -394,16 +530,16 @@ class BitsetBase { } // Find the index of the first bit set to true. - inline std::optional + inline std::optional find_first() const { return policy_type::op_find( this->data(), this->offset(), this->size(), 0); } // Find the index of the first bit set to true, starting from a given bit index. - inline std::optional - find_next(const size_type starting_bit_idx) const { - const size_type size_v = this->size(); + inline std::optional + find_next(const size_t starting_bit_idx) const { + const size_t size_v = this->size(); if (starting_bit_idx + 1 >= size_v) { return std::nullopt; } @@ -414,7 +550,7 @@ class BitsetBase { // Read multiple bits starting from a given bit index. inline data_type - read(const size_type starting_bit_idx, const size_type nbits) { + read(const size_t starting_bit_idx, const size_t nbits) { range_checker::le(nbits, sizeof(data_type)); return policy_type::op_read( @@ -423,9 +559,9 @@ class BitsetBase { // Write multiple bits starting from a given bit index. inline void - write(const size_type starting_bit_idx, + write(const size_t starting_bit_idx, const data_type value, - const size_type nbits) { + const size_t nbits) { range_checker::le(nbits, sizeof(data_type)); policy_type::op_write( @@ -437,7 +573,7 @@ class BitsetBase { void inplace_compare_column(const T* const __restrict t, const U* const __restrict u, - const size_type size, + const size_t size, CompareOpType op) { if (op == CompareOpType::EQ) { this->inplace_compare_column(t, u, size); @@ -460,7 +596,7 @@ class BitsetBase { void inplace_compare_column(const T* const __restrict t, const U* const __restrict u, - const size_type size) { + const size_t size) { range_checker::le(size, this->size()); policy_type::template op_compare_column( @@ -471,7 +607,7 @@ class BitsetBase { template void inplace_compare_val(const T* const __restrict t, - const size_type size, + const size_t size, const T& value, CompareOpType op) { if (op == CompareOpType::EQ) { @@ -494,7 +630,7 @@ class BitsetBase { template void inplace_compare_val(const T* const __restrict t, - const size_type size, + const size_t size, const T& value) { range_checker::le(size, this->size()); @@ -508,7 +644,7 @@ class BitsetBase { inplace_within_range_column(const T* const __restrict lower, const T* const __restrict upper, const T* const __restrict values, - const size_type size, + const size_t size, const RangeType op) { if (op == RangeType::IncInc) { this->inplace_within_range_column( @@ -532,7 +668,7 @@ class BitsetBase { inplace_within_range_column(const T* const __restrict lower, const T* const __restrict upper, const T* const __restrict values, - const size_type size) { + const size_t size) { range_checker::le(size, this->size()); policy_type::template op_within_range_column( @@ -545,7 +681,7 @@ class BitsetBase { inplace_within_range_val(const T& lower, const T& upper, const T* const __restrict values, - const size_type size, + const size_t size, const RangeType op) { if (op == RangeType::IncInc) { this->inplace_within_range_val( @@ -569,7 +705,7 @@ class BitsetBase { inplace_within_range_val(const T& lower, const T& upper, const T* const __restrict values, - const size_type size) { + const size_t size) { range_checker::le(size, this->size()); policy_type::template op_within_range_val( @@ -582,7 +718,7 @@ class BitsetBase { inplace_arith_compare(const T* const __restrict src, const ArithHighPrecisionType& right_operand, const ArithHighPrecisionType& value, - const size_type size, + const size_t size, const ArithOpType a_op, const CompareOpType cmp_op) { if (a_op == ArithOpType::Add) { @@ -765,7 +901,7 @@ class BitsetBase { inplace_arith_compare(const T* const __restrict src, const ArithHighPrecisionType& right_operand, const ArithHighPrecisionType& value, - const size_type size) { + const size_t size) { range_checker::le(size, this->size()); policy_type::template op_arith_compare( @@ -775,9 +911,9 @@ class BitsetBase { // // Inplace and. Also, counts the number of active bits. template - inline size_type + inline size_t inplace_and_with_count(const BitsetBase& other, - const size_type size) { + const size_t size) { range_checker::le(size, this->size()); range_checker::le(size, other.size()); @@ -787,9 +923,9 @@ class BitsetBase { // Inplace or. Also, counts the number of inactive bits. template - inline size_type + inline size_t inplace_or_with_count(const BitsetBase& other, - const size_type size) { + const size_t size) { range_checker::le(size, this->size()); range_checker::le(size, other.size()); @@ -799,7 +935,7 @@ class BitsetBase { private: // Return the starting bit offset in our container. - inline size_type + inline size_t offset() const { return as_derived().offset_impl(); } @@ -829,7 +965,6 @@ class BitsetView : public BitsetBase(data)}, Size{size}, Offset{0} { } - BitsetView(void* data, const size_type offset, const size_type size) + BitsetView(void* data, const size_t offset, const size_t size) : Data{reinterpret_cast(data)}, Size{size}, Offset{offset} { } @@ -861,9 +996,9 @@ class BitsetView : public BitsetBase; + // This is the container type. using container_type = ContainerT; // This is how the data is stored. For example, we may operate using @@ -914,11 +1050,11 @@ class Bitset Bitset() { } // Allocate the given number of bits. - Bitset(const size_type size) + Bitset(const size_t size) : Data(get_required_size_in_container_elements(size)), Size{size} { } // Allocate the given number of bits, initialize with a given value. - Bitset(const size_type size, const bool init) + Bitset(const size_t size, const bool init) : Data(get_required_size_in_container_elements(size), init ? data_type(-1) : 0), Size{size} { @@ -964,8 +1100,8 @@ class Bitset // Resize. void - resize(const size_type new_size) { - const size_type new_size_in_container_elements = + resize(const size_t new_size) { + const size_t new_size_in_container_elements = get_required_size_in_container_elements(new_size); Data.resize(new_size_in_container_elements); Size = new_size; @@ -973,8 +1109,8 @@ class Bitset // Resize and initialize new bits with a given value if grown. void - resize(const size_type new_size, const bool init) { - const size_type old_size = this->size(); + resize(const size_t new_size, const bool init) { + const size_t old_size = this->size(); this->resize(new_size); if (new_size > old_size) { @@ -989,11 +1125,11 @@ class Bitset template void append(const BitsetBase& other, - const size_type starting_bit_idx, - const size_type count) { + const size_t starting_bit_idx, + const size_t count) { range_checker::le(starting_bit_idx, other.size()); - const size_type old_size = this->size(); + const size_t old_size = this->size(); this->resize(this->size() + count); policy_type::op_copy(other.data(), @@ -1020,8 +1156,8 @@ class Bitset // Reserve inline void - reserve(const size_type capacity) { - const size_type capacity_in_container_elements = + reserve(const size_t capacity) { + const size_t capacity_in_container_elements = get_required_size_in_container_elements(capacity); Data.reserve(capacity_in_container_elements); } @@ -1048,7 +1184,7 @@ class Bitset // the container container_type Data; // the actual number of bits - size_type Size = 0; + size_t Size = 0; inline data_type* data_impl() { @@ -1058,19 +1194,19 @@ class Bitset data_impl() const { return reinterpret_cast(Data.data()); } - inline size_type + inline size_t size_impl() const { return Size; } - inline size_type + inline size_t offset_impl() const { return 0; } // - static inline size_type + static inline size_t get_required_size_in_container_elements(const size_t size) { - const size_type size_in_bytes = + const size_t size_in_bytes = policy_type::get_required_size_in_bytes(size); return (size_in_bytes + sizeof(container_data_type) - 1) / sizeof(container_data_type); diff --git a/internal/core/src/bitset/common.h b/internal/core/src/bitset/common.h index 662813e91c2bf..1d4bb8186b456 100644 --- a/internal/core/src/bitset/common.h +++ b/internal/core/src/bitset/common.h @@ -27,6 +27,19 @@ namespace bitset { // this option is only somewhat supported // #define BITSET_HEADER_ONLY +// `always inline` hint. +// It is introduced to deal with clang's behavior to reuse +// once generated code. But if it is needed to generate +// different machine code for multiple platforms based on +// a single template, then such a behavior is undesired. +// `always inline` is applied for PolicyT methods. It is fine, +// because they are not used directly and are wrapped +// in BitsetBase methods. So, a compiler may decide whether +// to really inline them, but it forces a compiler to +// generate specialized code for every hardward platform. +// todo: MSVC has its own way to define `always inline`. +#define BITSET_ALWAYS_INLINE __attribute__((always_inline)) + // a supporting utility template inline constexpr bool always_false_v = false; diff --git a/internal/core/src/bitset/detail/bit_wise.h b/internal/core/src/bitset/detail/bit_wise.h index 5e8c1a37914c0..f3d08dc5be5c4 100644 --- a/internal/core/src/bitset/detail/bit_wise.h +++ b/internal/core/src/bitset/detail/bit_wise.h @@ -32,55 +32,53 @@ namespace detail { template struct BitWiseBitsetPolicy { using data_type = ElementT; - constexpr static auto data_bits = sizeof(data_type) * 8; - - using size_type = size_t; + constexpr static size_t data_bits = sizeof(data_type) * 8; using self_type = BitWiseBitsetPolicy; using proxy_type = Proxy; using const_proxy_type = ConstProxy; - static inline size_type + static inline size_t get_element(const size_t idx) { return idx / data_bits; } - static inline size_type + static inline size_t get_shift(const size_t idx) { return idx % data_bits; } - static inline size_type + static inline size_t get_required_size_in_elements(const size_t size) { return (size + data_bits - 1) / data_bits; } - static inline size_type + static inline size_t get_required_size_in_bytes(const size_t size) { return get_required_size_in_elements(size) * sizeof(data_type); } static inline proxy_type - get_proxy(data_type* const __restrict data, const size_type idx) { + get_proxy(data_type* const __restrict data, const size_t idx) { data_type& element = data[get_element(idx)]; - const size_type shift = get_shift(idx); + const size_t shift = get_shift(idx); return proxy_type{element, shift}; } static inline const_proxy_type - get_proxy(const data_type* const __restrict data, const size_type idx) { + get_proxy(const data_type* const __restrict data, const size_t idx) { const data_type& element = data[get_element(idx)]; - const size_type shift = get_shift(idx); + const size_t shift = get_shift(idx); return const_proxy_type{element, shift}; } static inline data_type op_read(const data_type* const data, - const size_type start, - const size_type nbits) { + const size_t start, + const size_t nbits) { data_type value = 0; - for (size_type i = 0; i < nbits; i++) { + for (size_t i = 0; i < nbits; i++) { const auto proxy = get_proxy(data, start + i); value += proxy ? (data_type(1) << i) : 0; } @@ -90,10 +88,10 @@ struct BitWiseBitsetPolicy { static void op_write(data_type* const data, - const size_type start, - const size_type nbits, + const size_t start, + const size_t nbits, const data_type value) { - for (size_type i = 0; i < nbits; i++) { + for (size_t i = 0; i < nbits; i++) { auto proxy = get_proxy(data, start + i); data_type mask = data_type(1) << i; if ((value & mask) == mask) { @@ -105,10 +103,8 @@ struct BitWiseBitsetPolicy { } static inline void - op_flip(data_type* const data, - const size_type start, - const size_type size) { - for (size_type i = 0; i < size; i++) { + op_flip(data_type* const data, const size_t start, const size_t size) { + for (size_t i = 0; i < size; i++) { auto proxy = get_proxy(data, start + i); proxy.flip(); } @@ -122,7 +118,7 @@ struct BitWiseBitsetPolicy { const size_t size) { // todo: check if intersect - for (size_type i = 0; i < size; i++) { + for (size_t i = 0; i < size; i++) { auto proxy_left = get_proxy(left, start_left + i); auto proxy_right = get_proxy(right, start_right + i); @@ -130,6 +126,27 @@ struct BitWiseBitsetPolicy { } } + static inline void + op_and_multiple(data_type* const left, + const data_type* const* const rights, + const size_t start_left, + const size_t* const __restrict start_rights, + const size_t n_rights, + const size_t size) { + for (size_t i = 0; i < size; i++) { + auto proxy_left = get_proxy(left, start_left + i); + + bool value = proxy_left; + for (size_t j = 0; j < n_rights; j++) { + auto proxy_right = get_proxy(rights[j], start_rights[j] + i); + + value &= proxy_right; + } + + proxy_left = value; + } + } + static inline void op_or(data_type* const left, const data_type* const right, @@ -138,7 +155,7 @@ struct BitWiseBitsetPolicy { const size_t size) { // todo: check if intersect - for (size_type i = 0; i < size; i++) { + for (size_t i = 0; i < size; i++) { auto proxy_left = get_proxy(left, start_left + i); auto proxy_right = get_proxy(right, start_right + i); @@ -147,26 +164,43 @@ struct BitWiseBitsetPolicy { } static inline void - op_set(data_type* const data, const size_type start, const size_type size) { - for (size_type i = 0; i < size; i++) { + op_or_multiple(data_type* const left, + const data_type* const* const rights, + const size_t start_left, + const size_t* const __restrict start_rights, + const size_t n_rights, + const size_t size) { + for (size_t i = 0; i < size; i++) { + auto proxy_left = get_proxy(left, start_left + i); + + bool value = proxy_left; + for (size_t j = 0; j < n_rights; j++) { + auto proxy_right = get_proxy(rights[j], start_rights[j] + i); + + value |= proxy_right; + } + + proxy_left = value; + } + } + + static inline void + op_set(data_type* const data, const size_t start, const size_t size) { + for (size_t i = 0; i < size; i++) { get_proxy(data, start + i) = true; } } static inline void - op_reset(data_type* const data, - const size_type start, - const size_type size) { - for (size_type i = 0; i < size; i++) { + op_reset(data_type* const data, const size_t start, const size_t size) { + for (size_t i = 0; i < size; i++) { get_proxy(data, start + i) = false; } } static inline bool - op_all(const data_type* const data, - const size_type start, - const size_type size) { - for (size_type i = 0; i < size; i++) { + op_all(const data_type* const data, const size_t start, const size_t size) { + for (size_t i = 0; i < size; i++) { if (!get_proxy(data, start + i)) { return false; } @@ -177,9 +211,9 @@ struct BitWiseBitsetPolicy { static inline bool op_none(const data_type* const data, - const size_type start, - const size_type size) { - for (size_type i = 0; i < size; i++) { + const size_t start, + const size_t size) { + for (size_t i = 0; i < size; i++) { if (get_proxy(data, start + i)) { return false; } @@ -190,11 +224,11 @@ struct BitWiseBitsetPolicy { static void op_copy(const data_type* const src, - const size_type start_src, + const size_t start_src, data_type* const dst, - const size_type start_dst, - const size_type size) { - for (size_type i = 0; i < size; i++) { + const size_t start_dst, + const size_t size) { + for (size_t i = 0; i < size; i++) { const auto src_p = get_proxy(src, start_src + i); auto dst_p = get_proxy(dst, start_dst + i); dst_p = src_p.operator bool(); @@ -203,22 +237,22 @@ struct BitWiseBitsetPolicy { static void op_fill(data_type* const dst, - const size_type start_dst, - const size_type size, + const size_t start_dst, + const size_t size, const bool value) { - for (size_type i = 0; i < size; i++) { + for (size_t i = 0; i < size; i++) { auto dst_p = get_proxy(dst, start_dst + i); dst_p = value; } } - static inline size_type + static inline size_t op_count(const data_type* const data, - const size_type start, - const size_type size) { - size_type count = 0; + const size_t start, + const size_t size) { + size_t count = 0; - for (size_type i = 0; i < size; i++) { + for (size_t i = 0; i < size; i++) { auto proxy = get_proxy(data, start + i); count += (proxy) ? 1 : 0; } @@ -232,7 +266,7 @@ struct BitWiseBitsetPolicy { const size_t start_left, const size_t start_right, const size_t size) { - for (size_type i = 0; i < size; i++) { + for (size_t i = 0; i < size; i++) { const auto proxy_left = get_proxy(left, start_left + i); const auto proxy_right = get_proxy(right, start_right + i); @@ -252,7 +286,7 @@ struct BitWiseBitsetPolicy { const size_t size) { // todo: check if intersect - for (size_type i = 0; i < size; i++) { + for (size_t i = 0; i < size; i++) { auto proxy_left = get_proxy(left, start_left + i); const auto proxy_right = get_proxy(right, start_right + i); @@ -268,7 +302,7 @@ struct BitWiseBitsetPolicy { const size_t size) { // todo: check if intersect - for (size_type i = 0; i < size; i++) { + for (size_t i = 0; i < size; i++) { auto proxy_left = get_proxy(left, start_left + i); const auto proxy_right = get_proxy(right, start_right + i); @@ -277,12 +311,12 @@ struct BitWiseBitsetPolicy { } // - static inline std::optional + static inline std::optional op_find(const data_type* const data, - const size_type start, - const size_type size, - const size_type starting_idx) { - for (size_type i = starting_idx; i < size; i++) { + const size_t start, + const size_t size, + const size_t starting_idx) { + for (size_t i = starting_idx; i < size; i++) { const auto proxy = get_proxy(data, start + i); if (proxy) { return i; @@ -296,11 +330,11 @@ struct BitWiseBitsetPolicy { template static inline void op_compare_column(data_type* const __restrict data, - const size_type start, + const size_t start, const T* const __restrict t, const U* const __restrict u, - const size_type size) { - for (size_type i = 0; i < size; i++) { + const size_t size) { + for (size_t i = 0; i < size; i++) { get_proxy(data, start + i) = CompareOperator::compare(t[i], u[i]); } @@ -310,11 +344,11 @@ struct BitWiseBitsetPolicy { template static inline void op_compare_val(data_type* const __restrict data, - const size_type start, + const size_t start, const T* const __restrict t, - const size_type size, + const size_t size, const T& value) { - for (size_type i = 0; i < size; i++) { + for (size_t i = 0; i < size; i++) { get_proxy(data, start + i) = CompareOperator::compare(t[i], value); } @@ -323,12 +357,12 @@ struct BitWiseBitsetPolicy { template static inline void op_within_range_column(data_type* const __restrict data, - const size_type start, + const size_t start, const T* const __restrict lower, const T* const __restrict upper, const T* const __restrict values, - const size_type size) { - for (size_type i = 0; i < size; i++) { + const size_t size) { + for (size_t i = 0; i < size; i++) { get_proxy(data, start + i) = RangeOperator::within_range(lower[i], upper[i], values[i]); } @@ -338,12 +372,12 @@ struct BitWiseBitsetPolicy { template static inline void op_within_range_val(data_type* const __restrict data, - const size_type start, + const size_t start, const T& lower, const T& upper, const T* const __restrict values, - const size_type size) { - for (size_type i = 0; i < size; i++) { + const size_t size) { + for (size_t i = 0; i < size; i++) { get_proxy(data, start + i) = RangeOperator::within_range(lower, upper, values[i]); } @@ -353,12 +387,12 @@ struct BitWiseBitsetPolicy { template static inline void op_arith_compare(data_type* const __restrict data, - const size_type start, + const size_t start, const T* const __restrict src, const ArithHighPrecisionType& right_operand, const ArithHighPrecisionType& value, - const size_type size) { - for (size_type i = 0; i < size; i++) { + const size_t size) { + for (size_t i = 0; i < size; i++) { get_proxy(data, start + i) = ArithCompareOperator::compare( src[i], right_operand, value); @@ -375,7 +409,7 @@ struct BitWiseBitsetPolicy { // todo: check if intersect size_t active = 0; - for (size_type i = 0; i < size; i++) { + for (size_t i = 0; i < size; i++) { auto proxy_left = get_proxy(left, start_left + i); auto proxy_right = get_proxy(right, start_right + i); @@ -397,7 +431,7 @@ struct BitWiseBitsetPolicy { // todo: check if intersect size_t inactive = 0; - for (size_type i = 0; i < size; i++) { + for (size_t i = 0; i < size; i++) { auto proxy_left = get_proxy(left, start_left + i); auto proxy_right = get_proxy(right, start_right + i); diff --git a/internal/core/src/bitset/detail/element_vectorized.h b/internal/core/src/bitset/detail/element_vectorized.h index e21aca883bbbc..93668904abedd 100644 --- a/internal/core/src/bitset/detail/element_vectorized.h +++ b/internal/core/src/bitset/detail/element_vectorized.h @@ -32,53 +32,49 @@ namespace detail { template struct VectorizedElementWiseBitsetPolicy { using data_type = ElementT; - constexpr static auto data_bits = sizeof(data_type) * 8; - - using size_type = size_t; + constexpr static size_t data_bits = sizeof(data_type) * 8; using self_type = VectorizedElementWiseBitsetPolicy; using proxy_type = Proxy; using const_proxy_type = ConstProxy; - static inline size_type + static inline size_t get_element(const size_t idx) { return idx / data_bits; } - static inline size_type + static inline size_t get_shift(const size_t idx) { return idx % data_bits; } - static inline size_type + static inline size_t get_required_size_in_elements(const size_t size) { return (size + data_bits - 1) / data_bits; } - static inline size_type + static inline size_t get_required_size_in_bytes(const size_t size) { return get_required_size_in_elements(size) * sizeof(data_type); } static inline proxy_type - get_proxy(data_type* const __restrict data, const size_type idx) { + get_proxy(data_type* const __restrict data, const size_t idx) { data_type& element = data[get_element(idx)]; - const size_type shift = get_shift(idx); + const size_t shift = get_shift(idx); return proxy_type{element, shift}; } static inline const_proxy_type - get_proxy(const data_type* const __restrict data, const size_type idx) { + get_proxy(const data_type* const __restrict data, const size_t idx) { const data_type& element = data[get_element(idx)]; - const size_type shift = get_shift(idx); + const size_t shift = get_shift(idx); return const_proxy_type{element, shift}; } static inline void - op_flip(data_type* const data, - const size_type start, - const size_type size) { + op_flip(data_type* const data, const size_t start, const size_t size) { ElementWiseBitsetPolicy::op_flip(data, start, size); } @@ -88,8 +84,25 @@ struct VectorizedElementWiseBitsetPolicy { const size_t start_left, const size_t start_right, const size_t size) { - ElementWiseBitsetPolicy::op_and( - left, right, start_left, start_right, size); + if (!VectorizedT::template forward_op_and( + left, right, start_left, start_right, size)) { + ElementWiseBitsetPolicy::op_and( + left, right, start_left, start_right, size); + } + } + + static inline void + op_and_multiple(data_type* const left, + const data_type* const* const rights, + const size_t start_left, + const size_t* const __restrict start_rights, + const size_t n_rights, + const size_t size) { + if (!VectorizedT::template forward_op_and_multiple( + left, rights, start_left, start_rights, n_rights, size)) { + ElementWiseBitsetPolicy::op_and_multiple( + left, rights, start_left, start_rights, n_rights, size); + } } static inline void @@ -98,59 +111,72 @@ struct VectorizedElementWiseBitsetPolicy { const size_t start_left, const size_t start_right, const size_t size) { - ElementWiseBitsetPolicy::op_or( - left, right, start_left, start_right, size); + if (!VectorizedT::template forward_op_or( + left, right, start_left, start_right, size)) { + ElementWiseBitsetPolicy::op_or( + left, right, start_left, start_right, size); + } } static inline void - op_set(data_type* const data, const size_type start, const size_type size) { + op_or_multiple(data_type* const left, + const data_type* const* const rights, + const size_t start_left, + const size_t* const __restrict start_rights, + const size_t n_rights, + const size_t size) { + if (!VectorizedT::template forward_op_or_multiple( + left, rights, start_left, start_rights, n_rights, size)) { + ElementWiseBitsetPolicy::op_or_multiple( + left, rights, start_left, start_rights, n_rights, size); + } + } + + static inline void + op_set(data_type* const data, const size_t start, const size_t size) { ElementWiseBitsetPolicy::op_set(data, start, size); } static inline void - op_reset(data_type* const data, - const size_type start, - const size_type size) { + op_reset(data_type* const data, const size_t start, const size_t size) { ElementWiseBitsetPolicy::op_reset(data, start, size); } static inline bool - op_all(const data_type* const data, - const size_type start, - const size_type size) { + op_all(const data_type* const data, const size_t start, const size_t size) { return ElementWiseBitsetPolicy::op_all(data, start, size); } static inline bool op_none(const data_type* const data, - const size_type start, - const size_type size) { + const size_t start, + const size_t size) { return ElementWiseBitsetPolicy::op_none(data, start, size); } static void op_copy(const data_type* const src, - const size_type start_src, + const size_t start_src, data_type* const dst, - const size_type start_dst, - const size_type size) { + const size_t start_dst, + const size_t size) { ElementWiseBitsetPolicy::op_copy( src, start_src, dst, start_dst, size); } - static inline size_type + static inline size_t op_count(const data_type* const data, - const size_type start, - const size_type size) { + const size_t start, + const size_t size) { return ElementWiseBitsetPolicy::op_count(data, start, size); } static inline bool op_eq(const data_type* const left, const data_type* const right, - const size_type start_left, - const size_type start_right, - const size_type size) { + const size_t start_left, + const size_t start_right, + const size_t size) { return ElementWiseBitsetPolicy::op_eq( left, right, start_left, start_right, size); } @@ -161,8 +187,11 @@ struct VectorizedElementWiseBitsetPolicy { const size_t start_left, const size_t start_right, const size_t size) { - ElementWiseBitsetPolicy::op_xor( - left, right, start_left, start_right, size); + if (!VectorizedT::template forward_op_xor( + left, right, start_left, start_right, size)) { + ElementWiseBitsetPolicy::op_xor( + left, right, start_left, start_right, size); + } } static inline void @@ -171,24 +200,27 @@ struct VectorizedElementWiseBitsetPolicy { const size_t start_left, const size_t start_right, const size_t size) { - ElementWiseBitsetPolicy::op_sub( - left, right, start_left, start_right, size); + if (!VectorizedT::template forward_op_sub( + left, right, start_left, start_right, size)) { + ElementWiseBitsetPolicy::op_sub( + left, right, start_left, start_right, size); + } } static void op_fill(data_type* const data, - const size_type start, - const size_type size, + const size_t start, + const size_t size, const bool value) { ElementWiseBitsetPolicy::op_fill(data, start, size, value); } // - static inline std::optional + static inline std::optional op_find(const data_type* const data, - const size_type start, - const size_type size, - const size_type starting_idx) { + const size_t start, + const size_t size, + const size_t starting_idx) { return ElementWiseBitsetPolicy::op_find( data, start, size, starting_idx); } @@ -197,16 +229,16 @@ struct VectorizedElementWiseBitsetPolicy { template static inline void op_compare_column(data_type* const __restrict data, - const size_type start, + const size_t start, const T* const __restrict t, const U* const __restrict u, - const size_type size) { + const size_t size) { op_func( start, size, - [data, t, u](const size_type starting_bit, - const size_type ptr_offset, - const size_type nbits) { + [data, t, u](const size_t starting_bit, + const size_t ptr_offset, + const size_t nbits) { ElementWiseBitsetPolicy:: template op_compare_column(data, starting_bit, @@ -214,9 +246,9 @@ struct VectorizedElementWiseBitsetPolicy { u + ptr_offset, nbits); }, - [data, t, u](const size_type starting_element, - const size_type ptr_offset, - const size_type nbits) { + [data, t, u](const size_t starting_element, + const size_t ptr_offset, + const size_t nbits) { return VectorizedT::template op_compare_column( reinterpret_cast(data + starting_element), t + ptr_offset, @@ -229,23 +261,23 @@ struct VectorizedElementWiseBitsetPolicy { template static inline void op_compare_val(data_type* const __restrict data, - const size_type start, + const size_t start, const T* const __restrict t, - const size_type size, + const size_t size, const T& value) { op_func( start, size, - [data, t, value](const size_type starting_bit, - const size_type ptr_offset, - const size_type nbits) { + [data, t, value](const size_t starting_bit, + const size_t ptr_offset, + const size_t nbits) { ElementWiseBitsetPolicy::template op_compare_val( data, starting_bit, t + ptr_offset, nbits, value); }, - [data, t, value](const size_type starting_element, - const size_type ptr_offset, - const size_type nbits) { + [data, t, value](const size_t starting_element, + const size_t ptr_offset, + const size_t nbits) { return VectorizedT::template op_compare_val( reinterpret_cast(data + starting_element), t + ptr_offset, @@ -258,17 +290,17 @@ struct VectorizedElementWiseBitsetPolicy { template static inline void op_within_range_column(data_type* const __restrict data, - const size_type start, + const size_t start, const T* const __restrict lower, const T* const __restrict upper, const T* const __restrict values, - const size_type size) { + const size_t size) { op_func( start, size, - [data, lower, upper, values](const size_type starting_bit, - const size_type ptr_offset, - const size_type nbits) { + [data, lower, upper, values](const size_t starting_bit, + const size_t ptr_offset, + const size_t nbits) { ElementWiseBitsetPolicy:: template op_within_range_column(data, starting_bit, @@ -277,9 +309,9 @@ struct VectorizedElementWiseBitsetPolicy { values + ptr_offset, nbits); }, - [data, lower, upper, values](const size_type starting_element, - const size_type ptr_offset, - const size_type nbits) { + [data, lower, upper, values](const size_t starting_element, + const size_t ptr_offset, + const size_t nbits) { return VectorizedT::template op_within_range_column( reinterpret_cast(data + starting_element), lower + ptr_offset, @@ -293,17 +325,17 @@ struct VectorizedElementWiseBitsetPolicy { template static inline void op_within_range_val(data_type* const __restrict data, - const size_type start, + const size_t start, const T& lower, const T& upper, const T* const __restrict values, - const size_type size) { + const size_t size) { op_func( start, size, - [data, lower, upper, values](const size_type starting_bit, - const size_type ptr_offset, - const size_type nbits) { + [data, lower, upper, values](const size_t starting_bit, + const size_t ptr_offset, + const size_t nbits) { ElementWiseBitsetPolicy:: template op_within_range_val(data, starting_bit, @@ -312,9 +344,9 @@ struct VectorizedElementWiseBitsetPolicy { values + ptr_offset, nbits); }, - [data, lower, upper, values](const size_type starting_element, - const size_type ptr_offset, - const size_type nbits) { + [data, lower, upper, values](const size_t starting_element, + const size_t ptr_offset, + const size_t nbits) { return VectorizedT::template op_within_range_val( reinterpret_cast(data + starting_element), lower, @@ -328,17 +360,17 @@ struct VectorizedElementWiseBitsetPolicy { template static inline void op_arith_compare(data_type* const __restrict data, - const size_type start, + const size_t start, const T* const __restrict src, const ArithHighPrecisionType& right_operand, const ArithHighPrecisionType& value, - const size_type size) { + const size_t size) { op_func( start, size, - [data, src, right_operand, value](const size_type starting_bit, - const size_type ptr_offset, - const size_type nbits) { + [data, src, right_operand, value](const size_t starting_bit, + const size_t ptr_offset, + const size_t nbits) { ElementWiseBitsetPolicy:: template op_arith_compare(data, starting_bit, @@ -347,9 +379,9 @@ struct VectorizedElementWiseBitsetPolicy { value, nbits); }, - [data, src, right_operand, value](const size_type starting_element, - const size_type ptr_offset, - const size_type nbits) { + [data, src, right_operand, value](const size_t starting_element, + const size_t ptr_offset, + const size_t nbits) { return VectorizedT::template op_arith_compare( reinterpret_cast(data + starting_element), src + ptr_offset, @@ -380,12 +412,12 @@ struct VectorizedElementWiseBitsetPolicy { left, right, start_left, start_right, size); } - // void FuncBaseline(const size_t starting_bit, const size_type ptr_offset, const size_type nbits) - // bool FuncVectorized(const size_type starting_element, const size_type ptr_offset, const size_type nbits) + // void FuncBaseline(const size_t starting_bit, const size_t ptr_offset, const size_t nbits) + // bool FuncVectorized(const size_t starting_element, const size_t ptr_offset, const size_t nbits) template static inline void - op_func(const size_type start, - const size_type size, + op_func(const size_t start, + const size_t size, FuncBaseline func_baseline, FuncVectorized func_vectorized) { if (size == 0) { diff --git a/internal/core/src/bitset/detail/element_wise.h b/internal/core/src/bitset/detail/element_wise.h index 62e49b5a93ae1..91b40692063b8 100644 --- a/internal/core/src/bitset/detail/element_wise.h +++ b/internal/core/src/bitset/detail/element_wise.h @@ -25,6 +25,8 @@ #include "ctz.h" #include "popcount.h" +#include "maybe_vector.h" + namespace milvus { namespace bitset { namespace detail { @@ -33,53 +35,51 @@ namespace detail { template struct ElementWiseBitsetPolicy { using data_type = ElementT; - constexpr static auto data_bits = sizeof(data_type) * 8; - - using size_type = size_t; + constexpr static size_t data_bits = sizeof(data_type) * 8; using self_type = ElementWiseBitsetPolicy; using proxy_type = Proxy; using const_proxy_type = ConstProxy; - static inline size_type + static inline size_t get_element(const size_t idx) { return idx / data_bits; } - static inline size_type + static inline size_t get_shift(const size_t idx) { return idx % data_bits; } - static inline size_type + static inline size_t get_required_size_in_elements(const size_t size) { return (size + data_bits - 1) / data_bits; } - static inline size_type + static inline size_t get_required_size_in_bytes(const size_t size) { return get_required_size_in_elements(size) * sizeof(data_type); } static inline proxy_type - get_proxy(data_type* const __restrict data, const size_type idx) { + get_proxy(data_type* const __restrict data, const size_t idx) { data_type& element = data[get_element(idx)]; - const size_type shift = get_shift(idx); + const size_t shift = get_shift(idx); return proxy_type{element, shift}; } static inline const_proxy_type - get_proxy(const data_type* const __restrict data, const size_type idx) { + get_proxy(const data_type* const __restrict data, const size_t idx) { const data_type& element = data[get_element(idx)]; - const size_type shift = get_shift(idx); + const size_t shift = get_shift(idx); return const_proxy_type{element, shift}; } static inline data_type op_read(const data_type* const data, - const size_type start, - const size_type nbits) { + const size_t start, + const size_t nbits) { if (nbits == 0) { return 0; } @@ -120,8 +120,8 @@ struct ElementWiseBitsetPolicy { static inline void op_write(data_type* const data, - const size_type start, - const size_type nbits, + const size_t start, + const size_t nbits, const data_type value) { if (nbits == 0) { return; @@ -168,9 +168,7 @@ struct ElementWiseBitsetPolicy { } static inline void - op_flip(data_type* const data, - const size_type start, - const size_type size) { + op_flip(data_type* const data, const size_t start, const size_t size) { if (size == 0) { return; } @@ -210,7 +208,7 @@ struct ElementWiseBitsetPolicy { } // process the middle - for (size_type i = start_element; i < end_element; i++) { + for (size_t i = start_element; i < end_element; i++) { data[i] = ~data[i]; } @@ -227,7 +225,7 @@ struct ElementWiseBitsetPolicy { } } - static inline void + static BITSET_ALWAYS_INLINE inline void op_and(data_type* const left, const data_type* const right, const size_t start_left, @@ -243,7 +241,25 @@ struct ElementWiseBitsetPolicy { }); } - static inline void + static BITSET_ALWAYS_INLINE inline void + op_and_multiple(data_type* const left, + const data_type* const* const rights, + const size_t start_left, + const size_t* const __restrict start_rights, + const size_t n_rights, + const size_t size) { + op_func(left, + rights, + start_left, + start_rights, + n_rights, + size, + [](const data_type left_v, const data_type right_v) { + return left_v & right_v; + }); + } + + static BITSET_ALWAYS_INLINE inline void op_or(data_type* const left, const data_type* const right, const size_t start_left, @@ -259,8 +275,26 @@ struct ElementWiseBitsetPolicy { }); } + static BITSET_ALWAYS_INLINE inline void + op_or_multiple(data_type* const left, + const data_type* const* const rights, + const size_t start_left, + const size_t* const __restrict start_rights, + const size_t n_rights, + const size_t size) { + op_func(left, + rights, + start_left, + start_rights, + n_rights, + size, + [](const data_type left_v, const data_type right_v) { + return left_v | right_v; + }); + } + static inline data_type - get_shift_mask_begin(const size_type shift) { + get_shift_mask_begin(const size_t shift) { // 0 -> 0b00000000 // 1 -> 0b00000001 // 2 -> 0b00000011 @@ -272,7 +306,7 @@ struct ElementWiseBitsetPolicy { } static inline data_type - get_shift_mask_end(const size_type shift) { + get_shift_mask_end(const size_t shift) { // 0 -> 0b11111111 // 1 -> 0b11111110 // 2 -> 0b11111100 @@ -280,21 +314,17 @@ struct ElementWiseBitsetPolicy { } static inline void - op_set(data_type* const data, const size_type start, const size_type size) { + op_set(data_type* const data, const size_t start, const size_t size) { op_fill(data, start, size, true); } static inline void - op_reset(data_type* const data, - const size_type start, - const size_type size) { + op_reset(data_type* const data, const size_t start, const size_t size) { op_fill(data, start, size, false); } static inline bool - op_all(const data_type* const data, - const size_type start, - const size_type size) { + op_all(const data_type* const data, const size_t start, const size_t size) { if (size == 0) { return true; } @@ -328,7 +358,7 @@ struct ElementWiseBitsetPolicy { } // process the middle - for (size_type i = start_element; i < end_element; i++) { + for (size_t i = start_element; i < end_element; i++) { if (data[i] != data_type(-1)) { return false; } @@ -350,8 +380,8 @@ struct ElementWiseBitsetPolicy { static inline bool op_none(const data_type* const data, - const size_type start, - const size_type size) { + const size_t start, + const size_t size) { if (size == 0) { return true; } @@ -385,7 +415,7 @@ struct ElementWiseBitsetPolicy { } // process the middle - for (size_type i = start_element; i < end_element; i++) { + for (size_t i = start_element; i < end_element; i++) { if (data[i] != data_type(0)) { return false; } @@ -407,27 +437,27 @@ struct ElementWiseBitsetPolicy { static void op_copy(const data_type* const src, - const size_type start_src, + const size_t start_src, data_type* const dst, - const size_type start_dst, - const size_type size) { + const size_t start_dst, + const size_t size) { if (size == 0) { return; } // process big blocks - const size_type size_b = (size / data_bits) * data_bits; + const size_t size_b = (size / data_bits) * data_bits; if ((start_src % data_bits) == 0) { if ((start_dst % data_bits) == 0) { // plain memcpy - for (size_type i = 0; i < size_b; i += data_bits) { + for (size_t i = 0; i < size_b; i += data_bits) { const data_type src_v = src[(start_src + i) / data_bits]; dst[(start_dst + i) / data_bits] = src_v; } } else { // easier read - for (size_type i = 0; i < size_b; i += data_bits) { + for (size_t i = 0; i < size_b; i += data_bits) { const data_type src_v = src[(start_src + i) / data_bits]; op_write(dst, start_dst + i, data_bits, src_v); } @@ -435,14 +465,14 @@ struct ElementWiseBitsetPolicy { } else { if ((start_dst % data_bits) == 0) { // easier write - for (size_type i = 0; i < size_b; i += data_bits) { + for (size_t i = 0; i < size_b; i += data_bits) { const data_type src_v = op_read(src, start_src + i, data_bits); dst[(start_dst + i) / data_bits] = src_v; } } else { // general case - for (size_type i = 0; i < size_b; i += data_bits) { + for (size_t i = 0; i < size_b; i += data_bits) { const data_type src_v = op_read(src, start_src + i, data_bits); op_write(dst, start_dst + i, data_bits, src_v); @@ -460,8 +490,8 @@ struct ElementWiseBitsetPolicy { static void op_fill(data_type* const data, - const size_type start, - const size_type size, + const size_t start, + const size_t size, const bool value) { if (size == 0) { return; @@ -503,7 +533,7 @@ struct ElementWiseBitsetPolicy { } // process the middle - for (size_type i = start_element; i < end_element; i++) { + for (size_t i = start_element; i < end_element; i++) { data[i] = new_v; } @@ -519,15 +549,15 @@ struct ElementWiseBitsetPolicy { } } - static inline size_type + static inline size_t op_count(const data_type* const data, - const size_type start, - const size_type size) { + const size_t start, + const size_t size) { if (size == 0) { return 0; } - size_type count = 0; + size_t count = 0; auto start_element = get_element(start); const auto end_element = get_element(start + size); @@ -557,7 +587,7 @@ struct ElementWiseBitsetPolicy { } // process the middle - for (size_type i = start_element; i < end_element; i++) { + for (size_t i = start_element; i < end_element; i++) { count += PopCountHelper::count(data[i]); } @@ -576,24 +606,23 @@ struct ElementWiseBitsetPolicy { static inline bool op_eq(const data_type* const left, const data_type* const right, - const size_type start_left, - const size_type start_right, - const size_type size) { + const size_t start_left, + const size_t start_right, + const size_t size) { if (size == 0) { return true; } // process big chunks - const size_type size_b = (size / data_bits) * data_bits; + const size_t size_b = (size / data_bits) * data_bits; if ((start_left % data_bits) == 0) { if ((start_right % data_bits) == 0) { // plain "memcpy" - size_type start_left_idx = start_left / data_bits; - size_type start_right_idx = start_right / data_bits; + size_t start_left_idx = start_left / data_bits; + size_t start_right_idx = start_right / data_bits; - for (size_type i = 0, j = 0; i < size_b; - i += data_bits, j += 1) { + for (size_t i = 0, j = 0; i < size_b; i += data_bits, j += 1) { const data_type left_v = left[start_left_idx + j]; const data_type right_v = right[start_right_idx + j]; if (left_v != right_v) { @@ -602,10 +631,9 @@ struct ElementWiseBitsetPolicy { } } else { // easier left - size_type start_left_idx = start_left / data_bits; + size_t start_left_idx = start_left / data_bits; - for (size_type i = 0, j = 0; i < size_b; - i += data_bits, j += 1) { + for (size_t i = 0, j = 0; i < size_b; i += data_bits, j += 1) { const data_type left_v = left[start_left_idx + j]; const data_type right_v = op_read(right, start_right + i, data_bits); @@ -617,10 +645,9 @@ struct ElementWiseBitsetPolicy { } else { if ((start_right % data_bits) == 0) { // easier right - size_type start_right_idx = start_right / data_bits; + size_t start_right_idx = start_right / data_bits; - for (size_type i = 0, j = 0; i < size_b; - i += data_bits, j += 1) { + for (size_t i = 0, j = 0; i < size_b; i += data_bits, j += 1) { const data_type left_v = op_read(left, start_left + i, data_bits); const data_type right_v = right[start_right_idx + j]; @@ -630,7 +657,7 @@ struct ElementWiseBitsetPolicy { } } else { // general case - for (size_type i = 0; i < size_b; i += data_bits) { + for (size_t i = 0; i < size_b; i += data_bits) { const data_type left_v = op_read(left, start_left + i, data_bits); const data_type right_v = @@ -656,7 +683,7 @@ struct ElementWiseBitsetPolicy { return true; } - static inline void + static BITSET_ALWAYS_INLINE inline void op_xor(data_type* const left, const data_type* const right, const size_t start_left, @@ -672,7 +699,7 @@ struct ElementWiseBitsetPolicy { }); } - static inline void + static BITSET_ALWAYS_INLINE inline void op_sub(data_type* const left, const data_type* const right, const size_t start_left, @@ -689,11 +716,11 @@ struct ElementWiseBitsetPolicy { } // - static inline std::optional + static inline std::optional op_find(const data_type* const data, - const size_type start, - const size_type size, - const size_type starting_idx) { + const size_t start, + const size_t size, + const size_t starting_idx) { if (size == 0) { return std::nullopt; } @@ -705,7 +732,7 @@ struct ElementWiseBitsetPolicy { const auto start_shift = get_shift(start + starting_idx); const auto end_shift = get_shift(start + size); - size_type extra_offset = 0; + size_t extra_offset = 0; // same element? if (start_element == end_element) { @@ -717,7 +744,7 @@ struct ElementWiseBitsetPolicy { const data_type value = existing_v & existing_mask; if (value != 0) { const auto ctz = CtzHelper::ctz(value); - return size_type(ctz) + start_element * data_bits - start; + return size_t(ctz) + start_element * data_bits - start; } else { return std::nullopt; } @@ -732,7 +759,7 @@ struct ElementWiseBitsetPolicy { if (value != 0) { const auto ctz = CtzHelper::ctz(value) + start_element * data_bits - start; - return size_type(ctz); + return size_t(ctz); } start_element += 1; @@ -740,11 +767,11 @@ struct ElementWiseBitsetPolicy { } // process the middle - for (size_type i = start_element; i < end_element; i++) { + for (size_t i = start_element; i < end_element; i++) { const data_type value = data[i]; if (value != 0) { const auto ctz = CtzHelper::ctz(value); - return size_type(ctz) + i * data_bits - start; + return size_t(ctz) + i * data_bits - start; } } @@ -756,7 +783,7 @@ struct ElementWiseBitsetPolicy { const data_type value = existing_v & existing_mask; if (value != 0) { const auto ctz = CtzHelper::ctz(value); - return size_type(ctz) + end_element * data_bits - start; + return size_t(ctz) + end_element * data_bits - start; } } @@ -767,11 +794,11 @@ struct ElementWiseBitsetPolicy { template static inline void op_compare_column(data_type* const __restrict data, - const size_type start, + const size_t start, const T* const __restrict t, const U* const __restrict u, - const size_type size) { - op_func(data, start, size, [t, u](const size_type bit_idx) { + const size_t size) { + op_func(data, start, size, [t, u](const size_t bit_idx) { return CompareOperator::compare(t[bit_idx], u[bit_idx]); }); } @@ -780,11 +807,11 @@ struct ElementWiseBitsetPolicy { template static inline void op_compare_val(data_type* const __restrict data, - const size_type start, + const size_t start, const T* const __restrict t, - const size_type size, + const size_t size, const T& value) { - op_func(data, start, size, [t, value](const size_type bit_idx) { + op_func(data, start, size, [t, value](const size_t bit_idx) { return CompareOperator::compare(t[bit_idx], value); }); } @@ -793,13 +820,13 @@ struct ElementWiseBitsetPolicy { template static inline void op_within_range_column(data_type* const __restrict data, - const size_type start, + const size_t start, const T* const __restrict lower, const T* const __restrict upper, const T* const __restrict values, - const size_type size) { + const size_t size) { op_func( - data, start, size, [lower, upper, values](const size_type bit_idx) { + data, start, size, [lower, upper, values](const size_t bit_idx) { return RangeOperator::within_range( lower[bit_idx], upper[bit_idx], values[bit_idx]); }); @@ -809,13 +836,13 @@ struct ElementWiseBitsetPolicy { template static inline void op_within_range_val(data_type* const __restrict data, - const size_type start, + const size_t start, const T& lower, const T& upper, const T* const __restrict values, - const size_type size) { + const size_t size) { op_func( - data, start, size, [lower, upper, values](const size_type bit_idx) { + data, start, size, [lower, upper, values](const size_t bit_idx) { return RangeOperator::within_range( lower, upper, values[bit_idx]); }); @@ -825,15 +852,15 @@ struct ElementWiseBitsetPolicy { template static inline void op_arith_compare(data_type* const __restrict data, - const size_type start, + const size_t start, const T* const __restrict src, const ArithHighPrecisionType& right_operand, const ArithHighPrecisionType& value, - const size_type size) { + const size_t size) { op_func(data, start, size, - [src, right_operand, value](const size_type bit_idx) { + [src, right_operand, value](const size_t bit_idx) { return ArithCompareOperator::compare( src[bit_idx], right_operand, value); }); @@ -889,7 +916,7 @@ struct ElementWiseBitsetPolicy { // data_type Func(const data_type left_v, const data_type right_v); template - static inline void + static BITSET_ALWAYS_INLINE inline void op_func(data_type* const left, const data_type* const right, const size_t start_left, @@ -901,16 +928,15 @@ struct ElementWiseBitsetPolicy { } // process big blocks - const size_type size_b = (size / data_bits) * data_bits; + const size_t size_b = (size / data_bits) * data_bits; if ((start_left % data_bits) == 0) { if ((start_right % data_bits) == 0) { // plain "memcpy". // A compiler auto-vectorization is expected. - size_type start_left_idx = start_left / data_bits; - size_type start_right_idx = start_right / data_bits; + size_t start_left_idx = start_left / data_bits; + size_t start_right_idx = start_right / data_bits; - for (size_type i = 0, j = 0; i < size_b; - i += data_bits, j += 1) { + for (size_t i = 0, j = 0; i < size_b; i += data_bits, j += 1) { data_type& left_v = left[start_left_idx + j]; const data_type right_v = right[start_right_idx + j]; @@ -919,10 +945,9 @@ struct ElementWiseBitsetPolicy { } } else { // easier read - size_type start_right_idx = start_right / data_bits; + size_t start_right_idx = start_right / data_bits; - for (size_type i = 0, j = 0; i < size_b; - i += data_bits, j += 1) { + for (size_t i = 0, j = 0; i < size_b; i += data_bits, j += 1) { const data_type left_v = op_read(left, start_left + i, data_bits); const data_type right_v = right[start_right_idx + j]; @@ -934,10 +959,9 @@ struct ElementWiseBitsetPolicy { } else { if ((start_right % data_bits) == 0) { // easier write - size_type start_left_idx = start_left / data_bits; + size_t start_left_idx = start_left / data_bits; - for (size_type i = 0, j = 0; i < size_b; - i += data_bits, j += 1) { + for (size_t i = 0, j = 0; i < size_b; i += data_bits, j += 1) { data_type& left_v = left[start_left_idx + j]; const data_type right_v = op_read(right, start_right + i, data_bits); @@ -947,14 +971,14 @@ struct ElementWiseBitsetPolicy { } } else { // general case - for (size_type i = 0; i < size_b; i += data_bits) { + for (size_t i = 0; i < size_b; i += data_bits) { const data_type left_v = op_read(left, start_left + i, data_bits); const data_type right_v = op_read(right, start_right + i, data_bits); const data_type result_v = func(left_v, right_v); - op_write(left, start_right + i, data_bits, result_v); + op_write(left, start_left + i, data_bits, result_v); } } } @@ -971,11 +995,145 @@ struct ElementWiseBitsetPolicy { } } - // bool Func(const size_type bit_idx); + // data_type Func(const data_type left_v, const data_type right_v); template - static inline void + static BITSET_ALWAYS_INLINE inline void + op_func(data_type* const left, + const data_type* const* const rights, + const size_t start_left, + const size_t* const __restrict start_rights, + const size_t n_rights, + const size_t size, + Func func) { + if (size == 0 || n_rights == 0) { + return; + } + + if (n_rights == 1) { + op_func( + left, rights[0], start_left, start_rights[0], size, func); + return; + } + + // process big blocks + const size_t size_b = (size / data_bits) * data_bits; + + // check a specific case + bool all_aligned = true; + for (size_t i = 0; i < n_rights; i++) { + if (start_rights[i] % data_bits != 0) { + all_aligned = false; + break; + } + } + + // all are aligned + if (all_aligned) { + MaybeVector tmp(n_rights); + for (size_t i = 0; i < n_rights; i++) { + tmp[i] = rights[i] + (start_rights[i] / data_bits); + } + + // plain "memcpy". + // A compiler auto-vectorization is expected. + const size_t start_left_idx = start_left / data_bits; + data_type* left_ptr = left + start_left_idx; + + auto unrolled = [left_ptr, &tmp, func, size_b](const size_t count) { + for (size_t i = 0, j = 0; i < size_b; i += data_bits, j += 1) { + data_type& left_v = left_ptr[j]; + data_type value = left_v; + + for (size_t k = 0; k < count; k++) { + const data_type right_v = tmp[k][j]; + + value = func(value, right_v); + } + + left_v = value; + } + }; + + switch (n_rights) { + // case 1: unrolled(1); break; + case 2: + unrolled(2); + break; + case 3: + unrolled(3); + break; + case 4: + unrolled(4); + break; + case 5: + unrolled(5); + break; + case 6: + unrolled(6); + break; + case 7: + unrolled(7); + break; + case 8: + unrolled(8); + break; + default: { + for (size_t i = 0, j = 0; i < size_b; + i += data_bits, j += 1) { + data_type& left_v = left_ptr[j]; + data_type value = left_v; + + for (size_t k = 0; k < n_rights; k++) { + const data_type right_v = tmp[k][j]; + + value = func(value, right_v); + } + + left_v = value; + } + } + } + + } else { + // general case. Unoptimized. + for (size_t i = 0; i < size_b; i += data_bits) { + const data_type left_v = + op_read(left, start_left + i, data_bits); + + data_type value = left_v; + for (size_t k = 0; k < n_rights; k++) { + const data_type right_v = + op_read(rights[k], start_rights[k] + i, data_bits); + + value = func(value, right_v); + } + + op_write(left, start_left + i, data_bits, value); + } + } + + // process leftovers + if (size_b != size) { + const data_type left_v = + op_read(left, start_left + size_b, size - size_b); + + data_type value = left_v; + for (size_t k = 0; k < n_rights; k++) { + const data_type right_v = + op_read(rights[k], start_rights[k] + size_b, size - size_b); + + value = func(value, right_v); + } + + op_write(left, start_left + size_b, size - size_b, value); + } + } + + // bool Func(const size_t bit_idx); + template + static BITSET_ALWAYS_INLINE inline void op_func(data_type* const __restrict data, - const size_type start, + const size_t start, const size_t size, Func func) { if (size == 0) { @@ -990,7 +1148,7 @@ struct ElementWiseBitsetPolicy { if (start_element == end_element) { data_type bits = 0; - for (size_type j = 0; j < size; j++) { + for (size_t j = 0; j < size; j++) { const bool bit = func(j); // // a curious example where the compiler does not optimize the code properly // bits |= (bit ? (data_type(1) << j) : 0); @@ -1008,10 +1166,10 @@ struct ElementWiseBitsetPolicy { // process the first element if (start_shift != 0) { - const size_type n_bits = data_bits - start_shift; + const size_t n_bits = data_bits - start_shift; data_type bits = 0; - for (size_type j = 0; j < n_bits; j++) { + for (size_t j = 0; j < n_bits; j++) { const bool bit = func(j); bits |= (data_type(bit ? 1 : 0) << j); } @@ -1025,9 +1183,9 @@ struct ElementWiseBitsetPolicy { // process the middle { - for (size_type i = start_element; i < end_element; i++) { + for (size_t i = start_element; i < end_element; i++) { data_type bits = 0; - for (size_type j = 0; j < data_bits; j++) { + for (size_t j = 0; j < data_bits; j++) { const bool bit = func(ptr_offset + j); bits |= (data_type(bit ? 1 : 0) << j); } @@ -1040,7 +1198,7 @@ struct ElementWiseBitsetPolicy { // process the last element if (end_shift != 0) { data_type bits = 0; - for (size_type j = 0; j < end_shift; j++) { + for (size_t j = 0; j < end_shift; j++) { const bool bit = func(ptr_offset + j); bits |= (data_type(bit ? 1 : 0) << j); } diff --git a/internal/core/src/bitset/detail/maybe_vector.h b/internal/core/src/bitset/detail/maybe_vector.h new file mode 100644 index 0000000000000..b9770971f20d4 --- /dev/null +++ b/internal/core/src/bitset/detail/maybe_vector.h @@ -0,0 +1,91 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include + +namespace milvus { +namespace bitset { +namespace detail { + +// A structure that allocates an array of elements. +// No ownership is implied. +// If the number of elements is small, +// then an allocation will be done on the stack. +// If the number of elements is large, +// then an allocation will be done on the heap. +template +struct MaybeVector { + public: + static_assert(std::is_scalar_v); + + static constexpr size_t num_array_elements = 64; + + std::unique_ptr maybe_memory; + std::array maybe_array; + + MaybeVector(const size_t n_elements) { + m_size = n_elements; + + if (n_elements < num_array_elements) { + m_data = maybe_array.data(); + } else { + maybe_memory = std::make_unique(m_size); + m_data = maybe_memory.get(); + } + } + + MaybeVector(const MaybeVector&) = delete; + MaybeVector(MaybeVector&&) = delete; + MaybeVector& + operator=(const MaybeVector&) = delete; + MaybeVector& + operator=(MaybeVector&&) = delete; + + inline size_t + size() const { + return m_size; + } + inline T* + data() { + return m_data; + } + inline const T* + data() const { + return m_data; + } + + inline T& + operator[](const size_t idx) { + return m_data[idx]; + } + inline const T& + operator[](const size_t idx) const { + return m_data[idx]; + } + + private: + size_t m_size = 0; + + T* m_data = nullptr; +}; + +} // namespace detail +} // namespace bitset +} // namespace milvus diff --git a/internal/core/src/bitset/detail/platform/arm/neon-decl.h b/internal/core/src/bitset/detail/platform/arm/neon-decl.h index c92bb37c0fc45..ca0dd20fef734 100644 --- a/internal/core/src/bitset/detail/platform/arm/neon-decl.h +++ b/internal/core/src/bitset/detail/platform/arm/neon-decl.h @@ -39,6 +39,11 @@ namespace neon { FUNC(float); \ FUNC(double); +// a facility to run through all acceptable forward types +#define ALL_FORWARD_TYPES_1(FUNC) \ + FUNC(uint8_t); \ + FUNC(uint64_t); + /////////////////////////////////////////////////////////////////////////// // the default implementation does nothing @@ -192,7 +197,122 @@ ALL_DATATYPES_1(DECLARE_PARTIAL_OP_ARITH_COMPARE) /////////////////////////////////////////////////////////////////////////// +// forward ops +template +struct ForwardOpsImpl { + static inline bool + op_and(ElementT* const left, + const ElementT* const right, + const size_t start_left, + const size_t start_right, + const size_t size) { + return false; + } + + static inline bool + op_and_multiple(ElementT* const left, + const ElementT* const* const rights, + const size_t start_left, + const size_t* const __restrict start_rights, + const size_t n_rights, + const size_t size) { + return false; + } + + static inline bool + op_or(ElementT* const left, + const ElementT* const right, + const size_t start_left, + const size_t start_right, + const size_t size) { + return false; + } + + static inline bool + op_or_multiple(ElementT* const left, + const ElementT* const* const rights, + const size_t start_left, + const size_t* const __restrict start_rights, + const size_t n_rights, + const size_t size) { + return false; + } + + static inline bool + op_xor(ElementT* const left, + const ElementT* const right, + const size_t start_left, + const size_t start_right, + const size_t size) { + return false; + } + + static inline bool + op_sub(ElementT* const left, + const ElementT* const right, + const size_t start_left, + const size_t start_right, + const size_t size) { + return false; + } +}; + +#define DECLARE_PARTIAL_FORWARD_OPS(ELEMENTTYPE) \ + template <> \ + struct ForwardOpsImpl { \ + static bool \ + op_and(ELEMENTTYPE* const left, \ + const ELEMENTTYPE* const right, \ + const size_t start_left, \ + const size_t start_right, \ + const size_t size); \ + \ + static bool \ + op_and_multiple(ELEMENTTYPE* const left, \ + const ELEMENTTYPE* const* const rights, \ + const size_t start_left, \ + const size_t* const __restrict start_rights, \ + const size_t n_rights, \ + const size_t size); \ + \ + static bool \ + op_or(ELEMENTTYPE* const left, \ + const ELEMENTTYPE* const right, \ + const size_t start_left, \ + const size_t start_right, \ + const size_t size); \ + \ + static bool \ + op_or_multiple(ELEMENTTYPE* const left, \ + const ELEMENTTYPE* const* const rights, \ + const size_t start_left, \ + const size_t* const __restrict start_rights, \ + const size_t n_rights, \ + const size_t size); \ + \ + static bool \ + op_sub(ELEMENTTYPE* const left, \ + const ELEMENTTYPE* const right, \ + const size_t start_left, \ + const size_t start_right, \ + const size_t size); \ + \ + static bool \ + op_xor(ELEMENTTYPE* const left, \ + const ELEMENTTYPE* const right, \ + const size_t start_left, \ + const size_t start_right, \ + const size_t size); \ + }; + +ALL_FORWARD_TYPES_1(DECLARE_PARTIAL_FORWARD_OPS) + +#undef DECLARE_PARTIAL_FORWARD_OPS + +/////////////////////////////////////////////////////////////////////////// + #undef ALL_DATATYPES_1 +#undef ALL_FORWARD_TYPES_1 } // namespace neon } // namespace arm diff --git a/internal/core/src/bitset/detail/platform/arm/neon-impl.h b/internal/core/src/bitset/detail/platform/arm/neon-impl.h index 0547665d9f6ce..b8423272dc048 100644 --- a/internal/core/src/bitset/detail/platform/arm/neon-impl.h +++ b/internal/core/src/bitset/detail/platform/arm/neon-impl.h @@ -28,6 +28,7 @@ #include "neon-decl.h" #include "bitset/common.h" +#include "bitset/detail/element_wise.h" namespace milvus { namespace bitset { @@ -1810,6 +1811,151 @@ OpArithCompareImpl::op_arith_compare( } } +/////////////////////////////////////////////////////////////////////////// +// forward ops + +// +bool +ForwardOpsImpl::op_and(uint8_t* const left, + const uint8_t* const right, + const size_t start_left, + const size_t start_right, + const size_t size) { + ElementWiseBitsetPolicy::op_and( + left, right, start_left, start_right, size); + return true; +} + +bool +ForwardOpsImpl::op_and_multiple( + uint8_t* const left, + const uint8_t* const* const rights, + const size_t start_left, + const size_t* const __restrict start_rights, + const size_t n_rights, + const size_t size) { + ElementWiseBitsetPolicy::op_and_multiple( + left, rights, start_left, start_rights, n_rights, size); + return true; +} + +bool +ForwardOpsImpl::op_or(uint8_t* const left, + const uint8_t* const right, + const size_t start_left, + const size_t start_right, + const size_t size) { + ElementWiseBitsetPolicy::op_or( + left, right, start_left, start_right, size); + return true; +} + +bool +ForwardOpsImpl::op_or_multiple( + uint8_t* const left, + const uint8_t* const* const rights, + const size_t start_left, + const size_t* const __restrict start_rights, + const size_t n_rights, + const size_t size) { + ElementWiseBitsetPolicy::op_or_multiple( + left, rights, start_left, start_rights, n_rights, size); + return true; +} + +bool +ForwardOpsImpl::op_xor(uint8_t* const left, + const uint8_t* const right, + const size_t start_left, + const size_t start_right, + const size_t size) { + ElementWiseBitsetPolicy::op_xor( + left, right, start_left, start_right, size); + return true; +} + +bool +ForwardOpsImpl::op_sub(uint8_t* const left, + const uint8_t* const right, + const size_t start_left, + const size_t start_right, + const size_t size) { + ElementWiseBitsetPolicy::op_sub( + left, right, start_left, start_right, size); + return true; +} + +// +bool +ForwardOpsImpl::op_and(uint64_t* const left, + const uint64_t* const right, + const size_t start_left, + const size_t start_right, + const size_t size) { + ElementWiseBitsetPolicy::op_and( + left, right, start_left, start_right, size); + return true; +} + +bool +ForwardOpsImpl::op_and_multiple( + uint64_t* const left, + const uint64_t* const* const rights, + const size_t start_left, + const size_t* const __restrict start_rights, + const size_t n_rights, + const size_t size) { + ElementWiseBitsetPolicy::op_and_multiple( + left, rights, start_left, start_rights, n_rights, size); + return true; +} + +bool +ForwardOpsImpl::op_or(uint64_t* const left, + const uint64_t* const right, + const size_t start_left, + const size_t start_right, + const size_t size) { + ElementWiseBitsetPolicy::op_or( + left, right, start_left, start_right, size); + return true; +} + +bool +ForwardOpsImpl::op_or_multiple( + uint64_t* const left, + const uint64_t* const* const rights, + const size_t start_left, + const size_t* const __restrict start_rights, + const size_t n_rights, + const size_t size) { + ElementWiseBitsetPolicy::op_or_multiple( + left, rights, start_left, start_rights, n_rights, size); + return true; +} + +bool +ForwardOpsImpl::op_xor(uint64_t* const left, + const uint64_t* const right, + const size_t start_left, + const size_t start_right, + const size_t size) { + ElementWiseBitsetPolicy::op_xor( + left, right, start_left, start_right, size); + return true; +} + +bool +ForwardOpsImpl::op_sub(uint64_t* const left, + const uint64_t* const right, + const size_t start_left, + const size_t start_right, + const size_t size) { + ElementWiseBitsetPolicy::op_sub( + left, right, start_left, start_right, size); + return true; +} + /////////////////////////////////////////////////////////////////////////// } // namespace neon diff --git a/internal/core/src/bitset/detail/platform/arm/neon.h b/internal/core/src/bitset/detail/platform/arm/neon.h index 004547506e405..4fdd6688b672e 100644 --- a/internal/core/src/bitset/detail/platform/arm/neon.h +++ b/internal/core/src/bitset/detail/platform/arm/neon.h @@ -55,6 +55,30 @@ struct VectorizedNeon { template static constexpr inline auto op_arith_compare = neon::OpArithCompareImpl::op_arith_compare; + + template + static constexpr inline auto forward_op_and = + neon::ForwardOpsImpl::op_and; + + template + static constexpr inline auto forward_op_and_multiple = + neon::ForwardOpsImpl::op_and_multiple; + + template + static constexpr inline auto forward_op_or = + neon::ForwardOpsImpl::op_or; + + template + static constexpr inline auto forward_op_or_multiple = + neon::ForwardOpsImpl::op_or_multiple; + + template + static constexpr inline auto forward_op_xor = + neon::ForwardOpsImpl::op_xor; + + template + static constexpr inline auto forward_op_sub = + neon::ForwardOpsImpl::op_sub; }; } // namespace arm diff --git a/internal/core/src/bitset/detail/platform/arm/sve-decl.h b/internal/core/src/bitset/detail/platform/arm/sve-decl.h index f563041e15054..b1a346d1e6d65 100644 --- a/internal/core/src/bitset/detail/platform/arm/sve-decl.h +++ b/internal/core/src/bitset/detail/platform/arm/sve-decl.h @@ -39,6 +39,11 @@ namespace sve { FUNC(float); \ FUNC(double); +// a facility to run through all acceptable forward types +#define ALL_FORWARD_TYPES_1(FUNC) \ + FUNC(uint8_t); \ + FUNC(uint64_t); + /////////////////////////////////////////////////////////////////////////// // the default implementation does nothing @@ -192,7 +197,122 @@ ALL_DATATYPES_1(DECLARE_PARTIAL_OP_ARITH_COMPARE) /////////////////////////////////////////////////////////////////////////// +// forward ops +template +struct ForwardOpsImpl { + static inline bool + op_and(ElementT* const left, + const ElementT* const right, + const size_t start_left, + const size_t start_right, + const size_t size) { + return false; + } + + static inline bool + op_and_multiple(ElementT* const left, + const ElementT* const* const rights, + const size_t start_left, + const size_t* const __restrict start_rights, + const size_t n_rights, + const size_t size) { + return false; + } + + static inline bool + op_or(ElementT* const left, + const ElementT* const right, + const size_t start_left, + const size_t start_right, + const size_t size) { + return false; + } + + static inline bool + op_or_multiple(ElementT* const left, + const ElementT* const* const rights, + const size_t start_left, + const size_t* const __restrict start_rights, + const size_t n_rights, + const size_t size) { + return false; + } + + static inline bool + op_xor(ElementT* const left, + const ElementT* const right, + const size_t start_left, + const size_t start_right, + const size_t size) { + return false; + } + + static inline bool + op_sub(ElementT* const left, + const ElementT* const right, + const size_t start_left, + const size_t start_right, + const size_t size) { + return false; + } +}; + +#define DECLARE_PARTIAL_FORWARD_OPS(ELEMENTTYPE) \ + template <> \ + struct ForwardOpsImpl { \ + static bool \ + op_and(ELEMENTTYPE* const left, \ + const ELEMENTTYPE* const right, \ + const size_t start_left, \ + const size_t start_right, \ + const size_t size); \ + \ + static bool \ + op_and_multiple(ELEMENTTYPE* const left, \ + const ELEMENTTYPE* const* const rights, \ + const size_t start_left, \ + const size_t* const __restrict start_rights, \ + const size_t n_rights, \ + const size_t size); \ + \ + static bool \ + op_or(ELEMENTTYPE* const left, \ + const ELEMENTTYPE* const right, \ + const size_t start_left, \ + const size_t start_right, \ + const size_t size); \ + \ + static bool \ + op_or_multiple(ELEMENTTYPE* const left, \ + const ELEMENTTYPE* const* const rights, \ + const size_t start_left, \ + const size_t* const __restrict start_rights, \ + const size_t n_rights, \ + const size_t size); \ + \ + static bool \ + op_sub(ELEMENTTYPE* const left, \ + const ELEMENTTYPE* const right, \ + const size_t start_left, \ + const size_t start_right, \ + const size_t size); \ + \ + static bool \ + op_xor(ELEMENTTYPE* const left, \ + const ELEMENTTYPE* const right, \ + const size_t start_left, \ + const size_t start_right, \ + const size_t size); \ + }; + +ALL_FORWARD_TYPES_1(DECLARE_PARTIAL_FORWARD_OPS) + +#undef DECLARE_PARTIAL_FORWARD_OPS + +/////////////////////////////////////////////////////////////////////////// + #undef ALL_DATATYPES_1 +#undef ALL_FORWARD_TYPES_1 } // namespace sve } // namespace arm diff --git a/internal/core/src/bitset/detail/platform/arm/sve-impl.h b/internal/core/src/bitset/detail/platform/arm/sve-impl.h index dfc84f2824d8a..c5cd456659445 100644 --- a/internal/core/src/bitset/detail/platform/arm/sve-impl.h +++ b/internal/core/src/bitset/detail/platform/arm/sve-impl.h @@ -28,8 +28,7 @@ #include "sve-decl.h" #include "bitset/common.h" - -// #include +#include "bitset/detail/element_wise.h" namespace milvus { namespace bitset { @@ -1623,6 +1622,151 @@ OpArithCompareImpl::op_arith_compare( } } +/////////////////////////////////////////////////////////////////////////// +// forward ops + +// +bool +ForwardOpsImpl::op_and(uint8_t* const left, + const uint8_t* const right, + const size_t start_left, + const size_t start_right, + const size_t size) { + ElementWiseBitsetPolicy::op_and( + left, right, start_left, start_right, size); + return true; +} + +bool +ForwardOpsImpl::op_and_multiple( + uint8_t* const left, + const uint8_t* const* const rights, + const size_t start_left, + const size_t* const __restrict start_rights, + const size_t n_rights, + const size_t size) { + ElementWiseBitsetPolicy::op_and_multiple( + left, rights, start_left, start_rights, n_rights, size); + return true; +} + +bool +ForwardOpsImpl::op_or(uint8_t* const left, + const uint8_t* const right, + const size_t start_left, + const size_t start_right, + const size_t size) { + ElementWiseBitsetPolicy::op_or( + left, right, start_left, start_right, size); + return true; +} + +bool +ForwardOpsImpl::op_or_multiple( + uint8_t* const left, + const uint8_t* const* const rights, + const size_t start_left, + const size_t* const __restrict start_rights, + const size_t n_rights, + const size_t size) { + ElementWiseBitsetPolicy::op_or_multiple( + left, rights, start_left, start_rights, n_rights, size); + return true; +} + +bool +ForwardOpsImpl::op_xor(uint8_t* const left, + const uint8_t* const right, + const size_t start_left, + const size_t start_right, + const size_t size) { + ElementWiseBitsetPolicy::op_xor( + left, right, start_left, start_right, size); + return true; +} + +bool +ForwardOpsImpl::op_sub(uint8_t* const left, + const uint8_t* const right, + const size_t start_left, + const size_t start_right, + const size_t size) { + ElementWiseBitsetPolicy::op_sub( + left, right, start_left, start_right, size); + return true; +} + +// +bool +ForwardOpsImpl::op_and(uint64_t* const left, + const uint64_t* const right, + const size_t start_left, + const size_t start_right, + const size_t size) { + ElementWiseBitsetPolicy::op_and( + left, right, start_left, start_right, size); + return true; +} + +bool +ForwardOpsImpl::op_and_multiple( + uint64_t* const left, + const uint64_t* const* const rights, + const size_t start_left, + const size_t* const __restrict start_rights, + const size_t n_rights, + const size_t size) { + ElementWiseBitsetPolicy::op_and_multiple( + left, rights, start_left, start_rights, n_rights, size); + return true; +} + +bool +ForwardOpsImpl::op_or(uint64_t* const left, + const uint64_t* const right, + const size_t start_left, + const size_t start_right, + const size_t size) { + ElementWiseBitsetPolicy::op_or( + left, right, start_left, start_right, size); + return true; +} + +bool +ForwardOpsImpl::op_or_multiple( + uint64_t* const left, + const uint64_t* const* const rights, + const size_t start_left, + const size_t* const __restrict start_rights, + const size_t n_rights, + const size_t size) { + ElementWiseBitsetPolicy::op_or_multiple( + left, rights, start_left, start_rights, n_rights, size); + return true; +} + +bool +ForwardOpsImpl::op_xor(uint64_t* const left, + const uint64_t* const right, + const size_t start_left, + const size_t start_right, + const size_t size) { + ElementWiseBitsetPolicy::op_xor( + left, right, start_left, start_right, size); + return true; +} + +bool +ForwardOpsImpl::op_sub(uint64_t* const left, + const uint64_t* const right, + const size_t start_left, + const size_t start_right, + const size_t size) { + ElementWiseBitsetPolicy::op_sub( + left, right, start_left, start_right, size); + return true; +} + /////////////////////////////////////////////////////////////////////////// } // namespace sve diff --git a/internal/core/src/bitset/detail/platform/arm/sve.h b/internal/core/src/bitset/detail/platform/arm/sve.h index 615431373dcf2..7144d2ed094ee 100644 --- a/internal/core/src/bitset/detail/platform/arm/sve.h +++ b/internal/core/src/bitset/detail/platform/arm/sve.h @@ -55,6 +55,30 @@ struct VectorizedSve { template static constexpr inline auto op_arith_compare = sve::OpArithCompareImpl::op_arith_compare; + + template + static constexpr inline auto forward_op_and = + sve::ForwardOpsImpl::op_and; + + template + static constexpr inline auto forward_op_and_multiple = + sve::ForwardOpsImpl::op_and_multiple; + + template + static constexpr inline auto forward_op_or = + sve::ForwardOpsImpl::op_or; + + template + static constexpr inline auto forward_op_or_multiple = + sve::ForwardOpsImpl::op_or_multiple; + + template + static constexpr inline auto forward_op_xor = + sve::ForwardOpsImpl::op_xor; + + template + static constexpr inline auto forward_op_sub = + sve::ForwardOpsImpl::op_sub; }; } // namespace arm diff --git a/internal/core/src/bitset/detail/platform/dynamic.cpp b/internal/core/src/bitset/detail/platform/dynamic.cpp index 8341dede55de5..1e78428ff998b 100644 --- a/internal/core/src/bitset/detail/platform/dynamic.cpp +++ b/internal/core/src/bitset/detail/platform/dynamic.cpp @@ -88,6 +88,11 @@ using namespace milvus::bitset::detail::arm; FUNC(__VA_ARGS__, Mod, LT); \ FUNC(__VA_ARGS__, Mod, NE); +// a facility to run through all possible forward ElementT +#define ALL_FORWARD_OPS(FUNC) \ + FUNC(uint8_t); \ + FUNC(uint64_t); + // namespace milvus { namespace bitset { @@ -235,6 +240,7 @@ ALL_RANGE_OPS(DISPATCH_OP_WITHIN_RANGE_COLUMN_IMPL, float) ALL_RANGE_OPS(DISPATCH_OP_WITHIN_RANGE_COLUMN_IMPL, double) #undef DISPATCH_OP_WITHIN_RANGE_COLUMN_IMPL + } // namespace dynamic ///////////////////////////////////////////////////////////////////////////// @@ -282,6 +288,8 @@ ALL_RANGE_OPS(DISPATCH_OP_WITHIN_RANGE_VAL_IMPL, int64_t) ALL_RANGE_OPS(DISPATCH_OP_WITHIN_RANGE_VAL_IMPL, float) ALL_RANGE_OPS(DISPATCH_OP_WITHIN_RANGE_VAL_IMPL, double) +#undef DISPATCH_OP_WITHIN_RANGE_VAL_IMPL + } // namespace dynamic ///////////////////////////////////////////////////////////////////////////// @@ -332,6 +340,108 @@ ALL_ARITH_CMP_OPS(DISPATCH_OP_ARITH_COMPARE, int64_t) ALL_ARITH_CMP_OPS(DISPATCH_OP_ARITH_COMPARE, float) ALL_ARITH_CMP_OPS(DISPATCH_OP_ARITH_COMPARE, double) +#undef DISPATCH_OP_ARITH_COMPARE + +} // namespace dynamic + +///////////////////////////////////////////////////////////////////////////// +// forward_ops + +template +using ForwardOpsOp2 = bool (*)(ElementT* const left, + const ElementT* const right, + const size_t start_left, + const size_t start_right, + const size_t size); + +template +using ForwardOpsOpMultiple2 = + bool (*)(ElementT* const left, + const ElementT* const* const rights, + const size_t start_left, + const size_t* const __restrict start_rights, + const size_t n_rights, + const size_t size); + +#define DECLARE_FORWARD_OPS_OP2(ELEMENTTYPE) \ + ForwardOpsOp2 forward_op_and_##ELEMENTTYPE = \ + VectorizedRef::template forward_op_and; \ + ForwardOpsOpMultiple2 forward_op_and_multiple_##ELEMENTTYPE = \ + VectorizedRef::template forward_op_and_multiple; \ + ForwardOpsOp2 forward_op_or_##ELEMENTTYPE = \ + VectorizedRef::template forward_op_or; \ + ForwardOpsOpMultiple2 forward_op_or_multiple_##ELEMENTTYPE = \ + VectorizedRef::template forward_op_or_multiple; \ + ForwardOpsOp2 forward_op_xor_##ELEMENTTYPE = \ + VectorizedRef::template forward_op_xor; \ + ForwardOpsOp2 forward_op_sub_##ELEMENTTYPE = \ + VectorizedRef::template forward_op_sub; + +ALL_FORWARD_OPS(DECLARE_FORWARD_OPS_OP2) + +#undef DECLARE_FORWARD_OPS_OP2 + +// +namespace dynamic { + +#define DISPATCH_FORWARD_OPS_OP_AND(ELEMENTTYPE) \ + bool ForwardOpsImpl::op_and(ELEMENTTYPE* const left, \ + const ELEMENTTYPE* const right, \ + const size_t start_left, \ + const size_t start_right, \ + const size_t size) { \ + return forward_op_and_##ELEMENTTYPE( \ + left, right, start_left, start_right, size); \ + } \ + bool ForwardOpsImpl::op_and_multiple( \ + ELEMENTTYPE* const left, \ + const ELEMENTTYPE* const* const rights, \ + const size_t start_left, \ + const size_t* const __restrict start_rights, \ + const size_t n_rights, \ + const size_t size) { \ + return forward_op_and_multiple_##ELEMENTTYPE( \ + left, rights, start_left, start_rights, n_rights, size); \ + } \ + bool ForwardOpsImpl::op_or(ELEMENTTYPE* const left, \ + const ELEMENTTYPE* const right, \ + const size_t start_left, \ + const size_t start_right, \ + const size_t size) { \ + return forward_op_or_##ELEMENTTYPE( \ + left, right, start_left, start_right, size); \ + } \ + bool ForwardOpsImpl::op_or_multiple( \ + ELEMENTTYPE* const left, \ + const ELEMENTTYPE* const* const rights, \ + const size_t start_left, \ + const size_t* const __restrict start_rights, \ + const size_t n_rights, \ + const size_t size) { \ + return forward_op_or_multiple_##ELEMENTTYPE( \ + left, rights, start_left, start_rights, n_rights, size); \ + } \ + bool ForwardOpsImpl::op_xor(ELEMENTTYPE* const left, \ + const ELEMENTTYPE* const right, \ + const size_t start_left, \ + const size_t start_right, \ + const size_t size) { \ + return forward_op_xor_##ELEMENTTYPE( \ + left, right, start_left, start_right, size); \ + } \ + bool ForwardOpsImpl::op_sub(ELEMENTTYPE* const left, \ + const ELEMENTTYPE* const right, \ + const size_t start_left, \ + const size_t start_right, \ + const size_t size) { \ + return forward_op_sub_##ELEMENTTYPE( \ + left, right, start_left, start_right, size); \ + } + +ALL_FORWARD_OPS(DISPATCH_FORWARD_OPS_OP_AND) + +#undef DISPATCH_FORWARD_OPS_OP_AND + } // namespace dynamic } // namespace detail @@ -402,11 +512,28 @@ init_dynamic_hook() { ALL_ARITH_CMP_OPS(SET_ARITH_COMPARE_AVX512, float) ALL_ARITH_CMP_OPS(SET_ARITH_COMPARE_AVX512, double) +#define SET_FORWARD_OPS_AVX512(ELEMENTTYPE) \ + forward_op_and_##ELEMENTTYPE = \ + VectorizedAvx512::template forward_op_and; \ + forward_op_and_multiple_##ELEMENTTYPE = \ + VectorizedAvx512::template forward_op_and_multiple; \ + forward_op_or_##ELEMENTTYPE = \ + VectorizedAvx512::template forward_op_or; \ + forward_op_or_multiple_##ELEMENTTYPE = \ + VectorizedAvx512::template forward_op_or_multiple; \ + forward_op_xor_##ELEMENTTYPE = \ + VectorizedAvx512::template forward_op_xor; \ + forward_op_sub_##ELEMENTTYPE = \ + VectorizedAvx512::template forward_op_sub; + + ALL_FORWARD_OPS(SET_FORWARD_OPS_AVX512) + #undef SET_OP_COMPARE_COLUMN_AVX512 #undef SET_OP_COMPARE_VAL_AVX512 #undef SET_OP_WITHIN_RANGE_COLUMN_AVX512 #undef SET_OP_WITHIN_RANGE_VAL_AVX512 #undef SET_ARITH_COMPARE_AVX512 +#undef SET_FORWARD_OPS_AVX512 return; } @@ -467,11 +594,28 @@ init_dynamic_hook() { ALL_ARITH_CMP_OPS(SET_ARITH_COMPARE_AVX2, float) ALL_ARITH_CMP_OPS(SET_ARITH_COMPARE_AVX2, double) +#define SET_FORWARD_OPS_AVX2(ELEMENTTYPE) \ + forward_op_and_##ELEMENTTYPE = \ + VectorizedAvx2::template forward_op_and; \ + forward_op_and_multiple_##ELEMENTTYPE = \ + VectorizedAvx2::template forward_op_and_multiple; \ + forward_op_or_##ELEMENTTYPE = \ + VectorizedAvx2::template forward_op_or; \ + forward_op_or_multiple_##ELEMENTTYPE = \ + VectorizedAvx2::template forward_op_or_multiple; \ + forward_op_xor_##ELEMENTTYPE = \ + VectorizedAvx2::template forward_op_xor; \ + forward_op_sub_##ELEMENTTYPE = \ + VectorizedAvx2::template forward_op_sub; + + ALL_FORWARD_OPS(SET_FORWARD_OPS_AVX2) + #undef SET_OP_COMPARE_COLUMN_AVX2 #undef SET_OP_COMPARE_VAL_AVX2 #undef SET_OP_WITHIN_RANGE_COLUMN_AVX2 #undef SET_OP_WITHIN_RANGE_VAL_AVX2 #undef SET_ARITH_COMPARE_AVX2 +#undef SET_FORWARD_OPS_AVX2 return; } @@ -535,15 +679,33 @@ init_dynamic_hook() { ALL_ARITH_CMP_OPS(SET_ARITH_COMPARE_SVE, float) ALL_ARITH_CMP_OPS(SET_ARITH_COMPARE_SVE, double) +#define SET_FORWARD_OPS_SVE(ELEMENTTYPE) \ + forward_op_and_##ELEMENTTYPE = \ + VectorizedSve::template forward_op_and; \ + forward_op_and_multiple_##ELEMENTTYPE = \ + VectorizedSve::template forward_op_and_multiple; \ + forward_op_or_##ELEMENTTYPE = \ + VectorizedSve::template forward_op_or; \ + forward_op_or_multiple_##ELEMENTTYPE = \ + VectorizedSve::template forward_op_or_multiple; \ + forward_op_xor_##ELEMENTTYPE = \ + VectorizedSve::template forward_op_xor; \ + forward_op_sub_##ELEMENTTYPE = \ + VectorizedSve::template forward_op_sub; + + ALL_FORWARD_OPS(SET_FORWARD_OPS_SVE) + #undef SET_OP_COMPARE_COLUMN_SVE #undef SET_OP_COMPARE_VAL_SVE #undef SET_OP_WITHIN_RANGE_COLUMN_SVE #undef SET_OP_WITHIN_RANGE_VAL_SVE #undef SET_ARITH_COMPARE_SVE +#undef SET_FORWARD_OPS_SVE return; } #endif + // neon ? { #define SET_OP_COMPARE_COLUMN_NEON(TTYPE, UTYPE, OP) \ @@ -600,11 +762,28 @@ init_dynamic_hook() { ALL_ARITH_CMP_OPS(SET_ARITH_COMPARE_NEON, float) ALL_ARITH_CMP_OPS(SET_ARITH_COMPARE_NEON, double) +#define SET_FORWARD_OPS_NEON(ELEMENTTYPE) \ + forward_op_and_##ELEMENTTYPE = \ + VectorizedNeon::template forward_op_and; \ + forward_op_and_multiple_##ELEMENTTYPE = \ + VectorizedNeon::template forward_op_and_multiple; \ + forward_op_or_##ELEMENTTYPE = \ + VectorizedNeon::template forward_op_or; \ + forward_op_or_multiple_##ELEMENTTYPE = \ + VectorizedNeon::template forward_op_or_multiple; \ + forward_op_xor_##ELEMENTTYPE = \ + VectorizedNeon::template forward_op_xor; \ + forward_op_sub_##ELEMENTTYPE = \ + VectorizedNeon::template forward_op_sub; + + ALL_FORWARD_OPS(SET_FORWARD_OPS_NEON) + #undef SET_OP_COMPARE_COLUMN_NEON #undef SET_OP_COMPARE_VAL_NEON #undef SET_OP_WITHIN_RANGE_COLUMN_NEON #undef SET_OP_WITHIN_RANGE_VAL_NEON #undef SET_ARITH_COMPARE_NEON +#undef SET_FORWARD_OPS_NEON return; } @@ -616,6 +795,7 @@ init_dynamic_hook() { #undef ALL_COMPARE_OPS #undef ALL_RANGE_OPS #undef ALL_ARITH_CMP_OPS +#undef ALL_FORWARD_OPS // static int init_dynamic_ = []() { diff --git a/internal/core/src/bitset/detail/platform/dynamic.h b/internal/core/src/bitset/detail/platform/dynamic.h index 3a050a5e83aac..6d638392bb2ce 100644 --- a/internal/core/src/bitset/detail/platform/dynamic.h +++ b/internal/core/src/bitset/detail/platform/dynamic.h @@ -37,6 +37,11 @@ namespace dynamic { FUNC(float); \ FUNC(double); +// a facility to run through all acceptable forward types +#define ALL_FORWARD_TYPES_1(FUNC) \ + FUNC(uint8_t); \ + FUNC(uint64_t); + /////////////////////////////////////////////////////////////////////////// // the default implementation template @@ -176,11 +181,125 @@ struct OpArithCompareImpl { ALL_DATATYPES_1(DECLARE_PARTIAL_OP_ARITH_COMPARE) -// +#undef DECLARE_PARTIAL_OP_ARITH_COMPARE + +/////////////////////////////////////////////////////////////////////////// +// the default implementation +template +struct ForwardOpsImpl { + static inline bool + op_and(ElementT* const left, + const ElementT* const right, + const size_t start_left, + const size_t start_right, + const size_t size) { + return false; + } + + static inline bool + op_and_multiple(ElementT* const left, + const ElementT* const* const rights, + const size_t start_left, + const size_t* const __restrict start_rights, + const size_t n_rights, + const size_t size) { + return false; + } + + static inline bool + op_or(ElementT* const left, + const ElementT* const right, + const size_t start_left, + const size_t start_right, + const size_t size) { + return false; + } + + static inline bool + op_or_multiple(ElementT* const left, + const ElementT* const* const rights, + const size_t start_left, + const size_t* const __restrict start_rights, + const size_t n_rights, + const size_t size) { + return false; + } + + static inline bool + op_xor(ElementT* const left, + const ElementT* const right, + const size_t start_left, + const size_t start_right, + const size_t size) { + return false; + } + + static inline bool + op_sub(ElementT* const left, + const ElementT* const right, + const size_t start_left, + const size_t start_right, + const size_t size) { + return false; + } +}; + +#define DECLARE_PARTIAL_FORWARD_OPS(ELEMENTTYPE) \ + template <> \ + struct ForwardOpsImpl { \ + static bool \ + op_and(ELEMENTTYPE* const left, \ + const ELEMENTTYPE* const right, \ + const size_t start_left, \ + const size_t start_right, \ + const size_t size); \ + \ + static bool \ + op_and_multiple(ELEMENTTYPE* const left, \ + const ELEMENTTYPE* const* const rights, \ + const size_t start_left, \ + const size_t* const __restrict start_rights, \ + const size_t n_rights, \ + const size_t size); \ + \ + static bool \ + op_or(ELEMENTTYPE* const left, \ + const ELEMENTTYPE* const right, \ + const size_t start_left, \ + const size_t start_right, \ + const size_t size); \ + \ + static bool \ + op_or_multiple(ELEMENTTYPE* const left, \ + const ELEMENTTYPE* const* const rights, \ + const size_t start_left, \ + const size_t* const __restrict start_rights, \ + const size_t n_rights, \ + const size_t size); \ + \ + static bool \ + op_sub(ELEMENTTYPE* const left, \ + const ELEMENTTYPE* const right, \ + const size_t start_left, \ + const size_t start_right, \ + const size_t size); \ + \ + static bool \ + op_xor(ELEMENTTYPE* const left, \ + const ELEMENTTYPE* const right, \ + const size_t start_left, \ + const size_t start_right, \ + const size_t size); \ + }; + +ALL_FORWARD_TYPES_1(DECLARE_PARTIAL_FORWARD_OPS) + +#undef DECLARE_PARTIAL_FORWARD_OPS /////////////////////////////////////////////////////////////////////////// #undef ALL_DATATYPES_1 +#undef ALL_FORWARD_TYPES_1 } // namespace dynamic @@ -248,6 +367,77 @@ struct VectorizedDynamic { return dynamic::OpArithCompareImpl::op_arith_compare( bitmask, src, right_operand, value, size); } + + // The following functions just forward parameters to the reference code, + // generated for a particular platform. + + template + static inline bool + forward_op_and(ElementT* const left, + const ElementT* const right, + const size_t start_left, + const size_t start_right, + const size_t size) { + return dynamic::ForwardOpsImpl::op_and( + left, right, start_left, start_right, size); + } + + template + static inline bool + forward_op_and_multiple(ElementT* const left, + const ElementT* const* const rights, + const size_t start_left, + const size_t* const __restrict start_rights, + const size_t n_rights, + const size_t size) { + return dynamic::ForwardOpsImpl::op_and_multiple( + left, rights, start_left, start_rights, n_rights, size); + } + + template + static inline bool + forward_op_or(ElementT* const left, + const ElementT* const right, + const size_t start_left, + const size_t start_right, + const size_t size) { + return dynamic::ForwardOpsImpl::op_or( + left, right, start_left, start_right, size); + } + + template + static inline bool + forward_op_or_multiple(ElementT* const left, + const ElementT* const* const rights, + const size_t start_left, + const size_t* const __restrict start_rights, + const size_t n_rights, + const size_t size) { + return dynamic::ForwardOpsImpl::op_or_multiple( + left, rights, start_left, start_rights, n_rights, size); + } + + template + static inline bool + forward_op_xor(ElementT* const left, + const ElementT* const right, + const size_t start_left, + const size_t start_right, + const size_t size) { + return dynamic::ForwardOpsImpl::op_xor( + left, right, start_left, start_right, size); + } + + template + static inline bool + forward_op_sub(ElementT* const left, + const ElementT* const right, + const size_t start_left, + const size_t start_right, + const size_t size) { + return dynamic::ForwardOpsImpl::op_sub( + left, right, start_left, start_right, size); + } }; } // namespace detail diff --git a/internal/core/src/bitset/detail/platform/vectorized_ref.h b/internal/core/src/bitset/detail/platform/vectorized_ref.h index 20da65406f1f1..39041f84f7c24 100644 --- a/internal/core/src/bitset/detail/platform/vectorized_ref.h +++ b/internal/core/src/bitset/detail/platform/vectorized_ref.h @@ -27,9 +27,13 @@ namespace bitset { namespace detail { // The default reference vectorizer. -// Its every function returns a boolean value whether a vectorized implementation +// Functions return a boolean value whether a vectorized implementation // exists and was invoked. If not, then the caller code will use a default // non-vectorized implementation. +// Certain functions just forward the parameters to the platform code. Basically, +// sometimes compiler can do a good job on its own, we just need to make sure +// that it uses available appropriate hardware instructions. No specialized +// implementation is used under the hood. // The default vectorizer provides no vectorized implementation, forcing the // caller to use a defaut non-vectorized implementation every time. struct VectorizedRef { @@ -88,6 +92,72 @@ struct VectorizedRef { const size_t size) { return false; } + + // The following functions just forward parameters to the reference code, + // generated for a particular platform. + // The reference 'platform' is just a default platform. + + template + static inline bool + forward_op_and(ElementT* const left, + const ElementT* const right, + const size_t start_left, + const size_t start_right, + const size_t size) { + return false; + } + + template + static inline bool + forward_op_and_multiple(ElementT* const left, + const ElementT* const* const rights, + const size_t start_left, + const size_t* const __restrict start_rights, + const size_t n_rights, + const size_t size) { + return false; + } + + template + static inline bool + forward_op_or(ElementT* const left, + const ElementT* const right, + const size_t start_left, + const size_t start_right, + const size_t size) { + return false; + } + + template + static inline bool + forward_op_or_multiple(ElementT* const left, + const ElementT* const* const rights, + const size_t start_left, + const size_t* const __restrict start_rights, + const size_t n_rights, + const size_t size) { + return false; + } + + template + static inline bool + forward_op_xor(ElementT* const left, + const ElementT* const right, + const size_t start_left, + const size_t start_right, + const size_t size) { + return false; + } + + template + static inline bool + forward_op_sub(ElementT* const left, + const ElementT* const right, + const size_t start_left, + const size_t start_right, + const size_t size) { + return false; + } }; } // namespace detail diff --git a/internal/core/src/bitset/detail/platform/x86/avx2-decl.h b/internal/core/src/bitset/detail/platform/x86/avx2-decl.h index cdac2b9713f31..97811eef6d0c9 100644 --- a/internal/core/src/bitset/detail/platform/x86/avx2-decl.h +++ b/internal/core/src/bitset/detail/platform/x86/avx2-decl.h @@ -39,6 +39,11 @@ namespace avx2 { FUNC(float); \ FUNC(double); +// a facility to run through all acceptable forward types +#define ALL_FORWARD_TYPES_1(FUNC) \ + FUNC(uint8_t); \ + FUNC(uint64_t); + /////////////////////////////////////////////////////////////////////////// // the default implementation does nothing @@ -192,7 +197,122 @@ ALL_DATATYPES_1(DECLARE_PARTIAL_OP_ARITH_COMPARE) /////////////////////////////////////////////////////////////////////////// +// forward ops +template +struct ForwardOpsImpl { + static inline bool + op_and(ElementT* const left, + const ElementT* const right, + const size_t start_left, + const size_t start_right, + const size_t size) { + return false; + } + + static inline bool + op_and_multiple(ElementT* const left, + const ElementT* const* const rights, + const size_t start_left, + const size_t* const __restrict start_rights, + const size_t n_rights, + const size_t size) { + return false; + } + + static inline bool + op_or(ElementT* const left, + const ElementT* const right, + const size_t start_left, + const size_t start_right, + const size_t size) { + return false; + } + + static inline bool + op_or_multiple(ElementT* const left, + const ElementT* const* const rights, + const size_t start_left, + const size_t* const __restrict start_rights, + const size_t n_rights, + const size_t size) { + return false; + } + + static inline bool + op_xor(ElementT* const left, + const ElementT* const right, + const size_t start_left, + const size_t start_right, + const size_t size) { + return false; + } + + static inline bool + op_sub(ElementT* const left, + const ElementT* const right, + const size_t start_left, + const size_t start_right, + const size_t size) { + return false; + } +}; + +#define DECLARE_PARTIAL_FORWARD_OPS(ELEMENTTYPE) \ + template <> \ + struct ForwardOpsImpl { \ + static bool \ + op_and(ELEMENTTYPE* const left, \ + const ELEMENTTYPE* const right, \ + const size_t start_left, \ + const size_t start_right, \ + const size_t size); \ + \ + static bool \ + op_and_multiple(ELEMENTTYPE* const left, \ + const ELEMENTTYPE* const* const rights, \ + const size_t start_left, \ + const size_t* const __restrict start_rights, \ + const size_t n_rights, \ + const size_t size); \ + \ + static bool \ + op_or(ELEMENTTYPE* const left, \ + const ELEMENTTYPE* const right, \ + const size_t start_left, \ + const size_t start_right, \ + const size_t size); \ + \ + static bool \ + op_or_multiple(ELEMENTTYPE* const left, \ + const ELEMENTTYPE* const* const rights, \ + const size_t start_left, \ + const size_t* const __restrict start_rights, \ + const size_t n_rights, \ + const size_t size); \ + \ + static bool \ + op_sub(ELEMENTTYPE* const left, \ + const ELEMENTTYPE* const right, \ + const size_t start_left, \ + const size_t start_right, \ + const size_t size); \ + \ + static bool \ + op_xor(ELEMENTTYPE* const left, \ + const ELEMENTTYPE* const right, \ + const size_t start_left, \ + const size_t start_right, \ + const size_t size); \ + }; + +ALL_FORWARD_TYPES_1(DECLARE_PARTIAL_FORWARD_OPS) + +#undef DECLARE_PARTIAL_FORWARD_OPS + +/////////////////////////////////////////////////////////////////////////// + #undef ALL_DATATYPES_1 +#undef ALL_FORWARD_TYPES_1 } // namespace avx2 } // namespace x86 diff --git a/internal/core/src/bitset/detail/platform/x86/avx2-impl.h b/internal/core/src/bitset/detail/platform/x86/avx2-impl.h index 3b74749d2a637..51af01047a379 100644 --- a/internal/core/src/bitset/detail/platform/x86/avx2-impl.h +++ b/internal/core/src/bitset/detail/platform/x86/avx2-impl.h @@ -28,6 +28,7 @@ #include "avx2-decl.h" #include "bitset/common.h" +#include "bitset/detail/element_wise.h" #include "common.h" namespace milvus { @@ -1649,6 +1650,151 @@ OpArithCompareImpl::op_arith_compare( } } +/////////////////////////////////////////////////////////////////////////// +// forward ops + +// +bool +ForwardOpsImpl::op_and(uint8_t* const left, + const uint8_t* const right, + const size_t start_left, + const size_t start_right, + const size_t size) { + ElementWiseBitsetPolicy::op_and( + left, right, start_left, start_right, size); + return true; +} + +bool +ForwardOpsImpl::op_and_multiple( + uint8_t* const left, + const uint8_t* const* const rights, + const size_t start_left, + const size_t* const __restrict start_rights, + const size_t n_rights, + const size_t size) { + ElementWiseBitsetPolicy::op_and_multiple( + left, rights, start_left, start_rights, n_rights, size); + return true; +} + +bool +ForwardOpsImpl::op_or(uint8_t* const left, + const uint8_t* const right, + const size_t start_left, + const size_t start_right, + const size_t size) { + ElementWiseBitsetPolicy::op_or( + left, right, start_left, start_right, size); + return true; +} + +bool +ForwardOpsImpl::op_or_multiple( + uint8_t* const left, + const uint8_t* const* const rights, + const size_t start_left, + const size_t* const __restrict start_rights, + const size_t n_rights, + const size_t size) { + ElementWiseBitsetPolicy::op_or_multiple( + left, rights, start_left, start_rights, n_rights, size); + return true; +} + +bool +ForwardOpsImpl::op_xor(uint8_t* const left, + const uint8_t* const right, + const size_t start_left, + const size_t start_right, + const size_t size) { + ElementWiseBitsetPolicy::op_xor( + left, right, start_left, start_right, size); + return true; +} + +bool +ForwardOpsImpl::op_sub(uint8_t* const left, + const uint8_t* const right, + const size_t start_left, + const size_t start_right, + const size_t size) { + ElementWiseBitsetPolicy::op_sub( + left, right, start_left, start_right, size); + return true; +} + +// +bool +ForwardOpsImpl::op_and(uint64_t* const left, + const uint64_t* const right, + const size_t start_left, + const size_t start_right, + const size_t size) { + ElementWiseBitsetPolicy::op_and( + left, right, start_left, start_right, size); + return true; +} + +bool +ForwardOpsImpl::op_and_multiple( + uint64_t* const left, + const uint64_t* const* const rights, + const size_t start_left, + const size_t* const __restrict start_rights, + const size_t n_rights, + const size_t size) { + ElementWiseBitsetPolicy::op_and_multiple( + left, rights, start_left, start_rights, n_rights, size); + return true; +} + +bool +ForwardOpsImpl::op_or(uint64_t* const left, + const uint64_t* const right, + const size_t start_left, + const size_t start_right, + const size_t size) { + ElementWiseBitsetPolicy::op_or( + left, right, start_left, start_right, size); + return true; +} + +bool +ForwardOpsImpl::op_or_multiple( + uint64_t* const left, + const uint64_t* const* const rights, + const size_t start_left, + const size_t* const __restrict start_rights, + const size_t n_rights, + const size_t size) { + ElementWiseBitsetPolicy::op_or_multiple( + left, rights, start_left, start_rights, n_rights, size); + return true; +} + +bool +ForwardOpsImpl::op_xor(uint64_t* const left, + const uint64_t* const right, + const size_t start_left, + const size_t start_right, + const size_t size) { + ElementWiseBitsetPolicy::op_xor( + left, right, start_left, start_right, size); + return true; +} + +bool +ForwardOpsImpl::op_sub(uint64_t* const left, + const uint64_t* const right, + const size_t start_left, + const size_t start_right, + const size_t size) { + ElementWiseBitsetPolicy::op_sub( + left, right, start_left, start_right, size); + return true; +} + /////////////////////////////////////////////////////////////////////////// } // namespace avx2 diff --git a/internal/core/src/bitset/detail/platform/x86/avx2.h b/internal/core/src/bitset/detail/platform/x86/avx2.h index 711b9f2b8f513..8cd398eba7151 100644 --- a/internal/core/src/bitset/detail/platform/x86/avx2.h +++ b/internal/core/src/bitset/detail/platform/x86/avx2.h @@ -55,6 +55,30 @@ struct VectorizedAvx2 { template static constexpr inline auto op_arith_compare = avx2::OpArithCompareImpl::op_arith_compare; + + template + static constexpr inline auto forward_op_and = + avx2::ForwardOpsImpl::op_and; + + template + static constexpr inline auto forward_op_and_multiple = + avx2::ForwardOpsImpl::op_and_multiple; + + template + static constexpr inline auto forward_op_or = + avx2::ForwardOpsImpl::op_or; + + template + static constexpr inline auto forward_op_or_multiple = + avx2::ForwardOpsImpl::op_or_multiple; + + template + static constexpr inline auto forward_op_xor = + avx2::ForwardOpsImpl::op_xor; + + template + static constexpr inline auto forward_op_sub = + avx2::ForwardOpsImpl::op_sub; }; } // namespace x86 diff --git a/internal/core/src/bitset/detail/platform/x86/avx512-decl.h b/internal/core/src/bitset/detail/platform/x86/avx512-decl.h index 3ad5173cda370..df3a5f110cda0 100644 --- a/internal/core/src/bitset/detail/platform/x86/avx512-decl.h +++ b/internal/core/src/bitset/detail/platform/x86/avx512-decl.h @@ -39,6 +39,11 @@ namespace avx512 { FUNC(float); \ FUNC(double); +// a facility to run through all acceptable forward types +#define ALL_FORWARD_TYPES_1(FUNC) \ + FUNC(uint8_t); \ + FUNC(uint64_t); + /////////////////////////////////////////////////////////////////////////// // the default implementation does nothing @@ -192,7 +197,122 @@ ALL_DATATYPES_1(DECLARE_PARTIAL_OP_ARITH_COMPARE) /////////////////////////////////////////////////////////////////////////// +// forward ops +template +struct ForwardOpsImpl { + static inline bool + op_and(ElementT* const left, + const ElementT* const right, + const size_t start_left, + const size_t start_right, + const size_t size) { + return false; + } + + static inline bool + op_and_multiple(ElementT* const left, + const ElementT* const* const rights, + const size_t start_left, + const size_t* const __restrict start_rights, + const size_t n_rights, + const size_t size) { + return false; + } + + static inline bool + op_or(ElementT* const left, + const ElementT* const right, + const size_t start_left, + const size_t start_right, + const size_t size) { + return false; + } + + static inline bool + op_or_multiple(ElementT* const left, + const ElementT* const* const rights, + const size_t start_left, + const size_t* const __restrict start_rights, + const size_t n_rights, + const size_t size) { + return false; + } + + static inline bool + op_xor(ElementT* const left, + const ElementT* const right, + const size_t start_left, + const size_t start_right, + const size_t size) { + return false; + } + + static inline bool + op_sub(ElementT* const left, + const ElementT* const right, + const size_t start_left, + const size_t start_right, + const size_t size) { + return false; + } +}; + +#define DECLARE_PARTIAL_FORWARD_OPS(ELEMENTTYPE) \ + template <> \ + struct ForwardOpsImpl { \ + static bool \ + op_and(ELEMENTTYPE* const left, \ + const ELEMENTTYPE* const right, \ + const size_t start_left, \ + const size_t start_right, \ + const size_t size); \ + \ + static bool \ + op_and_multiple(ELEMENTTYPE* const left, \ + const ELEMENTTYPE* const* const rights, \ + const size_t start_left, \ + const size_t* const __restrict start_rights, \ + const size_t n_rights, \ + const size_t size); \ + \ + static bool \ + op_or(ELEMENTTYPE* const left, \ + const ELEMENTTYPE* const right, \ + const size_t start_left, \ + const size_t start_right, \ + const size_t size); \ + \ + static bool \ + op_or_multiple(ELEMENTTYPE* const left, \ + const ELEMENTTYPE* const* const rights, \ + const size_t start_left, \ + const size_t* const __restrict start_rights, \ + const size_t n_rights, \ + const size_t size); \ + \ + static bool \ + op_sub(ELEMENTTYPE* const left, \ + const ELEMENTTYPE* const right, \ + const size_t start_left, \ + const size_t start_right, \ + const size_t size); \ + \ + static bool \ + op_xor(ELEMENTTYPE* const left, \ + const ELEMENTTYPE* const right, \ + const size_t start_left, \ + const size_t start_right, \ + const size_t size); \ + }; + +ALL_FORWARD_TYPES_1(DECLARE_PARTIAL_FORWARD_OPS) + +#undef DECLARE_PARTIAL_FORWARD_OPS + +/////////////////////////////////////////////////////////////////////////// + #undef ALL_DATATYPES_1 +#undef ALL_FORWARD_TYPES_1 } // namespace avx512 } // namespace x86 diff --git a/internal/core/src/bitset/detail/platform/x86/avx512-impl.h b/internal/core/src/bitset/detail/platform/x86/avx512-impl.h index b460d257ecda6..ca78585e1b12b 100644 --- a/internal/core/src/bitset/detail/platform/x86/avx512-impl.h +++ b/internal/core/src/bitset/detail/platform/x86/avx512-impl.h @@ -28,6 +28,7 @@ #include "avx512-decl.h" #include "bitset/common.h" +#include "bitset/detail/element_wise.h" #include "common.h" namespace milvus { @@ -1451,6 +1452,151 @@ OpArithCompareImpl::op_arith_compare( } } +/////////////////////////////////////////////////////////////////////////// +// forward ops + +// +bool +ForwardOpsImpl::op_and(uint8_t* const left, + const uint8_t* const right, + const size_t start_left, + const size_t start_right, + const size_t size) { + ElementWiseBitsetPolicy::op_and( + left, right, start_left, start_right, size); + return true; +} + +bool +ForwardOpsImpl::op_and_multiple( + uint8_t* const left, + const uint8_t* const* const rights, + const size_t start_left, + const size_t* const __restrict start_rights, + const size_t n_rights, + const size_t size) { + ElementWiseBitsetPolicy::op_and_multiple( + left, rights, start_left, start_rights, n_rights, size); + return true; +} + +bool +ForwardOpsImpl::op_or(uint8_t* const left, + const uint8_t* const right, + const size_t start_left, + const size_t start_right, + const size_t size) { + ElementWiseBitsetPolicy::op_or( + left, right, start_left, start_right, size); + return true; +} + +bool +ForwardOpsImpl::op_or_multiple( + uint8_t* const left, + const uint8_t* const* const rights, + const size_t start_left, + const size_t* const __restrict start_rights, + const size_t n_rights, + const size_t size) { + ElementWiseBitsetPolicy::op_or_multiple( + left, rights, start_left, start_rights, n_rights, size); + return true; +} + +bool +ForwardOpsImpl::op_xor(uint8_t* const left, + const uint8_t* const right, + const size_t start_left, + const size_t start_right, + const size_t size) { + ElementWiseBitsetPolicy::op_xor( + left, right, start_left, start_right, size); + return true; +} + +bool +ForwardOpsImpl::op_sub(uint8_t* const left, + const uint8_t* const right, + const size_t start_left, + const size_t start_right, + const size_t size) { + ElementWiseBitsetPolicy::op_sub( + left, right, start_left, start_right, size); + return true; +} + +// +bool +ForwardOpsImpl::op_and(uint64_t* const left, + const uint64_t* const right, + const size_t start_left, + const size_t start_right, + const size_t size) { + ElementWiseBitsetPolicy::op_and( + left, right, start_left, start_right, size); + return true; +} + +bool +ForwardOpsImpl::op_and_multiple( + uint64_t* const left, + const uint64_t* const* const rights, + const size_t start_left, + const size_t* const __restrict start_rights, + const size_t n_rights, + const size_t size) { + ElementWiseBitsetPolicy::op_and_multiple( + left, rights, start_left, start_rights, n_rights, size); + return true; +} + +bool +ForwardOpsImpl::op_or(uint64_t* const left, + const uint64_t* const right, + const size_t start_left, + const size_t start_right, + const size_t size) { + ElementWiseBitsetPolicy::op_or( + left, right, start_left, start_right, size); + return true; +} + +bool +ForwardOpsImpl::op_or_multiple( + uint64_t* const left, + const uint64_t* const* const rights, + const size_t start_left, + const size_t* const __restrict start_rights, + const size_t n_rights, + const size_t size) { + ElementWiseBitsetPolicy::op_or_multiple( + left, rights, start_left, start_rights, n_rights, size); + return true; +} + +bool +ForwardOpsImpl::op_xor(uint64_t* const left, + const uint64_t* const right, + const size_t start_left, + const size_t start_right, + const size_t size) { + ElementWiseBitsetPolicy::op_xor( + left, right, start_left, start_right, size); + return true; +} + +bool +ForwardOpsImpl::op_sub(uint64_t* const left, + const uint64_t* const right, + const size_t start_left, + const size_t start_right, + const size_t size) { + ElementWiseBitsetPolicy::op_sub( + left, right, start_left, start_right, size); + return true; +} + /////////////////////////////////////////////////////////////////////////// } // namespace avx512 diff --git a/internal/core/src/bitset/detail/platform/x86/avx512.h b/internal/core/src/bitset/detail/platform/x86/avx512.h index 2582efd7c3800..836e65a771ae7 100644 --- a/internal/core/src/bitset/detail/platform/x86/avx512.h +++ b/internal/core/src/bitset/detail/platform/x86/avx512.h @@ -55,6 +55,30 @@ struct VectorizedAvx512 { template static constexpr inline auto op_arith_compare = avx512::OpArithCompareImpl::op_arith_compare; + + template + static constexpr inline auto forward_op_and = + avx512::ForwardOpsImpl::op_and; + + template + static constexpr inline auto forward_op_and_multiple = + avx512::ForwardOpsImpl::op_and_multiple; + + template + static constexpr inline auto forward_op_or = + avx512::ForwardOpsImpl::op_or; + + template + static constexpr inline auto forward_op_or_multiple = + avx512::ForwardOpsImpl::op_or_multiple; + + template + static constexpr inline auto forward_op_xor = + avx512::ForwardOpsImpl::op_xor; + + template + static constexpr inline auto forward_op_sub = + avx512::ForwardOpsImpl::op_sub; }; } // namespace x86 diff --git a/internal/core/src/bitset/detail/proxy.h b/internal/core/src/bitset/detail/proxy.h index efcdc0994e571..82e0922a55ad6 100644 --- a/internal/core/src/bitset/detail/proxy.h +++ b/internal/core/src/bitset/detail/proxy.h @@ -23,19 +23,19 @@ namespace detail { template struct ConstProxy { using policy_type = PolicyT; - using size_type = typename policy_type::size_type; using data_type = typename policy_type::data_type; using self_type = ConstProxy; const data_type& element; data_type mask; - inline ConstProxy(const data_type& _element, const size_type _shift) + inline ConstProxy(const data_type& _element, const size_t _shift) : element{_element} { mask = (data_type(1) << _shift); } - inline operator bool() const { + inline + operator bool() const { return ((element & mask) != 0); } inline bool @@ -47,19 +47,18 @@ struct ConstProxy { template struct Proxy { using policy_type = PolicyT; - using size_type = typename policy_type::size_type; using data_type = typename policy_type::data_type; using self_type = Proxy; data_type& element; data_type mask; - inline Proxy(data_type& _element, const size_type _shift) - : element{_element} { + inline Proxy(data_type& _element, const size_t _shift) : element{_element} { mask = (data_type(1) << _shift); } - inline operator bool() const { + inline + operator bool() const { return ((element & mask) != 0); } inline bool diff --git a/internal/core/unittest/test_bitset.cpp b/internal/core/unittest/test_bitset.cpp index a5f93a9f83c85..dd21d05781bef 100644 --- a/internal/core/unittest/test_bitset.cpp +++ b/internal/core/unittest/test_bitset.cpp @@ -257,6 +257,22 @@ using Ttypes1 = ::testing::Types< #endif >; +// combinations to run +using Ttypes0 = ::testing::Types< +#if FULL_TESTS == 1 + std::tuple, +#endif + + std::tuple + +#if FULL_TESTS == 1 + , + std::tuple, + + std::tuple +#endif + >; + ////////////////////////////////////////////////////////////////////////////////////////// struct StopWatch { @@ -1660,6 +1676,445 @@ TEST(CountElement, f) { ////////////////////////////////////////////////////////////////////////////////////////// +enum class TestInplaceOp { AND, OR, XOR, SUB }; + +// +template +void +TestInplaceOpImpl(BitsetT& bitset, BitsetT& bitset_2, const TestInplaceOp op) { + const size_t n = bitset.size(); + const size_t max_v = 3; + + std::default_random_engine rng(123); + std::uniform_int_distribution u(0, max_v); + + // populate first bitset + std::vector ref_bitset(n, false); + for (size_t i = 0; i < n; i++) { + bool enabled = (u(rng) == 0); + + ref_bitset[i] = enabled; + bitset[i] = enabled; + } + + // populate second bitset + std::vector ref_bitset_2(n, false); + for (size_t i = 0; i < n; i++) { + bool enabled = (u(rng) == 0); + + ref_bitset_2[i] = enabled; + bitset_2[i] = enabled; + } + + // evaluate + StopWatch sw; + if (op == TestInplaceOp::AND) { + bitset.inplace_and(bitset_2, n); + } else if (op == TestInplaceOp::OR) { + bitset.inplace_or(bitset_2, n); + } else if (op == TestInplaceOp::XOR) { + bitset.inplace_xor(bitset_2, n); + } else if (op == TestInplaceOp::SUB) { + bitset.inplace_sub(bitset_2, n); + } else { + ASSERT_TRUE(false) << "Not implemented"; + } + + if (print_timing) { + printf("elapsed %f\n", sw.elapsed()); + } + + // validate + for (size_t i = 0; i < n; i++) { + if (op == TestInplaceOp::AND) { + ASSERT_EQ(bitset[i], ref_bitset[i] & ref_bitset_2[i]); + } else if (op == TestInplaceOp::OR) { + ASSERT_EQ(bitset[i], ref_bitset[i] | ref_bitset_2[i]); + } else if (op == TestInplaceOp::XOR) { + ASSERT_EQ(bitset[i], ref_bitset[i] ^ ref_bitset_2[i]); + } else if (op == TestInplaceOp::SUB) { + ASSERT_EQ(bitset[i], ref_bitset[i] & (~ref_bitset_2[i])); + } else { + ASSERT_TRUE(false) << "Not implemented"; + } + } +} + +template +void +TestInplaceOpImpl() { + const auto inplace_ops = {TestInplaceOp::AND, + TestInplaceOp::OR, + TestInplaceOp::XOR, + TestInplaceOp::SUB}; + + for (const size_t n : typical_sizes) { + for (const auto op : inplace_ops) { + BitsetT bitset(n); + bitset.reset(); + BitsetT bitset_2(n); + bitset_2.reset(); + + if (print_log) { + printf("Testing bitset, n=%zd, op=%zd\n", n, (size_t)op); + } + + TestInplaceOpImpl(bitset, bitset_2, op); + + for (const size_t offset : typical_offsets) { + if (offset >= n) { + continue; + } + + bitset.reset(); + auto view = bitset.view(offset); + bitset_2.reset(); + auto view_2 = bitset_2.view(offset); + + if (print_log) { + printf("Testing bitset view, n=%zd, offset=%zd, op=%zd\n", + n, + offset, + (size_t)op); + } + + TestInplaceOpImpl(view, view_2, op); + } + } + } +} + +// +template +class InplaceOpSuite : public ::testing::Test {}; + +TYPED_TEST_SUITE_P(InplaceOpSuite); + +TYPED_TEST_P(InplaceOpSuite, BitWise) { + using impl_traits = RefImplTraits, + std::tuple_element_t<1, TypeParam>>; + TestInplaceOpImpl(); +} + +TYPED_TEST_P(InplaceOpSuite, ElementWise) { + using impl_traits = ElementImplTraits, + std::tuple_element_t<1, TypeParam>>; + TestInplaceOpImpl(); +} + +TYPED_TEST_P(InplaceOpSuite, Avx2) { +#if defined(__x86_64__) + using namespace milvus::bitset::detail::x86; + + if (cpu_support_avx2()) { + using impl_traits = + VectorizedImplTraits, + std::tuple_element_t<1, TypeParam>, + milvus::bitset::detail::x86::VectorizedAvx2>; + TestInplaceOpImpl(); + } +#endif +} + +TYPED_TEST_P(InplaceOpSuite, Avx512) { +#if defined(__x86_64__) + using namespace milvus::bitset::detail::x86; + + if (cpu_support_avx512()) { + using impl_traits = + VectorizedImplTraits, + std::tuple_element_t<1, TypeParam>, + milvus::bitset::detail::x86::VectorizedAvx512>; + TestInplaceOpImpl(); + } +#endif +} + +TYPED_TEST_P(InplaceOpSuite, Neon) { +#if defined(__aarch64__) + using namespace milvus::bitset::detail::arm; + + using impl_traits = + VectorizedImplTraits, + std::tuple_element_t<1, TypeParam>, + milvus::bitset::detail::arm::VectorizedNeon>; + TestInplaceOpImpl(); +#endif +} + +TYPED_TEST_P(InplaceOpSuite, Sve) { +#if defined(__aarch64__) && defined(__ARM_FEATURE_SVE) + using namespace milvus::bitset::detail::arm; + + using impl_traits = + VectorizedImplTraits, + std::tuple_element_t<1, TypeParam>, + milvus::bitset::detail::arm::VectorizedSve>; + TestInplaceOpImpl(); +#endif +} + +TYPED_TEST_P(InplaceOpSuite, Dynamic) { + using impl_traits = + VectorizedImplTraits, + std::tuple_element_t<1, TypeParam>, + milvus::bitset::detail::VectorizedDynamic>; + TestInplaceOpImpl(); +} + +TYPED_TEST_P(InplaceOpSuite, VecRef) { + using impl_traits = + VectorizedImplTraits, + std::tuple_element_t<1, TypeParam>, + milvus::bitset::detail::VectorizedRef>; + TestInplaceOpImpl(); +} + +// +REGISTER_TYPED_TEST_SUITE_P(InplaceOpSuite, + BitWise, + ElementWise, + Avx2, + Avx512, + Neon, + Sve, + Dynamic, + VecRef); + +INSTANTIATE_TYPED_TEST_SUITE_P(InplaceOpTest, InplaceOpSuite, Ttypes0); + +////////////////////////////////////////////////////////////////////////////////////////// + +// +template +void +TestInplaceOpMultipleImpl(BitsetT& bitset, + std::vector& bitset_others, + const TestInplaceOp op) { + const size_t n = bitset.size(); + const size_t n_others = bitset_others.size(); + const size_t max_v = 3; + + std::default_random_engine rng(123); + std::uniform_int_distribution u(0, max_v); + + // populate first bitset + std::vector ref_bitset(n, false); + for (size_t i = 0; i < n; i++) { + bool enabled = (u(rng) == 0); + + ref_bitset[i] = enabled; + bitset[i] = enabled; + } + + // populate others + std::vector> ref_others; + for (size_t j = 0; j < n_others; j++) { + std::vector ref_other(n, false); + for (size_t i = 0; i < n; i++) { + bool enabled = (u(rng) == 0); + + ref_other[i] = enabled; + bitset_others[j][i] = enabled; + } + + ref_others.push_back(std::move(ref_other)); + } + + // evaluate + StopWatch sw; + if (op == TestInplaceOp::AND) { + bitset.inplace_and(bitset_others.data(), n_others, n); + } else if (op == TestInplaceOp::OR) { + bitset.inplace_or(bitset_others.data(), n_others, n); + } else { + ASSERT_TRUE(false) << "Not implemented"; + } + + if (print_timing) { + printf("elapsed %f\n", sw.elapsed()); + } + + // validate + for (size_t i = 0; i < n; i++) { + if (op == TestInplaceOp::AND) { + bool b = ref_bitset[i]; + for (size_t j = 0; j < n_others; j++) { + b &= ref_others[j][i]; + } + ASSERT_EQ(bitset[i], b); + } else if (op == TestInplaceOp::OR) { + bool b = ref_bitset[i]; + for (size_t j = 0; j < n_others; j++) { + b |= ref_others[j][i]; + } + ASSERT_EQ(bitset[i], b); + } else { + ASSERT_TRUE(false) << "Not implemented"; + } + } +} + +template +void +TestInplaceOpMultipleImpl() { + const auto inplace_ops = {TestInplaceOp::AND, TestInplaceOp::OR}; + + for (const size_t n : typical_sizes) { + for (const size_t n_ngb : {1, 2, 4, 8}) { + for (const auto op : inplace_ops) { + BitsetT bitset(n); + bitset.reset(); + + std::vector bitset_others; + for (size_t i = 0; i < n_ngb; i++) { + BitsetT bitset_other(n); + bitset_other.reset(); + + bitset_others.push_back(std::move(bitset_other)); + } + + if (print_log) { + printf("Testing bitset, n=%zd, op=%zd\n", n, (size_t)op); + } + + TestInplaceOpMultipleImpl(bitset, bitset_others, op); + + for (const size_t offset : typical_offsets) { + if (offset >= n) { + continue; + } + + bitset.reset(); + auto view = bitset.view(offset); + + std::vector view_others; + for (size_t i = 0; i < n_ngb; i++) { + bitset_others[i].reset(); + auto view_other = bitset_others[i].view(offset); + + view_others.push_back(std::move(view_other)); + } + + if (print_log) { + printf( + "Testing bitset view, n=%zd, offset=%zd, op=%zd\n", + n, + offset, + (size_t)op); + } + + TestInplaceOpMultipleImpl( + view, view_others, op); + } + } + } + } +} + +// +template +class InplaceOpMultipleSuite : public ::testing::Test {}; + +TYPED_TEST_SUITE_P(InplaceOpMultipleSuite); + +TYPED_TEST_P(InplaceOpMultipleSuite, BitWise) { + using impl_traits = RefImplTraits, + std::tuple_element_t<1, TypeParam>>; + TestInplaceOpMultipleImpl(); +} + +TYPED_TEST_P(InplaceOpMultipleSuite, ElementWise) { + using impl_traits = ElementImplTraits, + std::tuple_element_t<1, TypeParam>>; + TestInplaceOpMultipleImpl(); +} + +TYPED_TEST_P(InplaceOpMultipleSuite, Avx2) { +#if defined(__x86_64__) + using namespace milvus::bitset::detail::x86; + + if (cpu_support_avx2()) { + using impl_traits = + VectorizedImplTraits, + std::tuple_element_t<1, TypeParam>, + milvus::bitset::detail::x86::VectorizedAvx2>; + TestInplaceOpMultipleImpl(); + } +#endif +} + +TYPED_TEST_P(InplaceOpMultipleSuite, Avx512) { +#if defined(__x86_64__) + using namespace milvus::bitset::detail::x86; + + if (cpu_support_avx512()) { + using impl_traits = + VectorizedImplTraits, + std::tuple_element_t<1, TypeParam>, + milvus::bitset::detail::x86::VectorizedAvx512>; + TestInplaceOpMultipleImpl(); + } +#endif +} + +TYPED_TEST_P(InplaceOpMultipleSuite, Neon) { +#if defined(__aarch64__) + using namespace milvus::bitset::detail::arm; + + using impl_traits = + VectorizedImplTraits, + std::tuple_element_t<1, TypeParam>, + milvus::bitset::detail::arm::VectorizedNeon>; + TestInplaceOpMultipleImpl(); +#endif +} + +TYPED_TEST_P(InplaceOpMultipleSuite, Sve) { +#if defined(__aarch64__) && defined(__ARM_FEATURE_SVE) + using namespace milvus::bitset::detail::arm; + + using impl_traits = + VectorizedImplTraits, + std::tuple_element_t<1, TypeParam>, + milvus::bitset::detail::arm::VectorizedSve>; + TestInplaceOpMultipleImpl(); +#endif +} + +TYPED_TEST_P(InplaceOpMultipleSuite, Dynamic) { + using impl_traits = + VectorizedImplTraits, + std::tuple_element_t<1, TypeParam>, + milvus::bitset::detail::VectorizedDynamic>; + TestInplaceOpMultipleImpl(); +} + +TYPED_TEST_P(InplaceOpMultipleSuite, VecRef) { + using impl_traits = + VectorizedImplTraits, + std::tuple_element_t<1, TypeParam>, + milvus::bitset::detail::VectorizedRef>; + TestInplaceOpMultipleImpl(); +} + +// +REGISTER_TYPED_TEST_SUITE_P(InplaceOpMultipleSuite, + BitWise, + ElementWise, + Avx2, + Avx512, + Neon, + Sve, + Dynamic, + VecRef); + +INSTANTIATE_TYPED_TEST_SUITE_P(InplaceOpMultipleTest, + InplaceOpMultipleSuite, + Ttypes0); + +////////////////////////////////////////////////////////////////////////////////////////// + int main(int argc, char* argv[]) { ::testing::InitGoogleTest(&argc, argv);