Skip to content

Commit

Permalink
Implement ranges::sample and ranges::shuffle (#1052)
Browse files Browse the repository at this point in the history
Co-authored-by: statementreply <[email protected]>
  • Loading branch information
miscco and statementreply authored Jul 27, 2020
1 parent 28efc70 commit 99241dc
Show file tree
Hide file tree
Showing 7 changed files with 350 additions and 39 deletions.
214 changes: 198 additions & 16 deletions stl/inc/algorithm
Original file line number Diff line number Diff line change
Expand Up @@ -4856,20 +4856,20 @@ _SampleIt _Sample_reservoir_unchecked(
// pre: _SampleIt is random-access && 0 < _Count && the range [_Dest, _Dest + _Count) is valid
using _Diff_sample = _Iter_diff_t<_SampleIt>;
const auto _SCount = static_cast<_Diff_sample>(_Count);
_Iter_diff_t<_PopIt> _PopSize{};
for (; _PopSize < _SCount; ++_PopSize, (void) ++_First) {
// _PopSize is less than _SCount, and [_Dest, _Dest + _SCount) is valid,
// so [_Dest, _Dest + _PopSize) must be valid, so narrowing to _Diff_sample
_Iter_diff_t<_PopIt> _Pop_size{};
for (; _Pop_size < _SCount; ++_Pop_size, (void) ++_First) {
// _Pop_size is less than _SCount, and [_Dest, _Dest + _SCount) is valid,
// so [_Dest, _Dest + _Pop_size) must be valid, so narrowing to _Diff_sample
// can't overflow
const auto _Sample_pop = static_cast<_Diff_sample>(_PopSize);
const auto _Sample_pop = static_cast<_Diff_sample>(_Pop_size);
if (_First == _Last) {
return _Dest + _Sample_pop;
}

*(_Dest + _Sample_pop) = *_First;
}
for (; _First != _Last; ++_First) {
const auto _Idx = _RngFunc(++_PopSize);
const auto _Idx = _RngFunc(++_Pop_size);
if (_Idx < _SCount) {
*(_Dest + static_cast<_Diff_sample>(_Idx)) = *_First; // again, valid narrowing because _Idx < _SCount
}
Expand All @@ -4879,12 +4879,12 @@ _SampleIt _Sample_reservoir_unchecked(

template <class _PopIt, class _SampleIt, class _Diff, class _RngFn>
_SampleIt _Sample_selection_unchecked(
_PopIt _First, const _PopIt _Last, _Iter_diff_t<_PopIt> _PopSize, _SampleIt _Dest, _Diff _Count, _RngFn& _RngFunc) {
_PopIt _First, _Iter_diff_t<_PopIt> _Pop_size, _SampleIt _Dest, _Diff _Count, _RngFn& _RngFunc) {
// source is forward *and* we know the source range size: use selection sampling (stable)
// pre: _PopIt is forward && _Count <= _PopSize
// pre: _PopIt is forward && _Count <= _Pop_size
using _CT = common_type_t<_Iter_diff_t<_PopIt>, _Diff>;
for (; _Count > 0 && _First != _Last; ++_First, (void) --_PopSize) {
if (static_cast<_CT>(_RngFunc(_PopSize)) < static_cast<_CT>(_Count)) {
for (; _Pop_size > 0; ++_First, (void) --_Pop_size) {
if (static_cast<_CT>(_RngFunc(_Pop_size)) < static_cast<_CT>(_Count)) {
--_Count;
*_Dest = *_First;
++_Dest;
Expand All @@ -4906,15 +4906,15 @@ template <class _PopIt, class _SampleIt, class _Diff, class _RngFn>
_SampleIt _Sample1(_PopIt _First, _PopIt _Last, _SampleIt _Dest, _Diff _Count, _RngFn& _RngFunc, forward_iterator_tag) {
// source is forward: use selection sampling (stable)
// pre: _Count > 0
using _PopDiff = _Iter_diff_t<_PopIt>;
using _CT = common_type_t<_Diff, _PopDiff>;
const auto _PopSize = _STD distance(_First, _Last);
if (static_cast<_CT>(_Count) > static_cast<_CT>(_PopSize)) {
_Count = static_cast<_Diff>(_PopSize); // narrowing OK because _Count is getting smaller
using _PopDiff = _Iter_diff_t<_PopIt>;
using _CT = common_type_t<_Diff, _PopDiff>;
const auto _Pop_size = _STD distance(_First, _Last);
if (static_cast<_CT>(_Count) > static_cast<_CT>(_Pop_size)) {
_Count = static_cast<_Diff>(_Pop_size); // narrowing OK because _Count is getting smaller
}

_Seek_wrapped(
_Dest, _Sample_selection_unchecked(_First, _Last, _PopSize, _Get_unwrapped_n(_Dest, _Count), _Count, _RngFunc));
_Dest, _Sample_selection_unchecked(_First, _Pop_size, _Get_unwrapped_n(_Dest, _Count), _Count, _RngFunc));
return _Dest;
}

Expand All @@ -4932,6 +4932,128 @@ _SampleIt sample(_PopIt _First, _PopIt _Last, _SampleIt _Dest, _Diff _Count,

return _Dest;
}

#ifdef __cpp_lib_concepts
// STRUCT TEMPLATE _Require_constant
template <auto>
struct _Require_constant; // not defined; _Require_constant<E> is a valid type if E is a constant expression

// CONCEPT uniform_random_bit_generator
// clang-format off
template <class _Ty>
concept uniform_random_bit_generator = invocable<_Ty&> && unsigned_integral<invoke_result_t<_Ty&>> && requires {
{ (_Ty::min)() } -> same_as<invoke_result_t<_Ty&>>;
{ (_Ty::max)() } -> same_as<invoke_result_t<_Ty&>>;
typename _Require_constant<(_Ty::min)()>;
typename _Require_constant<(_Ty::max)()>;
requires (_Ty::min)() < (_Ty::max)();
};
// clang-format on

namespace ranges {
// VARIABLE ranges::sample
class _Sample_fn : private _Not_quite_object {
public:
using _Not_quite_object::_Not_quite_object;

// clang-format off
template <input_iterator _It, sentinel_for<_It> _Se, weakly_incrementable _Out, class _Urng>
requires (forward_iterator<_It> || random_access_iterator<_Out>)
&& indirectly_copyable<_It, _Out> && uniform_random_bit_generator<remove_reference_t<_Urng>>
_Out operator()(_It _First, _Se _Last, _Out _Result, iter_difference_t<_It> _Count, _Urng&& _Func) const {
_Adl_verify_range(_First, _Last);
if (_Count <= 0) {
return _Result;
}

_Rng_from_urng<iter_difference_t<_It>, remove_reference_t<_Urng>> _RngFunc(_Func);
if constexpr (forward_iterator<_It>) {
auto _UFirst = _Get_unwrapped(_STD move(_First));
auto _Pop_size = _RANGES distance(_UFirst, _Get_unwrapped(_STD move(_Last)));
return _Sample_selection_unchecked(_STD move(_UFirst), _Pop_size, _STD move(_Result), _Count, _RngFunc);
} else {
return _Sample_reservoir_unchecked(_Get_unwrapped(_STD move(_First)), _Get_unwrapped(_STD move(_Last)),
_STD move(_Result), _Count, _RngFunc);
}
}

template <input_range _Rng, weakly_incrementable _Out, class _Urng>
requires (forward_range<_Rng> || random_access_iterator<_Out>)
&& indirectly_copyable<iterator_t<_Rng>, _Out>
&& uniform_random_bit_generator<remove_reference_t<_Urng>>
_Out operator()(_Rng&& _Range, _Out _Result, range_difference_t<_Rng> _Count, _Urng&& _Func) const {
if (_Count <= 0) {
return _Result;
}

_Rng_from_urng<range_difference_t<_Rng>, remove_reference_t<_Urng>> _RngFunc(_Func);
if constexpr (forward_range<_Rng>) {
auto _UFirst = _Ubegin(_Range);
auto _Pop_size = _RANGES distance(_UFirst, _Uend(_Range));
return _Sample_selection_unchecked(_STD move(_UFirst), _Pop_size, _STD move(_Result), _Count, _RngFunc);
} else {
return _Sample_reservoir_unchecked(
_Ubegin(_Range), _Uend(_Range), _STD move(_Result), _Count, _RngFunc);
}
}
// clang-format on
private:
template <class _It, class _Out, class _Rng>
_NODISCARD static _Out _Sample_selection_unchecked(
_It _First, iter_difference_t<_It> _Pop_size, _Out _Result, iter_difference_t<_It> _Count, _Rng& _RngFunc) {
// randomly select _Count elements from [_First, _First + _Pop_size) into _Result
_STL_INTERNAL_STATIC_ASSERT(forward_iterator<_It>);
_STL_INTERNAL_STATIC_ASSERT(weakly_incrementable<_Out>);
_STL_INTERNAL_STATIC_ASSERT(indirectly_copyable<_It, _Out>);

if (_Count > _Pop_size) {
_Count = _Pop_size;
}

for (; _Pop_size > 0; ++_First, (void) --_Pop_size) {
if (_RngFunc(_Pop_size) < _Count) {
*_Result = *_First;
++_Result;
if (--_Count == 0) {
break;
}
}
}

return _Result;
}

template <class _It, class _Se, class _Out, class _Rng>
_NODISCARD static _Out _Sample_reservoir_unchecked(
_It _First, const _Se _Last, _Out _Result, const iter_difference_t<_It> _Count, _Rng& _RngFunc) {
// randomly select _Count elements from [_First, _Last) into _Result
_STL_INTERNAL_STATIC_ASSERT(input_iterator<_It>);
_STL_INTERNAL_STATIC_ASSERT(sentinel_for<_Se, _It>);
_STL_INTERNAL_STATIC_ASSERT(random_access_iterator<_Out>);
_STL_INTERNAL_STATIC_ASSERT(indirectly_copyable<_It, _Out>);

iter_difference_t<_It> _Pop_size{};
for (; _Pop_size < _Count; ++_Pop_size, (void) ++_First) {
if (_First == _Last) {
return _Result + _Pop_size;
}

*(_Result + _Pop_size) = *_First;
}
for (; _First != _Last; ++_First) {
const auto _Idx = _RngFunc(++_Pop_size);
if (_Idx < _Count) {
*(_Result + _Idx) = *_First;
}
}

return _Result + _Count;
}
};

inline constexpr _Sample_fn sample{_Not_quite_object::_Construct_tag{}};
} // namespace ranges
#endif // __cpp_lib_concepts
#endif // _HAS_CXX17

// FUNCTION TEMPLATE shuffle WITH URNG
Expand Down Expand Up @@ -4964,6 +5086,66 @@ void shuffle(_RanIt _First, _RanIt _Last, _Urng&& _Func) { // shuffle [_First, _
_Random_shuffle1(_First, _Last, _RngFunc);
}

#ifdef __cpp_lib_concepts
namespace ranges {
// VARIABLE ranges::shuffle
class _Shuffle_fn : private _Not_quite_object {
public:
using _Not_quite_object::_Not_quite_object;

// clang-format off
template <random_access_iterator _It, sentinel_for<_It> _Se, class _Urng>
requires permutable<_It> && uniform_random_bit_generator<remove_reference_t<_Urng>>
_It operator()(_It _First, _Se _Last, _Urng&& _Func) const {
_Adl_verify_range(_First, _Last);

_Rng_from_urng<iter_difference_t<_It>, remove_reference_t<_Urng>> _RngFunc(_Func);
auto _UResult =
_Shuffle_unchecked(_Get_unwrapped(_STD move(_First)), _Get_unwrapped(_STD move(_Last)), _RngFunc);

_Seek_wrapped(_First, _STD move(_UResult));
return _First;
}

template <random_access_range _Rng, class _Urng>
requires permutable<iterator_t<_Rng>> && uniform_random_bit_generator<remove_reference_t<_Urng>>
borrowed_iterator_t<_Rng> operator()(_Rng&& _Range, _Urng&& _Func) const {
_Rng_from_urng<range_difference_t<_Rng>, remove_reference_t<_Urng>> _RngFunc(_Func);

return _Rewrap_iterator(_Range, _Shuffle_unchecked(_Ubegin(_Range), _Uend(_Range), _RngFunc));
}
// clang-format on
private:
template <class _It, class _Se, class _Rng>
_NODISCARD static _It _Shuffle_unchecked(_It _First, const _Se _Last, _Rng& _Func) {
// shuffle [_First, _Last) using random function _Func
_STL_INTERNAL_STATIC_ASSERT(random_access_iterator<_It>);
_STL_INTERNAL_STATIC_ASSERT(sentinel_for<_Se, _It>);
_STL_INTERNAL_STATIC_ASSERT(permutable<_It>);

if (_First == _Last) {
return _First;
}
using _Diff = iter_difference_t<_It>;

auto _Target = _First;
_Diff _Target_index = 1;
for (; ++_Target != _Last; ++_Target_index) {
// randomly place an element from [_First, _Target] at _Target
const _Diff _Off = _Func(_Target_index + 1);
_STL_ASSERT(0 <= _Off && _Off <= _Target_index, "random value out of range");
if (_Off != _Target_index) { // avoid self-move-assignment
_RANGES iter_swap(_Target, _First + _Off);
}
}
return _Target;
}
};

inline constexpr _Shuffle_fn shuffle{_Not_quite_object::_Construct_tag{}};
} // namespace ranges
#endif // __cpp_lib_concepts

#if _HAS_AUTO_PTR_ETC
// FUNCTION TEMPLATE random_shuffle WITH RANDOM FN
template <class _RanIt, class _RngFn>
Expand Down
23 changes: 0 additions & 23 deletions stl/inc/random
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,6 @@
#include <vector>
#include <xstring>

#ifdef __cpp_lib_concepts
#include <concepts>
#endif // __cpp_lib_concepts

#pragma pack(push, _CRT_PACKING)
#pragma warning(push, _STL_WARNING_LEVEL)
#pragma warning(disable : _STL_DISABLED_WARNINGS)
Expand Down Expand Up @@ -52,25 +48,6 @@ _STD_BEGIN
"unsigned short, unsigned int, unsigned long, or unsigned long long"); \
_RNG_PROHIBIT_CHAR(_CheckedType)


#ifdef __cpp_lib_concepts
// STRUCT TEMPLATE _Require_constant
template <auto>
struct _Require_constant; // not defined; _Require_constant<E> is a valid type if E is a constant expression

// CONCEPT uniform_random_bit_generator
// clang-format off
template <class _Ty>
concept uniform_random_bit_generator = invocable<_Ty&> && unsigned_integral<invoke_result_t<_Ty&>> && requires {
{ (_Ty::min)() } -> same_as<invoke_result_t<_Ty&>>;
{ (_Ty::max)() } -> same_as<invoke_result_t<_Ty&>>;
typename _Require_constant<(_Ty::min)()>;
typename _Require_constant<(_Ty::max)()>;
requires (_Ty::min)() < (_Ty::max)();
};
// clang-format on
#endif // __cpp_lib_concepts

// ALIAS TEMPLATE _Enable_if_seed_seq_t
template <class _Seed_seq, class _Self, class _Engine = _Self>
using _Enable_if_seed_seq_t = enable_if_t<
Expand Down
2 changes: 2 additions & 0 deletions tests/std/test.lst
Original file line number Diff line number Diff line change
Expand Up @@ -279,8 +279,10 @@ tests\P0896R4_ranges_alg_replace_copy
tests\P0896R4_ranges_alg_replace_copy_if
tests\P0896R4_ranges_alg_replace_if
tests\P0896R4_ranges_alg_reverse
tests\P0896R4_ranges_alg_sample
tests\P0896R4_ranges_alg_search
tests\P0896R4_ranges_alg_search_n
tests\P0896R4_ranges_alg_shuffle
tests\P0896R4_ranges_alg_swap_ranges
tests\P0896R4_ranges_alg_transform_binary
tests\P0896R4_ranges_alg_transform_unary
Expand Down
4 changes: 4 additions & 0 deletions tests/std/tests/P0896R4_ranges_alg_sample/env.lst
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

RUNALL_INCLUDE ..\concepts_matrix.lst
84 changes: 84 additions & 0 deletions tests/std/tests/P0896R4_ranges_alg_sample/test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include <algorithm>
#include <cassert>
#include <concepts>
#include <cstdio>
#include <random>
#include <ranges>
#include <utility>

#include <range_algorithm_support.hpp>
using namespace std;

const unsigned int seed = random_device{}();
mt19937 gen{seed};

struct instantiator {
static constexpr int reservoir[5] = {13, 42, 71, 112, 1729};

template <ranges::input_range Read, indirectly_writable<ranges::range_reference_t<Read>> Write>
static void call() {
using ranges::sample, ranges::equal, ranges::is_sorted, ranges::iterator_t;

if constexpr (forward_iterator<iterator_t<Read>> || random_access_iterator<Write>) {
auto copy_gen = gen;

{ // Validate iterator + sentinel overload
int output1[3] = {-1, -1, -1};
int output2[3] = {-1, -1, -1};
Read wrapped_input{reservoir};

auto result1 = sample(wrapped_input.begin(), wrapped_input.end(), Write{output1}, 3, gen);
STATIC_ASSERT(same_as<decltype(result1), Write>);
assert(result1.peek() == end(output1));

// check repeatability
Read wrapped_input2{reservoir};
auto result2 = sample(wrapped_input2.begin(), wrapped_input2.end(), Write{output2}, 3, copy_gen);
assert(equal(output1, output2));
assert(result2.peek() == end(output2));

if (ranges::forward_range<Read>) {
// verify stability
assert(is_sorted(output1));
} else {
// ensure ordering for equality test
sort(begin(output1), end(output1));
}
assert(includes(cbegin(reservoir), cend(reservoir), cbegin(output1), cend(output1)));
}
{ // Validate range overload
int output1[3] = {-1, -1, -1};
int output2[3] = {-1, -1, -1};
Read wrapped_input{reservoir};

auto result1 = sample(wrapped_input, Write{output1}, 3, gen);
STATIC_ASSERT(same_as<decltype(result1), Write>);
assert(result1.peek() == end(output1));

// check repeatability
Read wrapped_input2{reservoir};
auto result2 = sample(wrapped_input2, Write{output2}, 3, copy_gen);
assert(equal(output1, output2));
assert(result2.peek() == end(output2));

if (ranges::forward_range<Read>) {
// verify stability
assert(is_sorted(output1));
} else {
// ensure ordering for equality test
sort(begin(output1), end(output1));
}
assert(includes(cbegin(reservoir), cend(reservoir), cbegin(output1), cend(output1)));
}
}
}
};

int main() {
printf("Using seed: %u\n", seed);

test_in_write<instantiator, const int, int>();
}
Loading

0 comments on commit 99241dc

Please sign in to comment.