From 5ced5d2e0c2c6afbf13b2337abc9e383f7183d89 Mon Sep 17 00:00:00 2001 From: Alex Guteniev Date: Wed, 10 Apr 2024 01:26:10 +0300 Subject: [PATCH] vectorize `find_first_of` for long needle (#4557) --- stl/src/vector_algorithms.cpp | 142 ++++++++++++++---- .../VSO_0000000_vector_algorithms/test.cpp | 2 +- 2 files changed, 116 insertions(+), 28 deletions(-) diff --git a/stl/src/vector_algorithms.cpp b/stl/src/vector_algorithms.cpp index 8f63bee1db..0eb3c4cf24 100644 --- a/stl/src/vector_algorithms.cpp +++ b/stl/src/vector_algorithms.cpp @@ -1997,50 +1997,138 @@ namespace { const void* __stdcall __std_find_first_of_trivial_impl( const void* _First1, const void* const _Last1, const void* const _First2, const void* const _Last2) noexcept { #ifndef _M_ARM64EC - const size_t _Needle_length = _Byte_length(_First2, _Last2); - - if (_Use_sse42() && _Needle_length <= 16) { + if (_Use_sse42()) { constexpr int _Op = (sizeof(_Ty) == 1 ? _SIDD_UBYTE_OPS : _SIDD_UWORD_OPS) | _SIDD_CMP_EQUAL_ANY | _SIDD_LEAST_SIGNIFICANT; constexpr int _Part_size_el = sizeof(_Ty) == 1 ? 16 : 8; + const size_t _Needle_length = _Byte_length(_First2, _Last2); + + if (_Needle_length <= 16) { + // Special handling of small needle + // The generic branch could also handle it but with slightly worse performance + + const int _Needle_length_el = static_cast(_Needle_length / sizeof(_Ty)); + + alignas(16) uint8_t _Tmp1[16]; + memcpy(_Tmp1, _First2, _Needle_length); + const __m128i _Needle = _mm_load_si128(reinterpret_cast(_Tmp1)); + + const size_t _Haystack_length = _Byte_length(_First1, _Last1); + const void* _Stop_at = _First1; + _Advance_bytes(_Stop_at, _Haystack_length & ~size_t{0xF}); - const int _Needle_length_el = static_cast(_Needle_length / sizeof(_Ty)); + while (_First1 != _Stop_at) { + const __m128i _Haystack_part = _mm_loadu_si128(static_cast(_First1)); + if (_mm_cmpestrc(_Needle, _Needle_length_el, _Haystack_part, _Part_size_el, _Op)) { + const int _Pos = _mm_cmpestri(_Needle, _Needle_length_el, _Haystack_part, _Part_size_el, _Op); + _Advance_bytes(_First1, _Pos * sizeof(_Ty)); + return _First1; + } - alignas(16) uint8_t _Tmp1[16]; - memcpy(_Tmp1, _First2, _Needle_length); - const __m128i _Needle = _mm_load_si128(reinterpret_cast(_Tmp1)); + _Advance_bytes(_First1, 16); + } - const size_t _Haystack_length = _Byte_length(_First1, _Last1); - const void* _Stop_at = _First1; - _Advance_bytes(_Stop_at, _Haystack_length & ~size_t{0xF}); + const size_t _Last_part_size = _Haystack_length & 0xF; + const int _Last_part_size_el = static_cast(_Last_part_size / sizeof(_Ty)); - while (_First1 != _Stop_at) { - const __m128i _Haystack_part = _mm_loadu_si128(static_cast(_First1)); + alignas(16) uint8_t _Tmp2[16]; + memcpy(_Tmp2, _First1, _Last_part_size); + const __m128i _Haystack_part = _mm_load_si128(reinterpret_cast(_Tmp2)); - if (_mm_cmpestrc(_Needle, _Needle_length_el, _Haystack_part, _Part_size_el, _Op)) { - const int _Pos = _mm_cmpestri(_Needle, _Needle_length_el, _Haystack_part, _Part_size_el, _Op); + if (_mm_cmpestrc(_Needle, _Needle_length_el, _Haystack_part, _Last_part_size_el, _Op)) { + const int _Pos = _mm_cmpestri(_Needle, _Needle_length_el, _Haystack_part, _Last_part_size_el, _Op); _Advance_bytes(_First1, _Pos * sizeof(_Ty)); return _First1; } - _Advance_bytes(_First1, 16); - } + _Advance_bytes(_First1, _Last_part_size); + return _First1; + } else { + const void* _Last_needle = _First2; + _Advance_bytes(_Last_needle, _Needle_length & ~size_t{0xF}); - const size_t _Last_part_size = _Haystack_length & 0xF; - const int _Last_part_size_el = static_cast(_Last_part_size / sizeof(_Ty)); + const int _Last_needle_length = static_cast(_Needle_length & 0xF); - alignas(16) uint8_t _Tmp2[16]; - memcpy(_Tmp2, _First1, _Last_part_size); - const __m128i _Haystack_last_part = _mm_load_si128(reinterpret_cast(_Tmp2)); + alignas(16) uint8_t _Tmp1[16]; + memcpy(_Tmp1, _Last_needle, _Last_needle_length); + const __m128i _Last_needle_val = _mm_load_si128(reinterpret_cast(_Tmp1)); + const int _Last_needle_length_el = _Last_needle_length / sizeof(_Ty); - if (_mm_cmpestrc(_Needle, _Needle_length_el, _Haystack_last_part, _Last_part_size_el, _Op)) { - const int _Pos = _mm_cmpestri(_Needle, _Needle_length_el, _Haystack_last_part, _Last_part_size_el, _Op); - _Advance_bytes(_First1, _Pos * sizeof(_Ty)); + constexpr int _Not_found = 16; // arbitrary value greater than any found value + + int _Found_pos = _Not_found; + + const size_t _Haystack_length = _Byte_length(_First1, _Last1); + const void* _Stop_at = _First1; + _Advance_bytes(_Stop_at, _Haystack_length & ~size_t{0xF}); + + while (_First1 != _Stop_at) { + const __m128i _Haystack_part = _mm_loadu_si128(static_cast(_First1)); + + for (const void* _Cur_needle = _First2; _Cur_needle != _Last_needle; + _Advance_bytes(_Cur_needle, 16)) { + const __m128i _Needle = _mm_loadu_si128(static_cast(_Cur_needle)); + if (_mm_cmpestrc(_Needle, _Part_size_el, _Haystack_part, _Part_size_el, _Op)) { + const int _Pos = _mm_cmpestri(_Needle, _Part_size_el, _Haystack_part, _Part_size_el, _Op); + if (_Pos < _Found_pos) { + _Found_pos = _Pos; + } + } + } + + if (const int _Needle_length_el = _Last_needle_length_el; _Needle_length_el != 0) { + const __m128i _Needle = _Last_needle_val; + if (_mm_cmpestrc(_Needle, _Needle_length_el, _Haystack_part, _Part_size_el, _Op)) { + const int _Pos = + _mm_cmpestri(_Needle, _Needle_length_el, _Haystack_part, _Part_size_el, _Op); + if (_Pos < _Found_pos) { + _Found_pos = _Pos; + } + } + } + + if (_Found_pos != _Not_found) { + _Advance_bytes(_First1, _Found_pos * sizeof(_Ty)); + return _First1; + } + + _Advance_bytes(_First1, 16); + } + + const size_t _Last_part_size = _Haystack_length & 0xF; + const int _Last_part_size_el = static_cast(_Last_part_size / sizeof(_Ty)); + + alignas(16) uint8_t _Tmp2[16]; + memcpy(_Tmp2, _First1, _Last_part_size); + const __m128i _Haystack_part = _mm_load_si128(reinterpret_cast(_Tmp2)); + + _Found_pos = _Last_part_size_el; + + for (const void* _Cur_needle = _First2; _Cur_needle != _Last_needle; _Advance_bytes(_Cur_needle, 16)) { + const __m128i _Needle = _mm_loadu_si128(static_cast(_Cur_needle)); + + if (_mm_cmpestrc(_Needle, _Part_size_el, _Haystack_part, _Last_part_size_el, _Op)) { + const int _Pos = _mm_cmpestri(_Needle, _Part_size_el, _Haystack_part, _Last_part_size_el, _Op); + if (_Pos < _Found_pos) { + _Found_pos = _Pos; + } + } + } + + if (const int _Needle_length_el = _Last_needle_length_el; _Needle_length_el != 0) { + const __m128i _Needle = _Last_needle_val; + if (_mm_cmpestrc(_Needle, _Needle_length_el, _Haystack_part, _Last_part_size_el, _Op)) { + const int _Pos = + _mm_cmpestri(_Needle, _Needle_length_el, _Haystack_part, _Last_part_size_el, _Op); + if (_Pos < _Found_pos) { + _Found_pos = _Pos; + } + } + } + + _Advance_bytes(_First1, _Found_pos * sizeof(_Ty)); return _First1; } - - _Advance_bytes(_First1, _Last_part_size); - return _First1; } #endif // !_M_ARM64EC diff --git a/tests/std/tests/VSO_0000000_vector_algorithms/test.cpp b/tests/std/tests/VSO_0000000_vector_algorithms/test.cpp index 39eed458ad..49fec16ea0 100644 --- a/tests/std/tests/VSO_0000000_vector_algorithms/test.cpp +++ b/tests/std/tests/VSO_0000000_vector_algorithms/test.cpp @@ -238,7 +238,7 @@ void test_case_find_first_of(const vector& input_haystack, const vector& i template void test_find_first_of(mt19937_64& gen) { - constexpr size_t needleDataCount = 30; + constexpr size_t needleDataCount = 50; using TD = conditional_t; uniform_int_distribution dis('a', 'z'); vector input_haystack;