From 630f713c2f8003d6802d71889407e5c1f80bbbce Mon Sep 17 00:00:00 2001 From: I <1091761+wx257osn2@users.noreply.github.com> Date: Wed, 31 May 2023 07:41:09 +0900 Subject: [PATCH 1/5] stop using slow across-vector operation twice --- faiss/utils/simdlib_neon.h | 47 ++++++++++++++------------------------ 1 file changed, 17 insertions(+), 30 deletions(-) diff --git a/faiss/utils/simdlib_neon.h b/faiss/utils/simdlib_neon.h index 1dbfa2cd27..57c5e19259 100644 --- a/faiss/utils/simdlib_neon.h +++ b/faiss/utils/simdlib_neon.h @@ -560,14 +560,10 @@ struct simd16uint16 { // Checks whether the other holds exactly the same bytes. bool is_same_as(simd16uint16 other) const { - const bool equal0 = - (vminvq_u16(vceqq_u16(data.val[0], other.data.val[0])) == - 0xffff); - const bool equal1 = - (vminvq_u16(vceqq_u16(data.val[1], other.data.val[1])) == - 0xffff); - - return equal0 && equal1; + const auto equals = detail::simdlib::binary_func(data, other.data) + .call<&vceqq_u16>(); + const auto equal = vandq_u16(equals.val[0], equals.val[1]); + return vminvq_u16(equal) == 0xffffu; } simd16uint16 operator~() const { @@ -870,12 +866,10 @@ struct simd32uint8 { // Checks whether the other holds exactly the same bytes. bool is_same_as(simd32uint8 other) const { - const bool equal0 = - (vminvq_u8(vceqq_u8(data.val[0], other.data.val[0])) == 0xff); - const bool equal1 = - (vminvq_u8(vceqq_u8(data.val[1], other.data.val[1])) == 0xff); - - return equal0 && equal1; + const auto equals = detail::simdlib::binary_func(data, other.data) + .call<&vceqq_u8>(); + const auto equal = vandq_u8(equals.val[0], equals.val[1]); + return vminvq_u8(equal) == 0xffu; } }; @@ -973,14 +967,10 @@ struct simd8uint32 { // Checks whether the other holds exactly the same bytes. bool is_same_as(simd8uint32 other) const { - const bool equal0 = - (vminvq_u32(vceqq_u32(data.val[0], other.data.val[0])) == - 0xffffffff); - const bool equal1 = - (vminvq_u32(vceqq_u32(data.val[1], other.data.val[1])) == - 0xffffffff); - - return equal0 && equal1; + const auto equals = detail::simdlib::binary_func(data, other.data) + .call<&vceqq_u32>(); + const auto equal = vandq_u32(equals.val[0], equals.val[1]); + return vminvq_u32(equal) == 0xffffffffu; } void clear() { @@ -1181,14 +1171,11 @@ struct simd8float32 { // Checks whether the other holds exactly the same bytes. bool is_same_as(simd8float32 other) const { - const bool equal0 = - (vminvq_u32(vceqq_f32(data.val[0], other.data.val[0])) == - 0xffffffff); - const bool equal1 = - (vminvq_u32(vceqq_f32(data.val[1], other.data.val[1])) == - 0xffffffff); - - return equal0 && equal1; + const auto equals = + detail::simdlib::binary_func<::uint32x4x2_t>(data, other.data) + .call<&vceqq_f32>(); + const auto equal = vandq_u32(equals.val[0], equals.val[1]); + return vminvq_u32(equal) == 0xffffffffu; } std::string tostring() const { From adf8b2868bcd2b8c7969c461f02ca578c734f487 Mon Sep 17 00:00:00 2001 From: I <1091761+wx257osn2@users.noreply.github.com> Date: Wed, 31 May 2023 07:48:16 +0900 Subject: [PATCH 2/5] unify semantics of operator== as same as simd16uint16 --- faiss/utils/simdlib_neon.h | 29 +++++++++++++++-------------- 1 file changed, 15 insertions(+), 14 deletions(-) diff --git a/faiss/utils/simdlib_neon.h b/faiss/utils/simdlib_neon.h index 57c5e19259..5f0d111108 100644 --- a/faiss/utils/simdlib_neon.h +++ b/faiss/utils/simdlib_neon.h @@ -954,15 +954,18 @@ struct simd8uint32 { return *this; } - bool operator==(simd8uint32 other) const { - const auto equals = detail::simdlib::binary_func(data, other.data) - .call<&vceqq_u32>(); - const auto equal = vandq_u32(equals.val[0], equals.val[1]); - return vminvq_u32(equal) == 0xffffffff; + simd8uint32 operator==(simd8uint32 other) const { + return simd8uint32{detail::simdlib::binary_func(data, other.data) + .call<&vceqq_u32>()}; } - bool operator!=(simd8uint32 other) const { - return !(*this == other); + simd8uint32 operator~() const { + return simd8uint32{ + detail::simdlib::unary_func(data).call<&vmvnq_u32>()}; + } + + simd8uint32 operator!=(simd8uint32 other) const { + return ~(*this == other); } // Checks whether the other holds exactly the same bytes. @@ -1157,16 +1160,14 @@ struct simd8float32 { return *this; } - bool operator==(simd8float32 other) const { - const auto equals = + simd8uint32 operator==(simd8float32 other) const { + return simd8uint32{ detail::simdlib::binary_func<::uint32x4x2_t>(data, other.data) - .call<&vceqq_f32>(); - const auto equal = vandq_u32(equals.val[0], equals.val[1]); - return vminvq_u32(equal) == 0xffffffff; + .call<&vceqq_f32>()}; } - bool operator!=(simd8float32 other) const { - return !(*this == other); + simd8uint32 operator!=(simd8float32 other) const { + return ~(*this == other); } // Checks whether the other holds exactly the same bytes. From 1f4429682e001161931ad01c880d4876d05293de Mon Sep 17 00:00:00 2001 From: I <1091761+wx257osn2@users.noreply.github.com> Date: Wed, 31 May 2023 08:19:59 +0900 Subject: [PATCH 3/5] use binary_func --- faiss/utils/simdlib_neon.h | 57 +++++++++++++++++++------------------- 1 file changed, 28 insertions(+), 29 deletions(-) diff --git a/faiss/utils/simdlib_neon.h b/faiss/utils/simdlib_neon.h index 5f0d111108..8566e4e0b8 100644 --- a/faiss/utils/simdlib_neon.h +++ b/faiss/utils/simdlib_neon.h @@ -685,13 +685,12 @@ inline void cmplt_min_max_fast( simd16uint16& minIndices, simd16uint16& maxValues, simd16uint16& maxIndices) { - const uint16x8x2_t comparison = uint16x8x2_t{ - vcltq_u16(candidateValues.data.val[0], currentValues.data.val[0]), - vcltq_u16(candidateValues.data.val[1], currentValues.data.val[1])}; + const uint16x8x2_t comparison = + detail::simdlib::binary_func( + candidateValues.data, currentValues.data) + .call<&vcltq_u16>(); - minValues.data = uint16x8x2_t{ - vminq_u16(candidateValues.data.val[0], currentValues.data.val[0]), - vminq_u16(candidateValues.data.val[1], currentValues.data.val[1])}; + minValues = min(candidateValues, currentValues); minIndices.data = uint16x8x2_t{ vbslq_u16( comparison.val[0], @@ -702,9 +701,7 @@ inline void cmplt_min_max_fast( candidateIndices.data.val[1], currentIndices.data.val[1])}; - maxValues.data = uint16x8x2_t{ - vmaxq_u16(candidateValues.data.val[0], currentValues.data.val[0]), - vmaxq_u16(candidateValues.data.val[1], currentValues.data.val[1])}; + maxValues = max(candidateValues, currentValues); maxIndices.data = uint16x8x2_t{ vbslq_u16( comparison.val[0], @@ -1046,13 +1043,14 @@ inline void cmplt_min_max_fast( simd8uint32& minIndices, simd8uint32& maxValues, simd8uint32& maxIndices) { - const uint32x4x2_t comparison = uint32x4x2_t{ - vcltq_u32(candidateValues.data.val[0], currentValues.data.val[0]), - vcltq_u32(candidateValues.data.val[1], currentValues.data.val[1])}; - - minValues.data = uint32x4x2_t{ - vminq_u32(candidateValues.data.val[0], currentValues.data.val[0]), - vminq_u32(candidateValues.data.val[1], currentValues.data.val[1])}; + const uint32x4x2_t comparison = + detail::simdlib::binary_func( + candidateValues.data, currentValues.data) + .call<&vcltq_u32>(); + + minValues.data = detail::simdlib::binary_func( + candidateValues.data, currentValues.data) + .call<&vminq_u32>(); minIndices.data = uint32x4x2_t{ vbslq_u32( comparison.val[0], @@ -1063,9 +1061,9 @@ inline void cmplt_min_max_fast( candidateIndices.data.val[1], currentIndices.data.val[1])}; - maxValues.data = uint32x4x2_t{ - vmaxq_u32(candidateValues.data.val[0], currentValues.data.val[0]), - vmaxq_u32(candidateValues.data.val[1], currentValues.data.val[1])}; + maxValues.data = detail::simdlib::binary_func( + candidateValues.data, currentValues.data) + .call<&vmaxq_u32>(); maxIndices.data = uint32x4x2_t{ vbslq_u32( comparison.val[0], @@ -1290,13 +1288,14 @@ inline void cmplt_min_max_fast( simd8uint32& minIndices, simd8float32& maxValues, simd8uint32& maxIndices) { - const uint32x4x2_t comparison = uint32x4x2_t{ - vcltq_f32(candidateValues.data.val[0], currentValues.data.val[0]), - vcltq_f32(candidateValues.data.val[1], currentValues.data.val[1])}; - - minValues.data = float32x4x2_t{ - vminq_f32(candidateValues.data.val[0], currentValues.data.val[0]), - vminq_f32(candidateValues.data.val[1], currentValues.data.val[1])}; + const uint32x4x2_t comparison = + detail::simdlib::binary_func<::uint32x4x2_t>( + candidateValues.data, currentValues.data) + .call<&vcltq_f32>(); + + minValues.data = detail::simdlib::binary_func( + candidateValues.data, currentValues.data) + .call<&vminq_f32>(); minIndices.data = uint32x4x2_t{ vbslq_u32( comparison.val[0], @@ -1307,9 +1306,9 @@ inline void cmplt_min_max_fast( candidateIndices.data.val[1], currentIndices.data.val[1])}; - maxValues.data = float32x4x2_t{ - vmaxq_f32(candidateValues.data.val[0], currentValues.data.val[0]), - vmaxq_f32(candidateValues.data.val[1], currentValues.data.val[1])}; + maxValues.data = detail::simdlib::binary_func( + candidateValues.data, currentValues.data) + .call<&vmaxq_f32>(); maxIndices.data = uint32x4x2_t{ vbslq_u32( comparison.val[0], From ea0ea3860f04736fd9e9d6fe2bda2f5973adcc53 Mon Sep 17 00:00:00 2001 From: I <1091761+wx257osn2@users.noreply.github.com> Date: Wed, 31 May 2023 13:11:47 +0900 Subject: [PATCH 4/5] is_same_as needs to receive simdlib vector with any element types --- faiss/utils/simdlib_neon.h | 32 ++++++++++++++++++++------------ 1 file changed, 20 insertions(+), 12 deletions(-) diff --git a/faiss/utils/simdlib_neon.h b/faiss/utils/simdlib_neon.h index 8566e4e0b8..656a561217 100644 --- a/faiss/utils/simdlib_neon.h +++ b/faiss/utils/simdlib_neon.h @@ -559,9 +559,11 @@ struct simd16uint16 { } // Checks whether the other holds exactly the same bytes. - bool is_same_as(simd16uint16 other) const { - const auto equals = detail::simdlib::binary_func(data, other.data) - .call<&vceqq_u16>(); + template + bool is_same_as(T other) const { + const auto o = detail::simdlib::reinterpret_u16(other.data); + const auto equals = detail::simdlib::binary_func(data, o) + .template call<&vceqq_u16>(); const auto equal = vandq_u16(equals.val[0], equals.val[1]); return vminvq_u16(equal) == 0xffffu; } @@ -862,9 +864,11 @@ struct simd32uint8 { } // Checks whether the other holds exactly the same bytes. - bool is_same_as(simd32uint8 other) const { - const auto equals = detail::simdlib::binary_func(data, other.data) - .call<&vceqq_u8>(); + template + bool is_same_as(T other) const { + const auto o = detail::simdlib::reinterpret_u8(other.data); + const auto equals = detail::simdlib::binary_func(data, o) + .template call<&vceqq_u8>(); const auto equal = vandq_u8(equals.val[0], equals.val[1]); return vminvq_u8(equal) == 0xffu; } @@ -966,9 +970,11 @@ struct simd8uint32 { } // Checks whether the other holds exactly the same bytes. - bool is_same_as(simd8uint32 other) const { - const auto equals = detail::simdlib::binary_func(data, other.data) - .call<&vceqq_u32>(); + template + bool is_same_as(T other) const { + const auto o = detail::simdlib::reinterpret_u32(other.data); + const auto equals = detail::simdlib::binary_func(data, o) + .template call<&vceqq_u32>(); const auto equal = vandq_u32(equals.val[0], equals.val[1]); return vminvq_u32(equal) == 0xffffffffu; } @@ -1169,10 +1175,12 @@ struct simd8float32 { } // Checks whether the other holds exactly the same bytes. - bool is_same_as(simd8float32 other) const { + template + bool is_same_as(T other) const { + const auto o = detail::simdlib::reinterpret_f32(other.data); const auto equals = - detail::simdlib::binary_func<::uint32x4x2_t>(data, other.data) - .call<&vceqq_f32>(); + detail::simdlib::binary_func<::uint32x4x2_t>(data, o) + .template call<&vceqq_f32>(); const auto equal = vandq_u32(equals.val[0], equals.val[1]); return vminvq_u32(equal) == 0xffffffffu; } From c67f521d62a36b0ebd9bc982cace63dcf903b409 Mon Sep 17 00:00:00 2001 From: I <1091761+wx257osn2@users.noreply.github.com> Date: Wed, 31 May 2023 13:11:59 +0900 Subject: [PATCH 5/5] remove unused function --- faiss/utils/simdlib_avx2.h | 6 ------ 1 file changed, 6 deletions(-) diff --git a/faiss/utils/simdlib_avx2.h b/faiss/utils/simdlib_avx2.h index 34d788ccd5..fc51e3ed18 100644 --- a/faiss/utils/simdlib_avx2.h +++ b/faiss/utils/simdlib_avx2.h @@ -202,12 +202,6 @@ struct simd16uint16 : simd256bit { return simd16uint16(_mm256_cmpeq_epi16(lhs.i, rhs.i)); } - bool is_same(simd16uint16 other) const { - const __m256i pcmp = _mm256_cmpeq_epi16(i, other.i); - unsigned bitmask = _mm256_movemask_epi8(pcmp); - return (bitmask == 0xffffffffU); - } - simd16uint16 operator~() const { return simd16uint16(_mm256_xor_si256(i, _mm256_set1_epi32(-1))); }