Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Vectorize std::search of 1 and 2 bytes elements with pcmpestri #4745

Merged
merged 49 commits into from
Sep 9, 2024
Merged
Show file tree
Hide file tree
Changes from 37 commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
73c96da
vectorize search
AlexGuteniev May 5, 2024
0c17a53
very tail fix
AlexGuteniev May 5, 2024
11c05ee
I 🧡 ADL
AlexGuteniev May 5, 2024
d4fcc96
unify ipsum
AlexGuteniev May 5, 2024
da5cf2e
-newline
AlexGuteniev May 5, 2024
da157b1
`strstr` for competition
AlexGuteniev May 5, 2024
772c513
missing progress
AlexGuteniev May 5, 2024
2c6c329
coverage
AlexGuteniev May 5, 2024
81a6000
these tests are too long
AlexGuteniev May 5, 2024
0b59b2e
missing include
AlexGuteniev May 5, 2024
f2806c5
default_searcher
AlexGuteniev May 5, 2024
15e54a9
ADL again
AlexGuteniev May 5, 2024
26646fe
avoid `memcmp` in fallback
AlexGuteniev May 5, 2024
0c473a4
partial review comment
AlexGuteniev Jun 7, 2024
3452fcc
Merge branch 'main' into search
StephanTLavavej Jun 10, 2024
629afd4
Internal static assert `sizeof(_Ty1) == sizeof(_Ty2)`.
StephanTLavavej Jun 10, 2024
a24e6eb
Use `+=` and `+` instead of `_RANGES next`.
StephanTLavavej Jun 10, 2024
9d07a40
Style: Return `_Ptr_res1` instead of `_Ptr_last1` when they're equal.
StephanTLavavej Jun 10, 2024
d57f9b6
Style: In `<algorithm>` and `<functional>`, `_Ptr_last1` doesn't need…
StephanTLavavej Jun 10, 2024
e51b98d
Restore top-level constness for `_UFirst2`.
StephanTLavavej Jun 10, 2024
d4462a5
Benchmark classic search().
StephanTLavavej Jun 10, 2024
95ba820
Simplify `last_known_good_search()`.
StephanTLavavej Jun 10, 2024
72a0d29
Revert vectorized implementation.
StephanTLavavej Jun 10, 2024
38b32d6
Drop `memcmp` paths from `_Equal_rev_pred_unchecked` and `_Equal_rev_…
StephanTLavavej Jun 10, 2024
1e16233
Merge remote-tracking branch 'upstream/main' into search
AlexGuteniev Jun 20, 2024
f269d6c
Revert "Revert vectorized implementation."
AlexGuteniev Jun 20, 2024
dc7eb5b
drop 4 and 8 bytes search optimization for now
AlexGuteniev Jun 24, 2024
0926486
SSE4.2 madness
AlexGuteniev Jun 24, 2024
ba63dbb
better approach
AlexGuteniev Jun 25, 2024
c293748
elegant tail
AlexGuteniev Jun 25, 2024
004d431
big needle benchmark
AlexGuteniev Jun 28, 2024
709ed47
large needle optimization
AlexGuteniev Jun 28, 2024
1c66f01
prevent found data withing overflown part
AlexGuteniev Jun 29, 2024
43e0eec
proper tail length
AlexGuteniev Jun 29, 2024
1420757
better match coverage
AlexGuteniev Jun 29, 2024
fa9d52f
bring back optimization
AlexGuteniev Jun 29, 2024
dfd69e8
i consistent
AlexGuteniev Jun 29, 2024
93cdcf0
Merge branch 'main' into search
StephanTLavavej Sep 4, 2024
96a4d58
Avoid truncation warnings in `_First1 + _Count2`.
StephanTLavavej Sep 5, 2024
2a239a7
Style and comment nitpicks.
StephanTLavavej Sep 5, 2024
3bc1d56
Benchmark: Use a constexpr array of string_view.
StephanTLavavej Sep 5, 2024
c1aaba7
Add const.
StephanTLavavej Sep 5, 2024
6276567
Don't help the compiler - let it deduce `_Ty`.
StephanTLavavej Sep 5, 2024
05e435d
Drop inconsistent `_CSTD`.
StephanTLavavej Sep 5, 2024
e7ec67a
input_needle is guaranteed non-empty here.
StephanTLavavej Sep 5, 2024
9d11dcc
Avoid permanently modifying the haystack.
StephanTLavavej Sep 5, 2024
abae4ed
Bugfix: Use an unaligned load from `_First2`.
StephanTLavavej Sep 5, 2024
e96407b
`_Count2` is more natural than `_Last2`
AlexGuteniev Sep 7, 2024
e0c843d
-hiding
AlexGuteniev Sep 7, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 29 additions & 14 deletions benchmarks/src/search.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,14 @@ const char src_haystack[] =
"euismod eros, ut posuere ligula ullamcorper id. Nullam aliquet malesuada est at dignissim. Pellentesque finibus "
"sagittis libero nec bibendum. Phasellus dolor ipsum, finibus quis turpis quis, mollis interdum felis.";

const char src_needle[] = "aliquet";
const std::vector<std::string> patterns = {
StephanTLavavej marked this conversation as resolved.
Show resolved Hide resolved
"aliquet",
"aliquet malesuada",
};

void c_strstr(benchmark::State& state) {
const auto& src_needle = patterns[static_cast<size_t>(state.range())];

const std::string haystack(std::begin(src_haystack), std::end(src_haystack));
const std::string needle(std::begin(src_needle), std::end(src_needle));

Expand All @@ -56,6 +61,8 @@ void c_strstr(benchmark::State& state) {

template <class T>
void classic_search(benchmark::State& state) {
const auto& src_needle = patterns[static_cast<size_t>(state.range())];

const std::vector<T> haystack(std::begin(src_haystack), std::end(src_haystack));
const std::vector<T> needle(std::begin(src_needle), std::end(src_needle));

Expand All @@ -69,6 +76,8 @@ void classic_search(benchmark::State& state) {

template <class T>
void ranges_search(benchmark::State& state) {
const auto& src_needle = patterns[static_cast<size_t>(state.range())];

const std::vector<T> haystack(std::begin(src_haystack), std::end(src_haystack));
const std::vector<T> needle(std::begin(src_needle), std::end(src_needle));

Expand All @@ -82,6 +91,8 @@ void ranges_search(benchmark::State& state) {

template <class T>
void search_default_searcher(benchmark::State& state) {
const auto& src_needle = patterns[static_cast<size_t>(state.range())];

const std::vector<T> haystack(std::begin(src_haystack), std::end(src_haystack));
const std::vector<T> needle(std::begin(src_needle), std::end(src_needle));

Expand All @@ -93,22 +104,26 @@ void search_default_searcher(benchmark::State& state) {
}
}

BENCHMARK(c_strstr);
void common_args(auto bm) {
bm->Range(0, patterns.size() - 1);
}

BENCHMARK(c_strstr)->Apply(common_args);

BENCHMARK(classic_search<std::uint8_t>);
BENCHMARK(classic_search<std::uint16_t>);
BENCHMARK(classic_search<std::uint32_t>);
BENCHMARK(classic_search<std::uint64_t>);
BENCHMARK(classic_search<std::uint8_t>)->Apply(common_args);
BENCHMARK(classic_search<std::uint16_t>)->Apply(common_args);
BENCHMARK(classic_search<std::uint32_t>)->Apply(common_args);
BENCHMARK(classic_search<std::uint64_t>)->Apply(common_args);

BENCHMARK(ranges_search<std::uint8_t>);
BENCHMARK(ranges_search<std::uint16_t>);
BENCHMARK(ranges_search<std::uint32_t>);
BENCHMARK(ranges_search<std::uint64_t>);
BENCHMARK(ranges_search<std::uint8_t>)->Apply(common_args);
BENCHMARK(ranges_search<std::uint16_t>)->Apply(common_args);
BENCHMARK(ranges_search<std::uint32_t>)->Apply(common_args);
BENCHMARK(ranges_search<std::uint64_t>)->Apply(common_args);

BENCHMARK(search_default_searcher<std::uint8_t>);
BENCHMARK(search_default_searcher<std::uint16_t>);
BENCHMARK(search_default_searcher<std::uint32_t>);
BENCHMARK(search_default_searcher<std::uint64_t>);
BENCHMARK(search_default_searcher<std::uint8_t>)->Apply(common_args);
BENCHMARK(search_default_searcher<std::uint16_t>)->Apply(common_args);
BENCHMARK(search_default_searcher<std::uint32_t>)->Apply(common_args);
BENCHMARK(search_default_searcher<std::uint64_t>)->Apply(common_args);


BENCHMARK_MAIN();
19 changes: 19 additions & 0 deletions stl/inc/algorithm
Original file line number Diff line number Diff line change
Expand Up @@ -2142,6 +2142,25 @@ _NODISCARD _CONSTEXPR20 _FwdItHaystack search(_FwdItHaystack _First1, _FwdItHays
if constexpr (_Is_ranges_random_iter_v<_FwdItHaystack> && _Is_ranges_random_iter_v<_FwdItPat>) {
const _Iter_diff_t<_FwdItPat> _Count2 = _ULast2 - _UFirst2;
if (_ULast1 - _UFirst1 >= _Count2) {
#if _USE_STD_VECTOR_ALGORITHMS
if constexpr (_Vector_alg_in_search_is_safe<decltype(_UFirst1), decltype(_UFirst2), _Pr>) {
if (!_STD _Is_constant_evaluated()) {
const auto _Ptr1 = _STD _To_address(_UFirst1);

const auto _Ptr_res1 = _STD _Search_vectorized(
_Ptr1, _STD _To_address(_ULast1), _STD _To_address(_UFirst2), _STD _To_address(_ULast2));

if constexpr (is_pointer_v<decltype(_UFirst1)>) {
_UFirst1 = _Ptr_res1;
} else {
_UFirst1 += _Ptr_res1 - _Ptr1;
}

_STD _Seek_wrapped(_Last1, _UFirst1);
return _Last1;
}
}
#endif // _USE_STD_VECTOR_ALGORITHMS
const auto _Last_possible = _ULast1 - static_cast<_Iter_diff_t<_FwdItHaystack>>(_Count2);
for (;; ++_UFirst1) {
if (_STD _Equal_rev_pred_unchecked(_UFirst1, _UFirst2, _ULast2, _STD _Pass_fn(_Pred))) {
Expand Down
23 changes: 23 additions & 0 deletions stl/inc/functional
Original file line number Diff line number Diff line change
Expand Up @@ -2459,6 +2459,29 @@ _CONSTEXPR20 pair<_FwdItHaystack, _FwdItHaystack> _Search_pair_unchecked(
_Iter_diff_t<_FwdItHaystack> _Count1 = _Last1 - _First1;
_Iter_diff_t<_FwdItPat> _Count2 = _Last2 - _First2;

#if _USE_STD_VECTOR_ALGORITHMS
if constexpr (_Vector_alg_in_search_is_safe<_FwdItHaystack, _FwdItPat, _Pred_eq>) {
if (!_STD _Is_constant_evaluated()) {
const auto _Ptr1 = _STD _To_address(_First1);

const auto _Ptr_res1 = _STD _Search_vectorized(
_Ptr1, _STD _To_address(_Last1), _STD _To_address(_First2), _STD _To_address(_Last2));

if constexpr (is_pointer_v<_FwdItHaystack>) {
_First1 = _Ptr_res1;
} else {
_First1 += _Ptr_res1 - _Ptr1;
}

if (_First1 != _Last1) {
return {_First1, _First1 + _Count2};
StephanTLavavej marked this conversation as resolved.
Show resolved Hide resolved
} else {
return {_Last1, _Last1};
}
}
}
#endif // _USE_STD_VECTOR_ALGORITHMS

for (; _Count2 <= _Count1; ++_First1, (void) --_Count1) { // room for match, try it
_FwdItHaystack _Mid1 = _First1;
for (_FwdItPat _Mid2 = _First2;; ++_Mid1, (void) ++_Mid2) {
Expand Down
62 changes: 62 additions & 0 deletions stl/inc/xutility
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,11 @@ const void* __stdcall __std_find_trivial_2(const void* _First, const void* _Last
const void* __stdcall __std_find_trivial_4(const void* _First, const void* _Last, uint32_t _Val) noexcept;
const void* __stdcall __std_find_trivial_8(const void* _First, const void* _Last, uint64_t _Val) noexcept;

const void* __stdcall __std_search_1(
const void* _First1, const void* _Last1, const void* _First2, const void* _Last2) noexcept;
const void* __stdcall __std_search_2(
const void* _First1, const void* _Last1, const void* _First2, const void* _Last2) noexcept;

const void* __stdcall __std_min_element_1(const void* _First, const void* _Last, bool _Signed) noexcept;
const void* __stdcall __std_min_element_2(const void* _First, const void* _Last, bool _Signed) noexcept;
const void* __stdcall __std_min_element_4(const void* _First, const void* _Last, bool _Signed) noexcept;
Expand Down Expand Up @@ -195,6 +200,18 @@ _Ty* _Find_vectorized(_Ty* const _First, _Ty* const _Last, const _TVal _Val) noe
}
}

template <class _Ty1, class _Ty2>
_Ty1* _Search_vectorized(_Ty1* const _First1, _Ty1* const _Last1, _Ty2* const _First2, _Ty2* const _Last2) noexcept {
_STL_INTERNAL_STATIC_ASSERT(sizeof(_Ty1) == sizeof(_Ty2));
if constexpr (sizeof(_Ty1) == 1) {
return const_cast<_Ty1*>(static_cast<const _Ty1*>(::__std_search_1(_First1, _Last1, _First2, _Last2)));
} else if constexpr (sizeof(_Ty1) == 2) {
return const_cast<_Ty1*>(static_cast<const _Ty1*>(::__std_search_2(_First1, _Last1, _First2, _Last2)));
} else {
_STL_INTERNAL_STATIC_ASSERT(false); // unexpected size
}
}

template <class _Ty>
_Ty* _Min_element_vectorized(_Ty* const _First, _Ty* const _Last) noexcept {
constexpr bool _Signed = is_signed_v<_Ty>;
Expand Down Expand Up @@ -5358,6 +5375,11 @@ template <class _Iter1, class _Iter2, class _Pr>
constexpr bool _Equal_memcmp_is_safe =
_Equal_memcmp_is_safe_helper<remove_const_t<_Iter1>, remove_const_t<_Iter2>, remove_const_t<_Pr>>;

// Can we activate the vector algorithms for std::search?
template <class _It1, class _It2, class _Pr>
constexpr bool _Vector_alg_in_search_is_safe = _Equal_memcmp_is_safe<_It1, _It2, _Pr> && // can search bitwise
sizeof(_Iter_value_t<_It1>) <= 2; // pcmpestri compatible element size
StephanTLavavej marked this conversation as resolved.
Show resolved Hide resolved

template <class _CtgIt1, class _CtgIt2>
_NODISCARD int _Memcmp_ranges(_CtgIt1 _First1, _CtgIt1 _Last1, _CtgIt2 _First2) {
_STL_INTERNAL_STATIC_ASSERT(sizeof(_Iter_value_t<_CtgIt1>) == sizeof(_Iter_value_t<_CtgIt2>));
Expand Down Expand Up @@ -6721,6 +6743,46 @@ namespace ranges {
_STL_INTERNAL_CHECK(_RANGES distance(_First1, _Last1) == _Count1);
_STL_INTERNAL_CHECK(_RANGES distance(_First2, _Last2) == _Count2);

#if _USE_STD_VECTOR_ALGORITHMS
if constexpr (_Vector_alg_in_search_is_safe<_It1, _It2, _Pr> && is_same_v<_Pj1, identity>
&& is_same_v<_Pj2, identity>) {
if (!_STD is_constant_evaluated()) {
const auto _Ptr1 = _STD to_address(_First1);
const auto _Ptr2 = _STD to_address(_First2);
remove_const_t<decltype(_Ptr1)> _Ptr_last1;
remove_const_t<decltype(_Ptr2)> _Ptr_last2;

if constexpr (is_same_v<_It1, _Se1>) {
_Ptr_last1 = _STD to_address(_Last1);
} else {
_Ptr_last1 = _Ptr1 + _Count1;
}

if constexpr (is_same_v<_It2, _Se2>) {
_Ptr_last2 = _STD to_address(_Last2);
} else {
_Ptr_last2 = _Ptr2 + _Count2;
}

const auto _Ptr_res1 = _STD _Search_vectorized(_Ptr1, _Ptr_last1, _Ptr2, _Ptr_last2);

if constexpr (is_pointer_v<_It1>) {
if (_Ptr_res1 != _Ptr_last1) {
return {_Ptr_res1, _Ptr_res1 + _Count2};
} else {
return {_Ptr_res1, _Ptr_res1};
}
} else {
_First1 += _Ptr_res1 - _Ptr1;
if (_First1 != _Last1) {
return {_First1, _First1 + _Count2};
StephanTLavavej marked this conversation as resolved.
Show resolved Hide resolved
} else {
return {_First1, _First1};
}
}
}
}
#endif // _USE_STD_VECTOR_ALGORITHMS
for (; _Count1 >= _Count2; ++_First1, (void) --_Count1) {
auto _Match_and_mid1 = _RANGES _Equal_rev_pred(_First1, _First2, _Last2, _Pred, _Proj1, _Proj2);
if (_Match_and_mid1.first) {
Expand Down
Loading