Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

UniformFloat: allow inclusion of high in all cases #1462

Merged
merged 8 commits into from
Jul 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ You may also find the [Upgrade Guide](https://rust-random.github.io/book/update.
- Move all benchmarks to new `benches` crate (#1439)
- Annotate panicking methods with `#[track_caller]` (#1442, #1447)
- Enable feature `small_rng` by default (#1455)
- Allow `UniformFloat::new` samples and `UniformFloat::sample_single` to yield `high` (#1462)

## [0.9.0-alpha.1] - 2024-03-18
- Add the `Slice::num_choices` method to the Slice distribution (#1402)
Expand Down
165 changes: 58 additions & 107 deletions src/distributions/uniform.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@
//! Those methods should include an assertion to check the range is valid (i.e.
//! `low < high`). The example below merely wraps another back-end.
//!
//! The `new`, `new_inclusive` and `sample_single` functions use arguments of
//! The `new`, `new_inclusive`, `sample_single` and `sample_single_inclusive`
//! functions use arguments of
//! type `SampleBorrow<X>` to support passing in values by reference or
//! by value. In the implementation of these functions, you can choose to
//! simply use the reference returned by [`SampleBorrow::borrow`], or you can choose
Expand Down Expand Up @@ -207,6 +208,11 @@ impl<X: SampleUniform> Uniform<X> {
/// Create a new `Uniform` instance, which samples uniformly from the half
/// open range `[low, high)` (excluding `high`).
///
/// For discrete types (e.g. integers), samples will always be strictly less
/// than `high`. For (approximations of) continuous types (e.g. `f32`, `f64`),
/// samples may equal `high` due to loss of precision but may not be
/// greater than `high`.
///
/// Fails if `low >= high`, or if `low`, `high` or the range `high - low` is
/// non-finite. In release mode, only the range is checked.
pub fn new<B1, B2>(low: B1, high: B2) -> Result<Uniform<X>, Error>
Expand Down Expand Up @@ -265,6 +271,11 @@ pub trait UniformSampler: Sized {

/// Construct self, with inclusive lower bound and exclusive upper bound `[low, high)`.
///
/// For discrete types (e.g. integers), samples will always be strictly less
/// than `high`. For (approximations of) continuous types (e.g. `f32`, `f64`),
/// samples may equal `high` due to loss of precision but may not be
/// greater than `high`.
///
/// Usually users should not call this directly but prefer to use
/// [`Uniform::new`].
fn new<B1, B2>(low: B1, high: B2) -> Result<Self, Error>
Expand All @@ -287,6 +298,11 @@ pub trait UniformSampler: Sized {
/// Sample a single value uniformly from a range with inclusive lower bound
/// and exclusive upper bound `[low, high)`.
///
/// For discrete types (e.g. integers), samples will always be strictly less
/// than `high`. For (approximations of) continuous types (e.g. `f32`, `f64`),
/// samples may equal `high` due to loss of precision but may not be
/// greater than `high`.
///
/// By default this is implemented using
/// `UniformSampler::new(low, high).sample(rng)`. However, for some types
/// more optimal implementations for single usage may be provided via this
Expand Down Expand Up @@ -908,6 +924,33 @@ pub struct UniformFloat<X> {

macro_rules! uniform_float_impl {
($($meta:meta)?, $ty:ty, $uty:ident, $f_scalar:ident, $u_scalar:ident, $bits_to_discard:expr) => {
$(#[cfg($meta)])?
impl UniformFloat<$ty> {
/// Construct, reducing `scale` as required to ensure that rounding
/// can never yield values greater than `high`.
///
/// Note: though it may be tempting to use a variant of this method
/// to ensure that samples from `[low, high)` are always strictly
/// less than `high`, this approach may be very slow where
/// `scale.abs()` is much smaller than `high.abs()`
/// (example: `low=0.99999999997819644, high=1.`).
fn new_bounded(low: $ty, high: $ty, mut scale: $ty) -> Self {
let max_rand = <$ty>::splat(1.0 as $f_scalar - $f_scalar::EPSILON);

loop {
let mask = (scale * max_rand + low).gt_mask(high);
if !mask.any() {
break;
}
scale = scale.decrease_masked(mask);
}

debug_assert!(<$ty>::splat(0.0).all_le(scale));

UniformFloat { low, scale }
}
}

$(#[cfg($meta)])?
impl SampleUniform for $ty {
type Sampler = UniformFloat<$ty>;
Expand All @@ -931,26 +974,13 @@ macro_rules! uniform_float_impl {
if !(low.all_lt(high)) {
return Err(Error::EmptyRange);
}
let max_rand = <$ty>::splat(
($u_scalar::MAX >> $bits_to_discard).into_float_with_exponent(0) - 1.0,
);

let mut scale = high - low;
let scale = high - low;
if !(scale.all_finite()) {
return Err(Error::NonFinite);
}

loop {
let mask = (scale * max_rand + low).ge_mask(high);
if !mask.any() {
break;
}
scale = scale.decrease_masked(mask);
}

debug_assert!(<$ty>::splat(0.0).all_le(scale));

Ok(UniformFloat { low, scale })
Ok(Self::new_bounded(low, high, scale))
}

fn new_inclusive<B1, B2>(low_b: B1, high_b: B2) -> Result<Self, Error>
Expand All @@ -967,26 +997,14 @@ macro_rules! uniform_float_impl {
if !low.all_le(high) {
return Err(Error::EmptyRange);
}
let max_rand = <$ty>::splat(
($u_scalar::MAX >> $bits_to_discard).into_float_with_exponent(0) - 1.0,
);

let mut scale = (high - low) / max_rand;
let max_rand = <$ty>::splat(1.0 as $f_scalar - $f_scalar::EPSILON);
let scale = (high - low) / max_rand;
if !scale.all_finite() {
return Err(Error::NonFinite);
}

loop {
let mask = (scale * max_rand + low).gt_mask(high);
if !mask.any() {
break;
}
scale = scale.decrease_masked(mask);
}

debug_assert!(<$ty>::splat(0.0).all_le(scale));

Ok(UniformFloat { low, scale })
Ok(Self::new_bounded(low, high, scale))
}

fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Self::X {
Expand All @@ -1010,72 +1028,7 @@ macro_rules! uniform_float_impl {
B1: SampleBorrow<Self::X> + Sized,
B2: SampleBorrow<Self::X> + Sized,
{
let low = *low_b.borrow();
let high = *high_b.borrow();
#[cfg(debug_assertions)]
if !low.all_finite() || !high.all_finite() {
return Err(Error::NonFinite);
}
if !low.all_lt(high) {
return Err(Error::EmptyRange);
}
let mut scale = high - low;
if !scale.all_finite() {
return Err(Error::NonFinite);
}

loop {
// Generate a value in the range [1, 2)
let value1_2 =
(rng.random::<$uty>() >> $uty::splat($bits_to_discard)).into_float_with_exponent(0);

// Get a value in the range [0, 1) to avoid overflow when multiplying by scale
let value0_1 = value1_2 - <$ty>::splat(1.0);

// Doing multiply before addition allows some architectures
// to use a single instruction.
let res = value0_1 * scale + low;

debug_assert!(low.all_le(res) || !scale.all_finite());
if res.all_lt(high) {
return Ok(res);
}

// This handles a number of edge cases.
// * `low` or `high` is NaN. In this case `scale` and
// `res` are going to end up as NaN.
// * `low` is negative infinity and `high` is finite.
// `scale` is going to be infinite and `res` will be
// NaN.
// * `high` is positive infinity and `low` is finite.
// `scale` is going to be infinite and `res` will
// be infinite or NaN (if value0_1 is 0).
// * `low` is negative infinity and `high` is positive
// infinity. `scale` will be infinite and `res` will
// be NaN.
// * `low` and `high` are finite, but `high - low`
// overflows to infinite. `scale` will be infinite
// and `res` will be infinite or NaN (if value0_1 is 0).
// So if `high` or `low` are non-finite, we are guaranteed
// to fail the `res < high` check above and end up here.
//
// While we technically should check for non-finite `low`
// and `high` before entering the loop, by doing the checks
// here instead, we allow the common case to avoid these
// checks. But we are still guaranteed that if `low` or
// `high` are non-finite we'll end up here and can do the
// appropriate checks.
//
// Likewise, `high - low` overflowing to infinity is also
// rare, so handle it here after the common case.
let mask = !scale.finite_mask();
if mask.any() {
if !(low.all_finite() && high.all_finite()) {
return Err(Error::NonFinite);
}
scale = scale.decrease_masked(mask);
}
}
Self::sample_single_inclusive(low_b, high_b, rng)
}

#[inline]
Expand Down Expand Up @@ -1465,14 +1418,14 @@ mod tests {
let my_incl_uniform = Uniform::new_inclusive(low, high).unwrap();
for _ in 0..100 {
let v = rng.sample(my_uniform).extract(lane);
assert!(low_scalar <= v && v < high_scalar);
assert!(low_scalar <= v && v <= high_scalar);
let v = rng.sample(my_incl_uniform).extract(lane);
assert!(low_scalar <= v && v <= high_scalar);
let v =
<$ty as SampleUniform>::Sampler::sample_single(low, high, &mut rng)
.unwrap()
.extract(lane);
assert!(low_scalar <= v && v < high_scalar);
assert!(low_scalar <= v && v <= high_scalar);
let v = <$ty as SampleUniform>::Sampler::sample_single_inclusive(
low, high, &mut rng,
)
Expand Down Expand Up @@ -1510,12 +1463,12 @@ mod tests {
low_scalar
);

assert!(max_rng.sample(my_uniform).extract(lane) < high_scalar);
assert!(max_rng.sample(my_uniform).extract(lane) <= high_scalar);
assert!(max_rng.sample(my_incl_uniform).extract(lane) <= high_scalar);
// sample_single cannot cope with max_rng:
// assert!(<$ty as SampleUniform>::Sampler
// ::sample_single(low, high, &mut max_rng).unwrap()
// .extract(lane) < high_scalar);
// .extract(lane) <= high_scalar);
assert!(
<$ty as SampleUniform>::Sampler::sample_single_inclusive(
low,
Expand Down Expand Up @@ -1543,7 +1496,7 @@ mod tests {
)
.unwrap()
.extract(lane)
< high_scalar
<= high_scalar
);
}
}
Expand Down Expand Up @@ -1590,10 +1543,9 @@ mod tests {
#[cfg(all(feature = "std", panic = "unwind"))]
fn test_float_assertions() {
use super::SampleUniform;
use std::panic::catch_unwind;
fn range<T: SampleUniform>(low: T, high: T) {
fn range<T: SampleUniform>(low: T, high: T) -> Result<T, Error> {
let mut rng = crate::test::rng(253);
T::Sampler::sample_single(low, high, &mut rng).unwrap();
T::Sampler::sample_single(low, high, &mut rng)
}

macro_rules! t {
Expand All @@ -1616,10 +1568,9 @@ mod tests {
for lane in 0..<$ty>::LEN {
let low = <$ty>::splat(0.0 as $f_scalar).replace(lane, low_scalar);
let high = <$ty>::splat(1.0 as $f_scalar).replace(lane, high_scalar);
assert!(catch_unwind(|| range(low, high)).is_err());
assert!(range(low, high).is_err());
assert!(Uniform::new(low, high).is_err());
assert!(Uniform::new_inclusive(low, high).is_err());
assert!(catch_unwind(|| range(low, low)).is_err());
assert!(Uniform::new(low, low).is_err());
}
}
Expand Down
22 changes: 0 additions & 22 deletions src/distributions/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -218,9 +218,7 @@ pub(crate) trait FloatSIMDUtils {
fn all_finite(self) -> bool;

type Mask;
fn finite_mask(self) -> Self::Mask;
fn gt_mask(self, other: Self) -> Self::Mask;
fn ge_mask(self, other: Self) -> Self::Mask;

// Decrease all lanes where the mask is `true` to the next lower value
// representable by the floating-point type. At least one of the lanes
Expand Down Expand Up @@ -292,21 +290,11 @@ macro_rules! scalar_float_impl {
self.is_finite()
}

#[inline(always)]
fn finite_mask(self) -> Self::Mask {
self.is_finite()
}

#[inline(always)]
fn gt_mask(self, other: Self) -> Self::Mask {
self > other
}

#[inline(always)]
fn ge_mask(self, other: Self) -> Self::Mask {
self >= other
}

#[inline(always)]
fn decrease_masked(self, mask: Self::Mask) -> Self {
debug_assert!(mask, "At least one lane must be set");
Expand Down Expand Up @@ -368,21 +356,11 @@ macro_rules! simd_impl {
self.is_finite().all()
}

#[inline(always)]
fn finite_mask(self) -> Self::Mask {
self.is_finite()
}

#[inline(always)]
fn gt_mask(self, other: Self) -> Self::Mask {
self.simd_gt(other)
}

#[inline(always)]
fn ge_mask(self, other: Self) -> Self::Mask {
self.simd_ge(other)
}

#[inline(always)]
fn decrease_masked(self, mask: Self::Mask) -> Self {
// Casting a mask into ints will produce all bits set for
Expand Down