Skip to content

Commit

Permalink
Implement Rng.gen_ratio() and Bernoulli::new_ratio()
Browse files Browse the repository at this point in the history
  • Loading branch information
sicking committed Jun 8, 2018
1 parent 276c8be commit a3a9fc3
Show file tree
Hide file tree
Showing 3 changed files with 113 additions and 11 deletions.
24 changes: 24 additions & 0 deletions benches/misc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,30 @@ fn misc_gen_bool_var(b: &mut Bencher) {
})
}

#[bench]
fn misc_gen_ratio_const(b: &mut Bencher) {
let mut rng = StdRng::from_rng(&mut thread_rng()).unwrap();
b.iter(|| {
let mut accum = true;
for _ in 0..::RAND_BENCH_N {
accum ^= rng.gen_ratio(2, 3);
}
accum
})
}

#[bench]
fn misc_gen_ratio_var(b: &mut Bencher) {
let mut rng = StdRng::from_rng(&mut thread_rng()).unwrap();
b.iter(|| {
let mut accum = true;
for i in 2..(::RAND_BENCH_N as u32 + 2) {
accum ^= rng.gen_ratio(i, i + 1);
}
accum
})
}

#[bench]
fn misc_bernoulli_const(b: &mut Bencher) {
let mut rng = StdRng::from_rng(&mut thread_rng()).unwrap();
Expand Down
45 changes: 38 additions & 7 deletions src/distributions/bernoulli.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,28 @@ impl Bernoulli {
};
Bernoulli { p_int }
}

/// Construct a new `Bernoulli` with the probability of success of
/// `numerator`-in-`denominator`. I.e. `new_ratio(2, 3)` will return
/// a `Bernoulli` with a 2-in-3 chance, or about 67%, of returning `true`.
///
/// If `numerator == denominator` then the returned `Bernoulli` will always
/// return `true`. If `numerator == 0` it will always return `false`.
///
/// # Panics
///
/// If `denominator == 0` or `numerator > denominator`.
///
#[inline]
pub fn from_ratio(numerator: u32, denominator: u32) -> Bernoulli {
assert!(numerator <= denominator);
if numerator == denominator {
return Bernoulli { p_int: ::core::u64::MAX }
}
const SCALE: f64 = 2.0 * (1u64 << 63) as f64;
let p_int = ((numerator as f64 / denominator as f64) * SCALE) as u64;
Bernoulli { p_int }
}
}

impl Distribution<bool> for Bernoulli {
Expand Down Expand Up @@ -103,18 +125,27 @@ mod test {
#[test]
fn test_average() {
const P: f64 = 0.3;
let d = Bernoulli::new(P);
const N: u32 = 10_000_000;
const NUM: u32 = 3;
const DENOM: u32 = 10;
let d1 = Bernoulli::new(P);
let d2 = Bernoulli::from_ratio(NUM, DENOM);
const N: u32 = 100_000;

let mut sum: u32 = 0;
let mut sum1: u32 = 0;
let mut sum2: u32 = 0;
let mut rng = ::test::rng(2);
for _ in 0..N {
if d.sample(&mut rng) {
sum += 1;
if d1.sample(&mut rng) {
sum1 += 1;
}
if d2.sample(&mut rng) {
sum2 += 1;
}
}
let avg = (sum as f64) / (N as f64);
let avg1 = (sum1 as f64) / (N as f64);
assert!((avg1 - P).abs() < 5e-3);

assert!((avg - P).abs() < 1e-3);
let avg2 = (sum2 as f64) / (N as f64);
assert!((avg2 - (NUM as f64)/(DENOM as f64)).abs() < 5e-3);
}
}
55 changes: 51 additions & 4 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,7 @@ pub trait Rng: RngCore {
/// ```
///
/// [`Uniform`]: distributions/uniform/struct.Uniform.html
fn gen_range<T: PartialOrd + SampleUniform>(&mut self, low: T, high: T) -> T {
fn gen_range<T: SampleUniform>(&mut self, low: T, high: T) -> T {
T::Sampler::sample_single(low, high, self)
}

Expand Down Expand Up @@ -509,7 +509,8 @@ pub trait Rng: RngCore {

/// Return a bool with a probability `p` of being true.
///
/// This is a wrapper around [`distributions::Bernoulli`].
/// See also the [`Bernoulli`] distribution, which may be faster if
/// sampling from the same probability repeatedly.
///
/// # Example
///
Expand All @@ -522,15 +523,44 @@ pub trait Rng: RngCore {
///
/// # Panics
///
/// If `p` < 0 or `p` > 1.
/// If `p < 0` or `p > 1`.
///
/// [`distributions::Bernoulli`]: distributions/bernoulli/struct.Bernoulli.html
/// [`Bernoulli`]: distributions/bernoulli/struct.Bernoulli.html
#[inline]
fn gen_bool(&mut self, p: f64) -> bool {
let d = distributions::Bernoulli::new(p);
self.sample(d)
}

/// Return a bool with a probability of `numerator/denominator` of being
/// true. I.e. `gen_ratio(2, 3)` has chance of 2 in 3, or about 67%, of
/// returning true. If `numerator == denominator`, then the returned value
/// is guaranteed to be `true`. If `numerator == 0`, then the returned
/// value is guaranteed to be `false`.
///
/// See also the [`Bernoulli`] distribution, which may be faster if
/// sampling from the same `numerator` and `denominator` repeatedly.
///
/// # Panics
///
/// If `denominator == 0` or `numerator > denominator`.
///
/// # Example
///
/// ```
/// use rand::{thread_rng, Rng};
///
/// let mut rng = thread_rng();
/// println!("{}", rng.gen_ratio(2, 3));
/// ```
///
/// [`Bernoulli`]: distributions/bernoulli/struct.Bernoulli.html
#[inline]
fn gen_ratio(&mut self, numerator: u32, denominator: u32) -> bool {
let d = distributions::Bernoulli::from_ratio(numerator, denominator);
self.sample(d)
}

/// Return a random element from `values`.
///
/// Return `None` if `values` is empty.
Expand Down Expand Up @@ -1017,4 +1047,21 @@ mod test {
(u8, i8, u16, i16, u32, i32, u64, i64),
(f32, (f64, (f64,)))) = random();
}

#[test]
fn test_gen_ratio_average() {
const NUM: u32 = 3;
const DENOM: u32 = 10;
const N: u32 = 100_000;

let mut sum: u32 = 0;
let mut rng = rng(111);
for _ in 0..N {
if rng.gen_ratio(NUM, DENOM) {
sum += 1;
}
}
let avg = (sum as f64) / (N as f64);
assert!((avg - (NUM as f64)/(DENOM as f64)).abs() < 1e-3);
}
}

0 comments on commit a3a9fc3

Please sign in to comment.