diff --git a/src/distributions/uniform.rs b/src/distributions/uniform.rs index 7e1bdcb2c81..1029a3f8885 100644 --- a/src/distributions/uniform.rs +++ b/src/distributions/uniform.rs @@ -433,7 +433,7 @@ pub struct UniformInt { } macro_rules! uniform_int_impl { - ($ty:ty, $unsigned:ident, $u_large:ident) => { + ($ty:ty, $unsigned:ident, $u_large:ident, $u_extra_large:ident) => { impl SampleUniform for $ty { type Sampler = UniformInt<$ty>; } @@ -536,9 +536,8 @@ macro_rules! uniform_int_impl { "UniformSampler::sample_single_inclusive: low > high" ); let range = high.wrapping_sub(low).wrapping_add(1) as $unsigned as $u_large; - // If the above resulted in wrap-around to 0, the range is $ty::MIN..=$ty::MAX, - // and any integer will do. if range == 0 { + // Range is MAX+1 (unrepresentable), so we need a special case return rng.gen(); } @@ -564,21 +563,188 @@ macro_rules! uniform_int_impl { } } } + + impl UniformInt<$ty> { + /// Sample single inclusive, using ONeill's method + #[inline] + pub fn sample_single_inclusive_oneill( + low_b: B1, high_b: B2, rng: &mut R, + ) -> $ty + where + B1: SampleBorrow<$ty> + Sized, + B2: SampleBorrow<$ty> + Sized, + { + let low = *low_b.borrow(); + let high = *high_b.borrow(); + assert!( + low <= high, + "UniformSampler::sample_single_inclusive: low > high" + ); + let range = high.wrapping_sub(low).wrapping_add(1) as $unsigned as $u_large; + if range == 0 { + // Range is MAX+1 (unrepresentable), so we need a special case + return rng.gen(); + } + + // we use the "Debiased Int Mult (t-opt, m-opt)" rejection sampling method + // described here https://www.pcg-random.org/posts/bounded-rands.html + // and here https://github.com/imneme/bounded-rands + + let (mut hi, mut lo) = rng.gen::<$u_large>().wmul(range); + if lo < range { + let mut threshold = range.wrapping_neg(); + // this shortcut works best with large ranges + if threshold >= range { + threshold -= range; + if threshold >= range { + threshold %= range; + } + } + while lo < threshold { + let (new_hi, new_lo) = rng.gen::<$u_large>().wmul(range); + hi = new_hi; + lo = new_lo; + } + } + low.wrapping_add(hi as $ty) + } + + /// Sample single inclusive, using Canon's method + #[inline] + pub fn sample_single_inclusive_canon( + low_b: B1, high_b: B2, rng: &mut R, + ) -> $ty + where + B1: SampleBorrow<$ty> + Sized, + B2: SampleBorrow<$ty> + Sized, + { + let low = *low_b.borrow(); + let high = *high_b.borrow(); + assert!( + low <= high, + "UniformSampler::sample_single_inclusive: low > high" + ); + let range = high.wrapping_sub(low).wrapping_add(1) as $unsigned as $u_extra_large; + if range == 0 { + // Range is MAX+1 (unrepresentable), so we need a special case + return rng.gen(); + } + + // generate a sample using a sensible integer type + let (mut result, lo_order) = rng.gen::<$u_extra_large>().wmul(range); + + // if the sample is biased... + if lo_order > range.wrapping_neg() { + // ...generate a new sample with 64 more bits, enough that bias is undetectable + let (new_hi_order, _) = + (rng.gen::<$u_extra_large>()).wmul(range as $u_extra_large); + // and adjust if needed + result += lo_order + .checked_add(new_hi_order as $u_extra_large) + .is_none() as $u_extra_large; + } + + low.wrapping_add(result as $ty) + } + + /// Sample single inclusive, using Canon's method with Lemire's early-out + #[inline] + pub fn sample_inclusive_canon_lemire( + low_b: B1, high_b: B2, rng: &mut R, + ) -> $ty + where + B1: SampleBorrow<$ty> + Sized, + B2: SampleBorrow<$ty> + Sized, + { + let low = *low_b.borrow(); + let high = *high_b.borrow(); + assert!( + low <= high, + "UniformSampler::sample_single_inclusive: low > high" + ); + let range = high.wrapping_sub(low).wrapping_add(1) as $unsigned as $u_extra_large; + if range == 0 { + // Range is MAX+1 (unrepresentable), so we need a special case + return rng.gen(); + } + + // generate a sample using a sensible integer type + let (mut result, lo_order) = rng.gen::<$u_extra_large>().wmul(range); + + // if the sample is biased... (since range won't be changing we can further + // improve this check with a modulo) + if lo_order < range.wrapping_neg() % range { + // ...generate a new sample with 64 more bits, enough that bias is undetectable + let (new_hi_order, _) = + (rng.gen::<$u_extra_large>()).wmul(range as $u_extra_large); + // and adjust if needed + result += lo_order + .checked_add(new_hi_order as $u_extra_large) + .is_none() as $u_extra_large; + } + + low.wrapping_add(result as $ty) + } + + /// Sample single inclusive, using the Bitmask method + #[inline] + pub fn sample_single_inclusive_bitmask( + low_b: B1, high_b: B2, rng: &mut R, + ) -> $ty + where + B1: SampleBorrow<$ty> + Sized, + B2: SampleBorrow<$ty> + Sized, + { + let low = *low_b.borrow(); + let high = *high_b.borrow(); + assert!( + low <= high, + "UniformSampler::sample_single_inclusive: low > high" + ); + let mut range = high.wrapping_sub(low).wrapping_add(1) as $unsigned as $u_large; + if range == 0 { + // Range is MAX+1 (unrepresentable), so we need a special case + return rng.gen(); + } + + // the old impl use a mix of methods for different integer sizes, we only use + // the lz method here for a better comparison. + + let mut mask = $u_large::max_value(); + range -= 1; + mask >>= (range | 1).leading_zeros(); + loop { + let x = rng.gen::<$u_large>() & mask; + if x <= range { + return low.wrapping_add(x as $ty); + } + } + } + } }; } - -uniform_int_impl! { i8, u8, u32 } -uniform_int_impl! { i16, u16, u32 } -uniform_int_impl! { i32, u32, u32 } -uniform_int_impl! { i64, u64, u64 } -uniform_int_impl! { i128, u128, u128 } -uniform_int_impl! { isize, usize, usize } -uniform_int_impl! { u8, u8, u32 } -uniform_int_impl! { u16, u16, u32 } -uniform_int_impl! { u32, u32, u32 } -uniform_int_impl! { u64, u64, u64 } -uniform_int_impl! { usize, usize, usize } -uniform_int_impl! { u128, u128, u128 } +uniform_int_impl! { i8, u8, u32, u64 } +uniform_int_impl! { i16, u16, u32, u64 } +uniform_int_impl! { i32, u32, u32, u64 } +uniform_int_impl! { i64, u64, u64, u64 } +uniform_int_impl! { i128, u128, u128, u128 } +uniform_int_impl! { u8, u8, u32, u64 } +uniform_int_impl! { u16, u16, u32, u64 } +uniform_int_impl! { u32, u32, u32, u64 } +uniform_int_impl! { u64, u64, u64, u64 } +uniform_int_impl! { u128, u128, u128, u128 } +#[cfg(any(target_pointer_width = "16", target_pointer_width = "32",))] +mod isize_int_impls { + use super::*; + uniform_int_impl! { isize, usize, usize, u64 } + uniform_int_impl! { usize, usize, usize, u64 } +} +#[cfg(not(any(target_pointer_width = "16", target_pointer_width = "32",)))] +mod isize_int_impls { + use super::*; + uniform_int_impl! { isize, usize, usize, usize } + uniform_int_impl! { usize, usize, usize, usize } +} #[cfg(feature = "simd_support")] macro_rules! uniform_simd_int_impl {