Skip to content

Commit

Permalink
let's have sized functions as usual
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexGuteniev committed Mar 21, 2024
1 parent f45dbbd commit b0d6ece
Show file tree
Hide file tree
Showing 3 changed files with 133 additions and 89 deletions.
34 changes: 18 additions & 16 deletions stl/inc/algorithm
Original file line number Diff line number Diff line change
Expand Up @@ -644,16 +644,17 @@ _NODISCARD _CONSTEXPR20 pair<_InIt1, _InIt2> mismatch(_InIt1 _First1, const _InI
#if _USE_STD_VECTOR_ALGORITHMS
if constexpr (_Equal_memcmp_is_safe<decltype(_UFirst1), decltype(_UFirst2), _Pr>) {
if (!_STD _Is_constant_evaluated()) {
using _Ty = _Iter_value_t<_InIt1>;
const auto _First1_ptr = _STD _To_address(_UFirst1);
const auto _First2_ptr = _STD _To_address(_UFirst2);
const auto _Count_el = static_cast<size_t>(_ULast1 - _UFirst1);
constexpr size_t _Elem_size = sizeof(_Iter_value_t<_InIt1>);

const auto _Skip = static_cast<decltype(_ULast1 - _UFirst1)>(
::__std_mismatch_byte_helper(_First1_ptr, _First2_ptr, _Count_el * sizeof(_Ty)) / sizeof(_Ty));
const auto _Pos = static_cast<decltype(_ULast1 - _UFirst1)>(_STD __std_mismatch<_Elem_size>(
_STD _To_address(_UFirst1), _STD _To_address(_UFirst2), static_cast<size_t>(_ULast1 - _UFirst1)));

_UFirst1 += _Skip;
_UFirst2 += _Skip;
_UFirst1 += _Pos;
_UFirst2 += _Pos;

_STD _Seek_wrapped(_First2, _UFirst2);
_STD _Seek_wrapped(_First1, _UFirst1);
return {_First1, _First2};
}
}
#endif // ^^^ !_USE_STD_VECTOR_ALGORITHMS ^^^
Expand Down Expand Up @@ -707,16 +708,17 @@ _NODISCARD _CONSTEXPR20 pair<_InIt1, _InIt2> mismatch(
#if _USE_STD_VECTOR_ALGORITHMS
if constexpr (_Equal_memcmp_is_safe<decltype(_UFirst1), decltype(_UFirst2), _Pr>) {
if (!_STD _Is_constant_evaluated()) {
using _Ty = _Iter_value_t<_InIt1>;
const auto _First1_ptr = _STD _To_address(_UFirst1);
const auto _First2_ptr = _STD _To_address(_UFirst2);
const auto _Count_el = static_cast<size_t>(_Count);
constexpr size_t _Elem_size = sizeof(_Iter_value_t<_InIt1>);

const auto _Pos = static_cast<decltype(_Count)>(_STD __std_mismatch<_Elem_size>(
_STD _To_address(_UFirst1), _STD _To_address(_UFirst2), static_cast<size_t>(_Count)));

const auto _Skip = static_cast<decltype(_Count)>(
::__std_mismatch_byte_helper(_First1_ptr, _First2_ptr, _Count_el * sizeof(_Ty)) / sizeof(_Ty));
_UFirst1 += _Pos;
_UFirst2 += _Pos;

_UFirst1 += _Skip;
_UFirst2 += _Skip;
_STD _Seek_wrapped(_First2, _UFirst2);
_STD _Seek_wrapped(_First1, _UFirst1);
return {_First1, _First2};
}
}
#endif // ^^^ !_USE_STD_VECTOR_ALGORITHMS ^^^
Expand Down
38 changes: 23 additions & 15 deletions stl/inc/xutility
Original file line number Diff line number Diff line change
Expand Up @@ -125,12 +125,10 @@ __declspec(noalias) uint64_t __stdcall __std_max_8u(const void* _First, const vo
__declspec(noalias) float __stdcall __std_max_f(const void* _First, const void* _Last) noexcept;
__declspec(noalias) double __stdcall __std_max_d(const void* _First, const void* _Last) noexcept;

// Returns the position to which 'mismatch' can fast forward.
// This position can be the first mismatched byte, or an earlier position.
// The purpose is to handle only the portion that can be vectorized in a more efficient way,
// than element wise comparison, without element size knowledge.
__declspec(noalias) size_t
__stdcall __std_mismatch_byte_helper(const void* _First1, const void* _First2, size_t _Count_bytes);
__declspec(noalias) size_t __stdcall __std_mismatch_1(const void* _First1, const void* _First2, size_t _Count) noexcept;
__declspec(noalias) size_t __stdcall __std_mismatch_2(const void* _First1, const void* _First2, size_t _Count) noexcept;
__declspec(noalias) size_t __stdcall __std_mismatch_4(const void* _First1, const void* _First2, size_t _Count) noexcept;
__declspec(noalias) size_t __stdcall __std_mismatch_8(const void* _First1, const void* _First2, size_t _Count) noexcept;
} // extern "C"

_STD_BEGIN
Expand Down Expand Up @@ -299,6 +297,21 @@ auto __std_max(_Ty* const _First, _Ty* const _Last) noexcept {
static_assert(_Always_false<_Ty>, "Unexpected size");
}
}

template <size_t _Element_size>
size_t __std_mismatch(const void* const _First1, const void* const _First2, const size_t _Count) noexcept {
if constexpr (_Element_size == 1) {
return __std_mismatch_1(_First1, _First2, _Count);
} else if constexpr (_Element_size == 2) {
return __std_mismatch_2(_First1, _First2, _Count);
} else if constexpr (_Element_size == 4) {
return __std_mismatch_4(_First1, _First2, _Count);
} else if constexpr (_Element_size == 8) {
return __std_mismatch_8(_First1, _First2, _Count);
} else {
static_assert(_Always_false<integral_constant<size_t, _Element_size>>, "Unexpected size");
}
}
_STD_END

#endif // _USE_STD_VECTOR_ALGORITHMS
Expand Down Expand Up @@ -5488,17 +5501,12 @@ namespace ranges {
if constexpr (_Equal_memcmp_is_safe<_It1, _It2, _Pr> && is_same_v<_Pj1, identity>
&& is_same_v<_Pj2, identity>) {
if (!_STD is_constant_evaluated()) {
using _Ty = iter_value_t<_It1>;
const auto _First1_ptr = _STD _To_address(_First1);
const auto _First2_ptr = _STD _To_address(_First2);
const auto _Count_el = static_cast<size_t>(_Count);
constexpr size_t _Elem_size = sizeof(iter_value_t<_It1>);

const auto _Skip = static_cast<decltype(_Count)>(
::__std_mismatch_byte_helper(_First1_ptr, _First2_ptr, _Count_el * sizeof(_Ty)) / sizeof(_Ty));
const auto _Pos = static_cast<decltype(_Count)>(_STD __std_mismatch<_Elem_size>(
_STD _To_address(_First1), _STD _To_address(_First2), static_cast<size_t>(_Count)));

_First1 += _Skip;
_First2 += _Skip;
_Count -= _Skip;
return {_STD move(_First1 + _Pos), _STD move(_First2 + _Pos)};
}
}
#endif // ^^^ !_USE_STD_VECTOR_ALGORITHMS ^^^
Expand Down
150 changes: 92 additions & 58 deletions stl/src/vector_algorithms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2027,6 +2027,78 @@ namespace {
}
return _Result;
}

template <class _Traits, class _Ty>
__declspec(noalias) size_t
__stdcall _Mismatch(const void* const _First1, const void* const _First2, const size_t _Count) noexcept {
size_t _Result = 0;
#ifndef _M_ARM64EC
const auto _First1_ch = static_cast<const char*>(_First1);
const auto _First2_ch = static_cast<const char*>(_First2);

if (_Use_avx2()) {
const size_t _Count_bytes = _Count * sizeof(_Ty);
const size_t _Count_bytes_avx_full = _Count_bytes & ~size_t{0x1F};

for (; _Result != _Count_bytes_avx_full; _Result += 0x20) {
const __m256i _Elem1 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(_First1_ch + _Result));
const __m256i _Elem2 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(_First2_ch + _Result));
const auto _Bingo = ~static_cast<unsigned int>(_mm256_movemask_epi8(_Traits::_Cmp_avx(_Elem1, _Elem2)));
if (_Bingo != 0) {
return (_Result + _tzcnt_u32(_Bingo)) / sizeof(_Ty);
}
}

size_t _Count_tail = _Count_bytes & size_t{0x1C};

if (_Count_tail != 0) {
const __m256i _Tail_mask = _Avx2_tail_mask_32(_Count_tail >> 2);
const __m256i _Elem1 =
_mm256_maskload_epi32(reinterpret_cast<const int*>(_First1_ch + _Result), _Tail_mask);
const __m256i _Elem2 =
_mm256_maskload_epi32(reinterpret_cast<const int*>(_First2_ch + _Result), _Tail_mask);

const auto _Bingo = ~static_cast<unsigned int>(_mm256_movemask_epi8(_Traits::_Cmp_avx(_Elem1, _Elem2)));
if (_Bingo != 0) {
return (_Result + _tzcnt_u32(_Bingo)) / sizeof(_Ty);
}

_Result += _Count_tail;
}

_Result /= sizeof(_Ty);

if constexpr (sizeof(_Ty) >= 4) {
return _Result;
}
} else if (_Use_sse2()) {
const size_t _Count_bytes_sse = (_Count * sizeof(_Ty)) & ~size_t{0xF};

for (; _Result != _Count_bytes_sse; _Result += 0x10) {
const __m128i _Elem1 = _mm_loadu_si128(reinterpret_cast<const __m128i*>(_First1_ch + _Result));
const __m128i _Elem2 = _mm_loadu_si128(reinterpret_cast<const __m128i*>(_First2_ch + _Result));
const auto _Bingo = static_cast<unsigned int>(_mm_movemask_epi8(_Traits::_Cmp_sse(_Elem1, _Elem2))) ^ 0xFFFF;
if (_Bingo != 0) {
unsigned long _Offset;
_BitScanForward(&_Offset, _Bingo); // lgtm [cpp/conditionallyuninitializedvariable]
return (_Result + _Offset) / sizeof(_Ty);
}
}

_Result /= sizeof(_Ty);
}
#endif // !defined(_M_ARM64EC)
const auto _First1_el = static_cast<const _Ty*>(_First1);
const auto _First2_el = static_cast<const _Ty*>(_First2);

for (; _Result != _Count; ++_Result) {
if (_First1_el[_Result] != _First2_el[_Result]) {
break;
}
}

return _Result;
}
} // unnamed namespace

extern "C" {
Expand Down Expand Up @@ -2112,6 +2184,26 @@ __declspec(noalias) size_t
return __std_count_trivial_impl<_Find_traits_8>(_First, _Last, _Val);
}

__declspec(noalias) size_t
__stdcall __std_mismatch_1(const void* const _First1, const void* const _First2, const size_t _Count) noexcept {
return _Mismatch<_Find_traits_1, uint8_t>(_First1, _First2, _Count);
}

__declspec(noalias) size_t
__stdcall __std_mismatch_2(const void* const _First1, const void* const _First2, const size_t _Count) noexcept {
return _Mismatch<_Find_traits_2, uint16_t>(_First1, _First2, _Count);
}

__declspec(noalias) size_t
__stdcall __std_mismatch_4(const void* const _First1, const void* const _First2, const size_t _Count) noexcept {
return _Mismatch<_Find_traits_4, uint32_t>(_First1, _First2, _Count);
}

__declspec(noalias) size_t
__stdcall __std_mismatch_8(const void* const _First1, const void* const _First2, const size_t _Count) noexcept {
return _Mismatch<_Find_traits_8, uint64_t>(_First1, _First2, _Count);
}

} // extern "C"

#ifndef _M_ARM64EC
Expand Down Expand Up @@ -2317,63 +2409,5 @@ __declspec(noalias) void __stdcall __std_bitset_to_string_2(
}
}

__declspec(noalias) size_t __stdcall __std_mismatch_byte_helper(
const void* const _First1, const void* const _First2, const size_t _Count_bytes) {
#ifndef _M_ARM64EC
const auto _First1_ch = static_cast<const char*>(_First1);
const auto _First2_ch = static_cast<const char*>(_First2);

if (_Use_avx2()) {
const size_t _Count_bytes_avx_full = _Count_bytes & ~size_t{0x1F};

size_t _Result = 0;
for (; _Result != _Count_bytes_avx_full; _Result += 0x20) {
const __m256i _Elem1 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(_First1_ch + _Result));
const __m256i _Elem2 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(_First2_ch + _Result));
const auto _Bingo = ~static_cast<unsigned int>(_mm256_movemask_epi8(_mm256_cmpeq_epi8(_Elem1, _Elem2)));
if (_Bingo != 0) {
return _Result + _tzcnt_u32(_Bingo);
}
}

size_t _Count_tail = _Count_bytes & size_t{0x1C};

if (_Count_tail == 0) {
return _Result;
}

const __m256i _Tail_mask = _Avx2_tail_mask_32(_Count_tail >> 2);
const __m256i _Elem1 = _mm256_maskload_epi32(reinterpret_cast<const int*>(_First1_ch + _Result), _Tail_mask);
const __m256i _Elem2 = _mm256_maskload_epi32(reinterpret_cast<const int*>(_First2_ch + _Result), _Tail_mask);

const auto _Bingo = ~static_cast<unsigned int>(_mm256_movemask_epi8(_mm256_cmpeq_epi8(_Elem1, _Elem2)));
if (_Bingo != 0) {
return _Result + _tzcnt_u32(_Bingo);
}

return _Result + _Count_tail;
} else if (_Use_sse2()) {
const size_t _Count_bytes_sse = _Count_bytes & ~size_t{0xF};

size_t _Result = 0;
for (; _Result != _Count_bytes_sse; _Result += 0x10) {
const __m128i _Elem1 = _mm_loadu_si128(reinterpret_cast<const __m128i*>(_First1_ch + _Result));
const __m128i _Elem2 = _mm_loadu_si128(reinterpret_cast<const __m128i*>(_First2_ch + _Result));
const auto _Bingo = ~static_cast<unsigned short>(_mm_movemask_epi8(_mm_cmpeq_epi8(_Elem1, _Elem2)));
if (_Bingo != 0) {
unsigned long _Offset;
_BitScanForward(&_Offset, _Bingo); // lgtm [cpp/conditionallyuninitializedvariable]
return _Result + _Offset;
}
}

return _Result;
} else
#endif // !defined(_M_ARM64EC)
{
return 0;
}
}

} // extern "C"
#endif // defined(_M_IX86) || defined(_M_X64)

0 comments on commit b0d6ece

Please sign in to comment.