Skip to content

Commit

Permalink
vectorize find_first_of for long needle (#4557)
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexGuteniev authored Apr 9, 2024
1 parent d872543 commit 5ced5d2
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 28 deletions.
142 changes: 115 additions & 27 deletions stl/src/vector_algorithms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1997,50 +1997,138 @@ namespace {
const void* __stdcall __std_find_first_of_trivial_impl(
const void* _First1, const void* const _Last1, const void* const _First2, const void* const _Last2) noexcept {
#ifndef _M_ARM64EC
const size_t _Needle_length = _Byte_length(_First2, _Last2);

if (_Use_sse42() && _Needle_length <= 16) {
if (_Use_sse42()) {
constexpr int _Op =
(sizeof(_Ty) == 1 ? _SIDD_UBYTE_OPS : _SIDD_UWORD_OPS) | _SIDD_CMP_EQUAL_ANY | _SIDD_LEAST_SIGNIFICANT;
constexpr int _Part_size_el = sizeof(_Ty) == 1 ? 16 : 8;
const size_t _Needle_length = _Byte_length(_First2, _Last2);

if (_Needle_length <= 16) {
// Special handling of small needle
// The generic branch could also handle it but with slightly worse performance

const int _Needle_length_el = static_cast<int>(_Needle_length / sizeof(_Ty));

alignas(16) uint8_t _Tmp1[16];
memcpy(_Tmp1, _First2, _Needle_length);
const __m128i _Needle = _mm_load_si128(reinterpret_cast<const __m128i*>(_Tmp1));

const size_t _Haystack_length = _Byte_length(_First1, _Last1);
const void* _Stop_at = _First1;
_Advance_bytes(_Stop_at, _Haystack_length & ~size_t{0xF});

const int _Needle_length_el = static_cast<int>(_Needle_length / sizeof(_Ty));
while (_First1 != _Stop_at) {
const __m128i _Haystack_part = _mm_loadu_si128(static_cast<const __m128i*>(_First1));
if (_mm_cmpestrc(_Needle, _Needle_length_el, _Haystack_part, _Part_size_el, _Op)) {
const int _Pos = _mm_cmpestri(_Needle, _Needle_length_el, _Haystack_part, _Part_size_el, _Op);
_Advance_bytes(_First1, _Pos * sizeof(_Ty));
return _First1;
}

alignas(16) uint8_t _Tmp1[16];
memcpy(_Tmp1, _First2, _Needle_length);
const __m128i _Needle = _mm_load_si128(reinterpret_cast<const __m128i*>(_Tmp1));
_Advance_bytes(_First1, 16);
}

const size_t _Haystack_length = _Byte_length(_First1, _Last1);
const void* _Stop_at = _First1;
_Advance_bytes(_Stop_at, _Haystack_length & ~size_t{0xF});
const size_t _Last_part_size = _Haystack_length & 0xF;
const int _Last_part_size_el = static_cast<int>(_Last_part_size / sizeof(_Ty));

while (_First1 != _Stop_at) {
const __m128i _Haystack_part = _mm_loadu_si128(static_cast<const __m128i*>(_First1));
alignas(16) uint8_t _Tmp2[16];
memcpy(_Tmp2, _First1, _Last_part_size);
const __m128i _Haystack_part = _mm_load_si128(reinterpret_cast<const __m128i*>(_Tmp2));

if (_mm_cmpestrc(_Needle, _Needle_length_el, _Haystack_part, _Part_size_el, _Op)) {
const int _Pos = _mm_cmpestri(_Needle, _Needle_length_el, _Haystack_part, _Part_size_el, _Op);
if (_mm_cmpestrc(_Needle, _Needle_length_el, _Haystack_part, _Last_part_size_el, _Op)) {
const int _Pos = _mm_cmpestri(_Needle, _Needle_length_el, _Haystack_part, _Last_part_size_el, _Op);
_Advance_bytes(_First1, _Pos * sizeof(_Ty));
return _First1;
}

_Advance_bytes(_First1, 16);
}
_Advance_bytes(_First1, _Last_part_size);
return _First1;
} else {
const void* _Last_needle = _First2;
_Advance_bytes(_Last_needle, _Needle_length & ~size_t{0xF});

const size_t _Last_part_size = _Haystack_length & 0xF;
const int _Last_part_size_el = static_cast<int>(_Last_part_size / sizeof(_Ty));
const int _Last_needle_length = static_cast<int>(_Needle_length & 0xF);

alignas(16) uint8_t _Tmp2[16];
memcpy(_Tmp2, _First1, _Last_part_size);
const __m128i _Haystack_last_part = _mm_load_si128(reinterpret_cast<const __m128i*>(_Tmp2));
alignas(16) uint8_t _Tmp1[16];
memcpy(_Tmp1, _Last_needle, _Last_needle_length);
const __m128i _Last_needle_val = _mm_load_si128(reinterpret_cast<const __m128i*>(_Tmp1));
const int _Last_needle_length_el = _Last_needle_length / sizeof(_Ty);

if (_mm_cmpestrc(_Needle, _Needle_length_el, _Haystack_last_part, _Last_part_size_el, _Op)) {
const int _Pos = _mm_cmpestri(_Needle, _Needle_length_el, _Haystack_last_part, _Last_part_size_el, _Op);
_Advance_bytes(_First1, _Pos * sizeof(_Ty));
constexpr int _Not_found = 16; // arbitrary value greater than any found value

int _Found_pos = _Not_found;

const size_t _Haystack_length = _Byte_length(_First1, _Last1);
const void* _Stop_at = _First1;
_Advance_bytes(_Stop_at, _Haystack_length & ~size_t{0xF});

while (_First1 != _Stop_at) {
const __m128i _Haystack_part = _mm_loadu_si128(static_cast<const __m128i*>(_First1));

for (const void* _Cur_needle = _First2; _Cur_needle != _Last_needle;
_Advance_bytes(_Cur_needle, 16)) {
const __m128i _Needle = _mm_loadu_si128(static_cast<const __m128i*>(_Cur_needle));
if (_mm_cmpestrc(_Needle, _Part_size_el, _Haystack_part, _Part_size_el, _Op)) {
const int _Pos = _mm_cmpestri(_Needle, _Part_size_el, _Haystack_part, _Part_size_el, _Op);
if (_Pos < _Found_pos) {
_Found_pos = _Pos;
}
}
}

if (const int _Needle_length_el = _Last_needle_length_el; _Needle_length_el != 0) {
const __m128i _Needle = _Last_needle_val;
if (_mm_cmpestrc(_Needle, _Needle_length_el, _Haystack_part, _Part_size_el, _Op)) {
const int _Pos =
_mm_cmpestri(_Needle, _Needle_length_el, _Haystack_part, _Part_size_el, _Op);
if (_Pos < _Found_pos) {
_Found_pos = _Pos;
}
}
}

if (_Found_pos != _Not_found) {
_Advance_bytes(_First1, _Found_pos * sizeof(_Ty));
return _First1;
}

_Advance_bytes(_First1, 16);
}

const size_t _Last_part_size = _Haystack_length & 0xF;
const int _Last_part_size_el = static_cast<int>(_Last_part_size / sizeof(_Ty));

alignas(16) uint8_t _Tmp2[16];
memcpy(_Tmp2, _First1, _Last_part_size);
const __m128i _Haystack_part = _mm_load_si128(reinterpret_cast<const __m128i*>(_Tmp2));

_Found_pos = _Last_part_size_el;

for (const void* _Cur_needle = _First2; _Cur_needle != _Last_needle; _Advance_bytes(_Cur_needle, 16)) {
const __m128i _Needle = _mm_loadu_si128(static_cast<const __m128i*>(_Cur_needle));

if (_mm_cmpestrc(_Needle, _Part_size_el, _Haystack_part, _Last_part_size_el, _Op)) {
const int _Pos = _mm_cmpestri(_Needle, _Part_size_el, _Haystack_part, _Last_part_size_el, _Op);
if (_Pos < _Found_pos) {
_Found_pos = _Pos;
}
}
}

if (const int _Needle_length_el = _Last_needle_length_el; _Needle_length_el != 0) {
const __m128i _Needle = _Last_needle_val;
if (_mm_cmpestrc(_Needle, _Needle_length_el, _Haystack_part, _Last_part_size_el, _Op)) {
const int _Pos =
_mm_cmpestri(_Needle, _Needle_length_el, _Haystack_part, _Last_part_size_el, _Op);
if (_Pos < _Found_pos) {
_Found_pos = _Pos;
}
}
}

_Advance_bytes(_First1, _Found_pos * sizeof(_Ty));
return _First1;
}

_Advance_bytes(_First1, _Last_part_size);
return _First1;
}
#endif // !_M_ARM64EC

Expand Down
2 changes: 1 addition & 1 deletion tests/std/tests/VSO_0000000_vector_algorithms/test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ void test_case_find_first_of(const vector<T>& input_haystack, const vector<T>& i

template <class T>
void test_find_first_of(mt19937_64& gen) {
constexpr size_t needleDataCount = 30;
constexpr size_t needleDataCount = 50;
using TD = conditional_t<sizeof(T) == 1, int, T>;
uniform_int_distribution<TD> dis('a', 'z');
vector<T> input_haystack;
Expand Down

0 comments on commit 5ced5d2

Please sign in to comment.