Skip to content

Commit

Permalink
Merge pull request #758 from dhardy/sized
Browse files Browse the repository at this point in the history
Distribution::sample_iter changes
  • Loading branch information
dhardy authored Apr 19, 2019
2 parents 852988b + 25aed2d commit 9828cdf
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 27 deletions.
62 changes: 45 additions & 17 deletions src/distributions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -182,29 +182,35 @@ pub trait Distribution<T> {
/// Create an iterator that generates random values of `T`, using `rng` as
/// the source of randomness.
///
/// Note that this function takes `self` by value. This works since
/// `Distribution<T>` is impl'd for `&D` where `D: Distribution<T>`,
/// however borrowing is not automatic hence `distr.sample_iter(...)` may
/// need to be replaced with `(&distr).sample_iter(...)` to borrow or
/// `(&*distr).sample_iter(...)` to reborrow an existing reference.
///
/// # Example
///
/// ```
/// use rand::thread_rng;
/// use rand::distributions::{Distribution, Alphanumeric, Uniform, Standard};
///
/// let mut rng = thread_rng();
/// let rng = thread_rng();
///
/// // Vec of 16 x f32:
/// let v: Vec<f32> = Standard.sample_iter(&mut rng).take(16).collect();
/// let v: Vec<f32> = Standard.sample_iter(rng).take(16).collect();
///
/// // String:
/// let s: String = Alphanumeric.sample_iter(&mut rng).take(7).collect();
/// let s: String = Alphanumeric.sample_iter(rng).take(7).collect();
///
/// // Dice-rolling:
/// let die_range = Uniform::new_inclusive(1, 6);
/// let mut roll_die = die_range.sample_iter(&mut rng);
/// let mut roll_die = die_range.sample_iter(rng);
/// while roll_die.next().unwrap() != 6 {
/// println!("Not a 6; rolling again!");
/// }
/// ```
fn sample_iter<'a, R>(&'a self, rng: &'a mut R) -> DistIter<'a, Self, R, T>
where Self: Sized, R: Rng
fn sample_iter<R>(self, rng: R) -> DistIter<Self, R, T>
where R: Rng, Self: Sized
{
DistIter {
distr: self,
Expand All @@ -229,20 +235,23 @@ impl<'a, T, D: Distribution<T>> Distribution<T> for &'a D {
///
/// [`sample_iter`]: Distribution::sample_iter
#[derive(Debug)]
pub struct DistIter<'a, D: 'a, R: 'a, T> {
distr: &'a D,
rng: &'a mut R,
pub struct DistIter<D, R, T> {
distr: D,
rng: R,
phantom: ::core::marker::PhantomData<T>,
}

impl<'a, D, R, T> Iterator for DistIter<'a, D, R, T>
where D: Distribution<T>, R: Rng + 'a
impl<D, R, T> Iterator for DistIter<D, R, T>
where D: Distribution<T>, R: Rng
{
type Item = T;

#[inline(always)]
fn next(&mut self) -> Option<T> {
Some(self.distr.sample(self.rng))
// Here, self.rng may be a reference, but we must take &mut anyway.
// Even if sample could take an R: Rng by value, we would need to do this
// since Rng is not copyable and we cannot enforce that this is "reborrowable".
Some(self.distr.sample(&mut self.rng))
}

fn size_hint(&self) -> (usize, Option<usize>) {
Expand All @@ -251,12 +260,12 @@ impl<'a, D, R, T> Iterator for DistIter<'a, D, R, T>
}

#[cfg(rustc_1_26)]
impl<'a, D, R, T> iter::FusedIterator for DistIter<'a, D, R, T>
where D: Distribution<T>, R: Rng + 'a {}
impl<D, R, T> iter::FusedIterator for DistIter<D, R, T>
where D: Distribution<T>, R: Rng {}

#[cfg(features = "nightly")]
impl<'a, D, R, T> iter::TrustedLen for DistIter<'a, D, R, T>
where D: Distribution<T>, R: Rng + 'a {}
impl<D, R, T> iter::TrustedLen for DistIter<D, R, T>
where D: Distribution<T>, R: Rng {}


/// A generic random value distribution, implemented for many primitive types.
Expand Down Expand Up @@ -340,7 +349,8 @@ pub struct Standard;

#[cfg(all(test, feature = "std"))]
mod tests {
use super::Distribution;
use ::Rng;
use super::{Distribution, Uniform};

#[test]
fn test_distributions_iter() {
Expand All @@ -350,4 +360,22 @@ mod tests {
let results: Vec<f32> = distr.sample_iter(&mut rng).take(100).collect();
println!("{:?}", results);
}

#[test]
fn test_make_an_iter() {
fn ten_dice_rolls_other_than_five<'a, R: Rng>(rng: &'a mut R) -> impl Iterator<Item = i32> + 'a {
Uniform::new_inclusive(1, 6)
.sample_iter(rng)
.filter(|x| *x != 5)
.take(10)
}

let mut rng = ::test::rng(211);
let mut count = 0;
for val in ten_dice_rolls_other_than_five(&mut rng) {
assert!(val >= 1 && val <= 6 && val != 5);
count += 1;
}
assert_eq!(count, 10);
}
}
22 changes: 13 additions & 9 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -206,35 +206,39 @@ pub trait Rng: RngCore {

/// Create an iterator that generates values using the given distribution.
///
/// Note that this function takes its arguments by value. This works since
/// `(&mut R): Rng where R: Rng` and
/// `(&D): Distribution where D: Distribution`,
/// however borrowing is not automatic hence `rng.sample_iter(...)` may
/// need to be replaced with `(&mut rng).sample_iter(...)`.
///
/// # Example
///
/// ```
/// use rand::{thread_rng, Rng};
/// use rand::distributions::{Alphanumeric, Uniform, Standard};
///
/// let mut rng = thread_rng();
/// let rng = thread_rng();
///
/// // Vec of 16 x f32:
/// let v: Vec<f32> = thread_rng().sample_iter(&Standard).take(16).collect();
/// let v: Vec<f32> = rng.sample_iter(Standard).take(16).collect();
///
/// // String:
/// let s: String = rng.sample_iter(&Alphanumeric).take(7).collect();
/// let s: String = rng.sample_iter(Alphanumeric).take(7).collect();
///
/// // Combined values
/// println!("{:?}", thread_rng().sample_iter(&Standard).take(5)
/// println!("{:?}", rng.sample_iter(Standard).take(5)
/// .collect::<Vec<(f64, bool)>>());
///
/// // Dice-rolling:
/// let die_range = Uniform::new_inclusive(1, 6);
/// let mut roll_die = rng.sample_iter(&die_range);
/// let mut roll_die = rng.sample_iter(die_range);
/// while roll_die.next().unwrap() != 6 {
/// println!("Not a 6; rolling again!");
/// }
/// ```
fn sample_iter<'a, T, D: Distribution<T>>(
&'a mut self, distr: &'a D,
) -> distributions::DistIter<'a, D, Self, T>
where Self: Sized {
fn sample_iter<T, D>(self, distr: D) -> distributions::DistIter<D, Self, T>
where D: Distribution<T>, Self: Sized {
distr.sample_iter(self)
}

Expand Down
2 changes: 1 addition & 1 deletion src/rngs/thread.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ const THREAD_RNG_RESEED_THRESHOLD: u64 = 32*1024*1024; // 32 MiB
/// [`ReseedingRng`]: crate::rngs::adapter::ReseedingRng
/// [`StdRng`]: crate::rngs::StdRng
/// [HC-128]: rand_hc::Hc128Rng
#[derive(Clone, Debug)]
#[derive(Copy, Clone, Debug)]
pub struct ThreadRng {
// use of raw pointer implies type is neither Send nor Sync
rng: *mut ReseedingRng<Hc128Core, OsRng>,
Expand Down

0 comments on commit 9828cdf

Please sign in to comment.