Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement ranges::sample and ranges::shuffle #1052

Merged
merged 17 commits into from
Jul 27, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
214 changes: 198 additions & 16 deletions stl/inc/algorithm
Original file line number Diff line number Diff line change
Expand Up @@ -4634,20 +4634,20 @@ _SampleIt _Sample_reservoir_unchecked(
// pre: _SampleIt is random-access && 0 < _Count && the range [_Dest, _Dest + _Count) is valid
using _Diff_sample = _Iter_diff_t<_SampleIt>;
const auto _SCount = static_cast<_Diff_sample>(_Count);
_Iter_diff_t<_PopIt> _PopSize{};
for (; _PopSize < _SCount; ++_PopSize, (void) ++_First) {
// _PopSize is less than _SCount, and [_Dest, _Dest + _SCount) is valid,
// so [_Dest, _Dest + _PopSize) must be valid, so narrowing to _Diff_sample
_Iter_diff_t<_PopIt> _Pop_size{};
for (; _Pop_size < _SCount; ++_Pop_size, (void) ++_First) {
// _Pop_size is less than _SCount, and [_Dest, _Dest + _SCount) is valid,
// so [_Dest, _Dest + _Pop_size) must be valid, so narrowing to _Diff_sample
// can't overflow
const auto _Sample_pop = static_cast<_Diff_sample>(_PopSize);
const auto _Sample_pop = static_cast<_Diff_sample>(_Pop_size);
if (_First == _Last) {
return _Dest + _Sample_pop;
}

*(_Dest + _Sample_pop) = *_First;
}
for (; _First != _Last; ++_First) {
const auto _Idx = _RngFunc(++_PopSize);
const auto _Idx = _RngFunc(++_Pop_size);
if (_Idx < _SCount) {
*(_Dest + static_cast<_Diff_sample>(_Idx)) = *_First; // again, valid narrowing because _Idx < _SCount
}
Expand All @@ -4657,12 +4657,12 @@ _SampleIt _Sample_reservoir_unchecked(

template <class _PopIt, class _SampleIt, class _Diff, class _RngFn>
_SampleIt _Sample_selection_unchecked(
_PopIt _First, const _PopIt _Last, _Iter_diff_t<_PopIt> _PopSize, _SampleIt _Dest, _Diff _Count, _RngFn& _RngFunc) {
_PopIt _First, _Iter_diff_t<_PopIt> _Pop_size, _SampleIt _Dest, _Diff _Count, _RngFn& _RngFunc) {
// source is forward *and* we know the source range size: use selection sampling (stable)
// pre: _PopIt is forward && _Count <= _PopSize
// pre: _PopIt is forward && _Count <= _Pop_size
using _CT = common_type_t<_Iter_diff_t<_PopIt>, _Diff>;
for (; _Count > 0 && _First != _Last; ++_First, (void) --_PopSize) {
if (static_cast<_CT>(_RngFunc(_PopSize)) < static_cast<_CT>(_Count)) {
for (; _Pop_size > 0; ++_First, (void) --_Pop_size) {
if (static_cast<_CT>(_RngFunc(_Pop_size)) < static_cast<_CT>(_Count)) {
--_Count;
*_Dest = *_First;
++_Dest;
Expand All @@ -4684,15 +4684,15 @@ template <class _PopIt, class _SampleIt, class _Diff, class _RngFn>
_SampleIt _Sample1(_PopIt _First, _PopIt _Last, _SampleIt _Dest, _Diff _Count, _RngFn& _RngFunc, forward_iterator_tag) {
// source is forward: use selection sampling (stable)
// pre: _Count > 0
using _PopDiff = _Iter_diff_t<_PopIt>;
using _CT = common_type_t<_Diff, _PopDiff>;
const auto _PopSize = _STD distance(_First, _Last);
if (static_cast<_CT>(_Count) > static_cast<_CT>(_PopSize)) {
_Count = static_cast<_Diff>(_PopSize); // narrowing OK because _Count is getting smaller
using _PopDiff = _Iter_diff_t<_PopIt>;
using _CT = common_type_t<_Diff, _PopDiff>;
const auto _Pop_size = _STD distance(_First, _Last);
if (static_cast<_CT>(_Count) > static_cast<_CT>(_Pop_size)) {
_Count = static_cast<_Diff>(_Pop_size); // narrowing OK because _Count is getting smaller
}

_Seek_wrapped(
_Dest, _Sample_selection_unchecked(_First, _Last, _PopSize, _Get_unwrapped_n(_Dest, _Count), _Count, _RngFunc));
_Dest, _Sample_selection_unchecked(_First, _Pop_size, _Get_unwrapped_n(_Dest, _Count), _Count, _RngFunc));
return _Dest;
}

Expand All @@ -4710,6 +4710,128 @@ _SampleIt sample(_PopIt _First, _PopIt _Last, _SampleIt _Dest, _Diff _Count,

return _Dest;
}

#ifdef __cpp_lib_concepts
// STRUCT TEMPLATE _Require_constant
template <auto>
struct _Require_constant; // not defined; _Require_constant<E> is a valid type if E is a constant expression

// CONCEPT uniform_random_bit_generator
// clang-format off
template <class _Ty>
concept uniform_random_bit_generator = invocable<_Ty&> && unsigned_integral<invoke_result_t<_Ty&>> && requires {
{ (_Ty::min)() } -> same_as<invoke_result_t<_Ty&>>;
{ (_Ty::max)() } -> same_as<invoke_result_t<_Ty&>>;
typename _Require_constant<(_Ty::min)()>;
typename _Require_constant<(_Ty::max)()>;
requires (_Ty::min)() < (_Ty::max)();
};
// clang-format on

namespace ranges {
// VARIABLE ranges::sample
class _Sample_fn : private _Not_quite_object {
public:
using _Not_quite_object::_Not_quite_object;

// clang-format off
template <input_iterator _It, sentinel_for<_It> _Se, weakly_incrementable _Out, class _Urng>
requires (forward_iterator<_It> || random_access_iterator<_Out>)
&& indirectly_copyable<_It, _Out> && uniform_random_bit_generator<remove_reference_t<_Urng>>
_Out operator()(_It _First, _Se _Last, _Out _Result, iter_difference_t<_It> _Count, _Urng&& _Func) const {
_Adl_verify_range(_First, _Last);
if (_Count <= 0) {
return _Result;
}

_Rng_from_urng<iter_difference_t<_It>, remove_reference_t<_Urng>> _RngFunc(_Func);
if constexpr (forward_iterator<_It>) {
auto _UFirst = _Get_unwrapped(_STD move(_First));
auto _Pop_size = _RANGES distance(_UFirst, _Get_unwrapped(_STD move(_Last)));
return _Sample_selection_unchecked(_STD move(_UFirst), _Pop_size, _STD move(_Result), _Count, _RngFunc);
} else {
return _Sample_reservoir_unchecked(_Get_unwrapped(_STD move(_First)), _Get_unwrapped(_STD move(_Last)),
_STD move(_Result), _Count, _RngFunc);
}
}

template <input_range _Rng, weakly_incrementable _Out, class _Urng>
requires (forward_range<_Rng> || random_access_iterator<_Out>)
&& indirectly_copyable<iterator_t<_Rng>, _Out>
&& uniform_random_bit_generator<remove_reference_t<_Urng>>
_Out operator()(_Rng&& _Range, _Out _Result, range_difference_t<_Rng> _Count, _Urng&& _Func) const {
if (_Count <= 0) {
return _Result;
}

miscco marked this conversation as resolved.
Show resolved Hide resolved
_Rng_from_urng<range_difference_t<_Rng>, remove_reference_t<_Urng>> _RngFunc(_Func);
if constexpr (forward_range<_Rng>) {
auto _UFirst = _Ubegin(_Range);
auto _Pop_size = _RANGES distance(_UFirst, _Uend(_Range));
return _Sample_selection_unchecked(_STD move(_UFirst), _Pop_size, _STD move(_Result), _Count, _RngFunc);
} else {
return _Sample_reservoir_unchecked(
_Ubegin(_Range), _Uend(_Range), _STD move(_Result), _Count, _RngFunc);
}
}
// clang-format on
private:
template <class _It, class _Out, class _Rng>
_NODISCARD static _Out _Sample_selection_unchecked(
_It _First, iter_difference_t<_It> _Pop_size, _Out _Result, iter_difference_t<_It> _Count, _Rng& _RngFunc) {
// randomly select _Count elements from [_First, _First + _Pop_size) into _Result
_STL_INTERNAL_STATIC_ASSERT(forward_iterator<_It>);
_STL_INTERNAL_STATIC_ASSERT(weakly_incrementable<_Out>);
CaseyCarter marked this conversation as resolved.
Show resolved Hide resolved
_STL_INTERNAL_STATIC_ASSERT(indirectly_copyable<_It, _Out>);

if (_Count > _Pop_size) {
_Count = _Pop_size;
}

for (; _Pop_size > 0; ++_First, (void) --_Pop_size) {
if (_RngFunc(_Pop_size) < _Count) {
*_Result = *_First;
++_Result;
if (--_Count == 0) {
break;
}
}
}

return _Result;
}

template <class _It, class _Se, class _Out, class _Rng>
_NODISCARD static _Out _Sample_reservoir_unchecked(
_It _First, const _Se _Last, _Out _Result, const iter_difference_t<_It> _Count, _Rng& _RngFunc) {
// randomly select _Count elements from [_First, _Last) into _Result
_STL_INTERNAL_STATIC_ASSERT(input_iterator<_It>);
_STL_INTERNAL_STATIC_ASSERT(sentinel_for<_Se, _It>);
_STL_INTERNAL_STATIC_ASSERT(random_access_iterator<_Out>);
_STL_INTERNAL_STATIC_ASSERT(indirectly_copyable<_It, _Out>);

iter_difference_t<_It> _Pop_size{};
for (; _Pop_size < _Count; ++_Pop_size, (void) ++_First) {
if (_First == _Last) {
return _Result + _Pop_size;
}

*(_Result + _Pop_size) = *_First;
}
for (; _First != _Last; ++_First) {
const auto _Idx = _RngFunc(++_Pop_size);
if (_Idx < _Count) {
*(_Result + _Idx) = *_First;
}
}

return _Result + _Count;
}
};

inline constexpr _Sample_fn sample{_Not_quite_object::_Construct_tag{}};
} // namespace ranges
#endif // __cpp_lib_concepts
#endif // _HAS_CXX17

// FUNCTION TEMPLATE shuffle WITH URNG
Expand Down Expand Up @@ -4742,6 +4864,66 @@ void shuffle(_RanIt _First, _RanIt _Last, _Urng&& _Func) { // shuffle [_First, _
_Random_shuffle1(_First, _Last, _RngFunc);
}

#ifdef __cpp_lib_concepts
namespace ranges {
// VARIABLE ranges::shuffle
class _Shuffle_fn : private _Not_quite_object {
public:
using _Not_quite_object::_Not_quite_object;

// clang-format off
template <random_access_iterator _It, sentinel_for<_It> _Se, class _Urng>
requires permutable<_It> && uniform_random_bit_generator<remove_reference_t<_Urng>>
_It operator()(_It _First, _Se _Last, _Urng&& _Func) const {
_Adl_verify_range(_First, _Last);

_Rng_from_urng<iter_difference_t<_It>, remove_reference_t<_Urng>> _RngFunc(_Func);
auto _UResult =
_Shuffle_unchecked(_Get_unwrapped(_STD move(_First)), _Get_unwrapped(_STD move(_Last)), _RngFunc);

_Seek_wrapped(_First, _STD move(_UResult));
return _First;
}

template <random_access_range _Rng, class _Urng>
requires permutable<iterator_t<_Rng>> && uniform_random_bit_generator<remove_reference_t<_Urng>>
borrowed_iterator_t<_Rng> operator()(_Rng&& _Range, _Urng&& _Func) const {
_Rng_from_urng<range_difference_t<_Rng>, remove_reference_t<_Urng>> _RngFunc(_Func);

return _Rewrap_iterator(_Range, _Shuffle_unchecked(_Ubegin(_Range), _Uend(_Range), _RngFunc));
}
// clang-format on
private:
template <class _It, class _Se, class _Rng>
_NODISCARD static _It _Shuffle_unchecked(_It _First, const _Se _Last, _Rng& _Func) {
miscco marked this conversation as resolved.
Show resolved Hide resolved
// shuffle [_First, _Last) using random function _Func
_STL_INTERNAL_STATIC_ASSERT(random_access_iterator<_It>);
_STL_INTERNAL_STATIC_ASSERT(sentinel_for<_Se, _It>);
CaseyCarter marked this conversation as resolved.
Show resolved Hide resolved
_STL_INTERNAL_STATIC_ASSERT(permutable<_It>);

if (_First == _Last) {
return _First;
}
using _Diff = iter_difference_t<_It>;

auto _Target = _First;
_Diff _Target_index = 1;
for (; ++_Target != _Last; ++_Target_index) {
// randomly place an element from [_First, _Target] at _Target
const _Diff _Off = _Func(_Target_index + 1);
_STL_ASSERT(0 <= _Off && _Off <= _Target_index, "random value out of range");
if (_Off != _Target_index) { // avoid self-move-assignment
_RANGES iter_swap(_Target, _First + _Off);
}
}
return _Target;
}
};

inline constexpr _Shuffle_fn shuffle{_Not_quite_object::_Construct_tag{}};
} // namespace ranges
#endif // __cpp_lib_concepts

#if _HAS_AUTO_PTR_ETC
// FUNCTION TEMPLATE random_shuffle WITH RANDOM FN
template <class _RanIt, class _RngFn>
Expand Down
23 changes: 0 additions & 23 deletions stl/inc/random
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,6 @@
#include <vector>
#include <xstring>

#ifdef __cpp_lib_concepts
#include <concepts>
#endif // __cpp_lib_concepts

#pragma pack(push, _CRT_PACKING)
#pragma warning(push, _STL_WARNING_LEVEL)
#pragma warning(disable : _STL_DISABLED_WARNINGS)
Expand Down Expand Up @@ -52,25 +48,6 @@ _STD_BEGIN
"unsigned short, unsigned int, unsigned long, or unsigned long long"); \
_RNG_PROHIBIT_CHAR(_CheckedType)


#ifdef __cpp_lib_concepts
// STRUCT TEMPLATE _Require_constant
template <auto>
struct _Require_constant; // not defined; _Require_constant<E> is a valid type if E is a constant expression

// CONCEPT uniform_random_bit_generator
// clang-format off
template <class _Ty>
concept uniform_random_bit_generator = invocable<_Ty&> && unsigned_integral<invoke_result_t<_Ty&>> && requires {
{ (_Ty::min)() } -> same_as<invoke_result_t<_Ty&>>;
{ (_Ty::max)() } -> same_as<invoke_result_t<_Ty&>>;
typename _Require_constant<(_Ty::min)()>;
typename _Require_constant<(_Ty::max)()>;
requires (_Ty::min)() < (_Ty::max)();
};
// clang-format on
#endif // __cpp_lib_concepts

// ALIAS TEMPLATE _Enable_if_seed_seq_t
template <class _Seed_seq, class _Self, class _Engine = _Self>
using _Enable_if_seed_seq_t = enable_if_t<
Expand Down
2 changes: 2 additions & 0 deletions tests/std/test.lst
Original file line number Diff line number Diff line change
Expand Up @@ -277,8 +277,10 @@ tests\P0896R4_ranges_alg_replace_copy
tests\P0896R4_ranges_alg_replace_copy_if
tests\P0896R4_ranges_alg_replace_if
tests\P0896R4_ranges_alg_reverse
tests\P0896R4_ranges_alg_sample
tests\P0896R4_ranges_alg_search
tests\P0896R4_ranges_alg_search_n
tests\P0896R4_ranges_alg_shuffle
tests\P0896R4_ranges_alg_swap_ranges
tests\P0896R4_ranges_alg_transform_binary
tests\P0896R4_ranges_alg_transform_unary
Expand Down
4 changes: 4 additions & 0 deletions tests/std/tests/P0896R4_ranges_alg_sample/env.lst
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

RUNALL_INCLUDE ..\concepts_matrix.lst
84 changes: 84 additions & 0 deletions tests/std/tests/P0896R4_ranges_alg_sample/test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include <algorithm>
#include <cassert>
#include <concepts>
#include <cstdio>
#include <random>
#include <ranges>
#include <utility>

#include <range_algorithm_support.hpp>
using namespace std;

const unsigned int seed = random_device{}();
mt19937 gen{seed};

struct instantiator {
static constexpr int reservoir[5] = {13, 42, 71, 112, 1729};

template <ranges::input_range Read, indirectly_writable<ranges::range_reference_t<Read>> Write>
static void call() {
using ranges::sample, ranges::equal, ranges::is_sorted, ranges::iterator_t;

if constexpr (forward_iterator<iterator_t<Read>> || random_access_iterator<Write>) {
auto copy_gen = gen;

{ // Validate iterator + sentinel overload
int output1[3] = {-1, -1, -1};
int output2[3] = {-1, -1, -1};
Read wrapped_input{reservoir};

auto result1 = sample(wrapped_input.begin(), wrapped_input.end(), Write{output1}, 3, gen);
STATIC_ASSERT(same_as<decltype(result1), Write>);
assert(result1.peek() == end(output1));

// check repeatability
Read wrapped_input2{reservoir};
auto result2 = sample(wrapped_input2.begin(), wrapped_input2.end(), Write{output2}, 3, copy_gen);
assert(equal(output1, output2));
assert(result2.peek() == end(output2));

if (ranges::forward_range<Read>) {
// verify stability
assert(is_sorted(output1));
} else {
// ensure ordering for equality test
sort(begin(output1), end(output1));
}
assert(includes(cbegin(reservoir), cend(reservoir), cbegin(output1), cend(output1)));
}
{ // Validate range overload
int output1[3] = {-1, -1, -1};
int output2[3] = {-1, -1, -1};
Read wrapped_input{reservoir};

auto result1 = sample(wrapped_input, Write{output1}, 3, gen);
STATIC_ASSERT(same_as<decltype(result1), Write>);
assert(result1.peek() == end(output1));

// check repeatability
Read wrapped_input2{reservoir};
auto result2 = sample(wrapped_input2, Write{output2}, 3, copy_gen);
assert(equal(output1, output2));
assert(result2.peek() == end(output2));

if (ranges::forward_range<Read>) {
// verify stability
assert(is_sorted(output1));
} else {
// ensure ordering for equality test
sort(begin(output1), end(output1));
}
assert(includes(cbegin(reservoir), cend(reservoir), cbegin(output1), cend(output1)));
}
}
}
};

int main() {
printf("Using seed: %u\n", seed);
StephanTLavavej marked this conversation as resolved.
Show resolved Hide resolved

test_in_write<instantiator, const int, int>();
}
Loading