Skip to content

Commit

Permalink
Impl ONeill, Canon, Canon-Lemire and Bitmask methods for integer types
Browse files Browse the repository at this point in the history
This is based on @TheIronBorn's work (#1154, #1172), with some changes.
  • Loading branch information
dhardy committed Feb 24, 2022
1 parent 255ff71 commit 00869d7
Showing 1 changed file with 182 additions and 16 deletions.
198 changes: 182 additions & 16 deletions src/distributions/uniform.rs
Original file line number Diff line number Diff line change
Expand Up @@ -433,7 +433,7 @@ pub struct UniformInt<X> {
}

macro_rules! uniform_int_impl {
($ty:ty, $unsigned:ident, $u_large:ident) => {
($ty:ty, $unsigned:ident, $u_large:ident, $u_extra_large:ident) => {
impl SampleUniform for $ty {
type Sampler = UniformInt<$ty>;
}
Expand Down Expand Up @@ -536,9 +536,8 @@ macro_rules! uniform_int_impl {
"UniformSampler::sample_single_inclusive: low > high"
);
let range = high.wrapping_sub(low).wrapping_add(1) as $unsigned as $u_large;
// If the above resulted in wrap-around to 0, the range is $ty::MIN..=$ty::MAX,
// and any integer will do.
if range == 0 {
// Range is MAX+1 (unrepresentable), so we need a special case
return rng.gen();
}

Expand All @@ -564,21 +563,188 @@ macro_rules! uniform_int_impl {
}
}
}

impl UniformInt<$ty> {
/// Sample single inclusive, using ONeill's method
#[inline]
pub fn sample_single_inclusive_oneill<R: Rng + ?Sized, B1, B2>(
low_b: B1, high_b: B2, rng: &mut R,
) -> $ty
where
B1: SampleBorrow<$ty> + Sized,
B2: SampleBorrow<$ty> + Sized,
{
let low = *low_b.borrow();
let high = *high_b.borrow();
assert!(
low <= high,
"UniformSampler::sample_single_inclusive: low > high"
);
let range = high.wrapping_sub(low).wrapping_add(1) as $unsigned as $u_large;
if range == 0 {
// Range is MAX+1 (unrepresentable), so we need a special case
return rng.gen();
}

// we use the "Debiased Int Mult (t-opt, m-opt)" rejection sampling method
// described here https://www.pcg-random.org/posts/bounded-rands.html
// and here https://github.com/imneme/bounded-rands

let (mut hi, mut lo) = rng.gen::<$u_large>().wmul(range);
if lo < range {
let mut threshold = range.wrapping_neg();
// this shortcut works best with large ranges
if threshold >= range {
threshold -= range;
if threshold >= range {
threshold %= range;
}
}
while lo < threshold {
let (new_hi, new_lo) = rng.gen::<$u_large>().wmul(range);
hi = new_hi;
lo = new_lo;
}
}
low.wrapping_add(hi as $ty)
}

/// Sample single inclusive, using Canon's method
#[inline]
pub fn sample_single_inclusive_canon<R: Rng + ?Sized, B1, B2>(
low_b: B1, high_b: B2, rng: &mut R,
) -> $ty
where
B1: SampleBorrow<$ty> + Sized,
B2: SampleBorrow<$ty> + Sized,
{
let low = *low_b.borrow();
let high = *high_b.borrow();
assert!(
low <= high,
"UniformSampler::sample_single_inclusive: low > high"
);
let range = high.wrapping_sub(low).wrapping_add(1) as $unsigned as $u_extra_large;
if range == 0 {
// Range is MAX+1 (unrepresentable), so we need a special case
return rng.gen();
}

// generate a sample using a sensible integer type
let (mut result, lo_order) = rng.gen::<$u_extra_large>().wmul(range);

// if the sample is biased...
if lo_order > range.wrapping_neg() {
// ...generate a new sample with 64 more bits, enough that bias is undetectable
let (new_hi_order, _) =
(rng.gen::<$u_extra_large>()).wmul(range as $u_extra_large);
// and adjust if needed
result += lo_order
.checked_add(new_hi_order as $u_extra_large)
.is_none() as $u_extra_large;
}

low.wrapping_add(result as $ty)
}

/// Sample single inclusive, using Canon's method with Lemire's early-out
#[inline]
pub fn sample_inclusive_canon_lemire<R: Rng + ?Sized, B1, B2>(
low_b: B1, high_b: B2, rng: &mut R,
) -> $ty
where
B1: SampleBorrow<$ty> + Sized,
B2: SampleBorrow<$ty> + Sized,
{
let low = *low_b.borrow();
let high = *high_b.borrow();
assert!(
low <= high,
"UniformSampler::sample_single_inclusive: low > high"
);
let range = high.wrapping_sub(low).wrapping_add(1) as $unsigned as $u_extra_large;
if range == 0 {
// Range is MAX+1 (unrepresentable), so we need a special case
return rng.gen();
}

// generate a sample using a sensible integer type
let (mut result, lo_order) = rng.gen::<$u_extra_large>().wmul(range);

// if the sample is biased... (since range won't be changing we can further
// improve this check with a modulo)
if lo_order < range.wrapping_neg() % range {
// ...generate a new sample with 64 more bits, enough that bias is undetectable
let (new_hi_order, _) =
(rng.gen::<$u_extra_large>()).wmul(range as $u_extra_large);
// and adjust if needed
result += lo_order
.checked_add(new_hi_order as $u_extra_large)
.is_none() as $u_extra_large;
}

low.wrapping_add(result as $ty)
}

/// Sample single inclusive, using the Bitmask method
#[inline]
pub fn sample_single_inclusive_bitmask<R: Rng + ?Sized, B1, B2>(
low_b: B1, high_b: B2, rng: &mut R,
) -> $ty
where
B1: SampleBorrow<$ty> + Sized,
B2: SampleBorrow<$ty> + Sized,
{
let low = *low_b.borrow();
let high = *high_b.borrow();
assert!(
low <= high,
"UniformSampler::sample_single_inclusive: low > high"
);
let mut range = high.wrapping_sub(low).wrapping_add(1) as $unsigned as $u_large;
if range == 0 {
// Range is MAX+1 (unrepresentable), so we need a special case
return rng.gen();
}

// the old impl use a mix of methods for different integer sizes, we only use
// the lz method here for a better comparison.

let mut mask = $u_large::max_value();
range -= 1;
mask >>= (range | 1).leading_zeros();
loop {
let x = rng.gen::<$u_large>() & mask;
if x <= range {
return low.wrapping_add(x as $ty);
}
}
}
}
};
}

uniform_int_impl! { i8, u8, u32 }
uniform_int_impl! { i16, u16, u32 }
uniform_int_impl! { i32, u32, u32 }
uniform_int_impl! { i64, u64, u64 }
uniform_int_impl! { i128, u128, u128 }
uniform_int_impl! { isize, usize, usize }
uniform_int_impl! { u8, u8, u32 }
uniform_int_impl! { u16, u16, u32 }
uniform_int_impl! { u32, u32, u32 }
uniform_int_impl! { u64, u64, u64 }
uniform_int_impl! { usize, usize, usize }
uniform_int_impl! { u128, u128, u128 }
uniform_int_impl! { i8, u8, u32, u64 }
uniform_int_impl! { i16, u16, u32, u64 }
uniform_int_impl! { i32, u32, u32, u64 }
uniform_int_impl! { i64, u64, u64, u64 }
uniform_int_impl! { i128, u128, u128, u128 }
uniform_int_impl! { u8, u8, u32, u64 }
uniform_int_impl! { u16, u16, u32, u64 }
uniform_int_impl! { u32, u32, u32, u64 }
uniform_int_impl! { u64, u64, u64, u64 }
uniform_int_impl! { u128, u128, u128, u128 }
#[cfg(any(target_pointer_width = "16", target_pointer_width = "32",))]
mod isize_int_impls {
use super::*;
uniform_int_impl! { isize, usize, usize, u64 }
uniform_int_impl! { usize, usize, usize, u64 }
}
#[cfg(not(any(target_pointer_width = "16", target_pointer_width = "32",)))]
mod isize_int_impls {
use super::*;
uniform_int_impl! { isize, usize, usize, usize }
uniform_int_impl! { usize, usize, usize, usize }
}

#[cfg(feature = "simd_support")]
macro_rules! uniform_simd_int_impl {
Expand Down

0 comments on commit 00869d7

Please sign in to comment.