Skip to content

Commit

Permalink
Merge pull request #523 from pitdicker/simd_support_basic
Browse files Browse the repository at this point in the history
Add basic SIMD support
  • Loading branch information
dhardy authored Jun 29, 2018
2 parents 3af227a + 5c948fe commit 950c0af
Show file tree
Hide file tree
Showing 7 changed files with 344 additions and 149 deletions.
3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,11 @@ appveyor = { repository = "alexcrichton/rand" }

[features]
default = ["std" ] # without "std" rand uses libcore
nightly = ["i128_support"] # enables all features requiring nightly rust
nightly = ["i128_support", "simd_support"] # enables all features requiring nightly rust
std = ["rand_core/std", "alloc", "libc", "winapi", "cloudabi", "fuchsia-zircon"]
alloc = ["rand_core/alloc"] # enables Vec and Box support (without std)
i128_support = [] # enables i128 and u128 support
simd_support = [] # enables SIMD support
serde1 = ["serde", "serde_derive", "rand_core/serde1"] # enables serialization for PRNGs

[workspace]
Expand Down
167 changes: 110 additions & 57 deletions src/distributions/float.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
use core::mem;
use Rng;
use distributions::{Distribution, Standard};
use distributions::utils::CastFromInt;
#[cfg(feature="simd_support")]
use core::simd::*;

/// A distribution to sample floating point numbers uniformly in the half-open
/// interval `(0, 1]`, i.e. including 1 but not 0.
Expand Down Expand Up @@ -83,15 +86,16 @@ pub(crate) trait IntoFloat {
}

macro_rules! float_impls {
($ty:ty, $uty:ty, $fraction_bits:expr, $exponent_bias:expr) => {
($ty:ident, $uty:ident, $f_scalar:ident, $u_scalar:ty,
$fraction_bits:expr, $exponent_bias:expr) => {
impl IntoFloat for $uty {
type F = $ty;
#[inline(always)]
fn into_float_with_exponent(self, exponent: i32) -> $ty {
// The exponent is encoded using an offset-binary representation
let exponent_bits =
(($exponent_bias + exponent) as $uty) << $fraction_bits;
unsafe { mem::transmute(self | exponent_bits) }
let exponent_bits: $u_scalar =
(($exponent_bias + exponent) as $u_scalar) << $fraction_bits;
$ty::from_bits(self | exponent_bits)
}
}

Expand All @@ -100,12 +104,13 @@ macro_rules! float_impls {
// Multiply-based method; 24/53 random bits; [0, 1) interval.
// We use the most significant bits because for simple RNGs
// those are usually more random.
let float_size = mem::size_of::<$ty>() * 8;
let float_size = mem::size_of::<$f_scalar>() * 8;
let precision = $fraction_bits + 1;
let scale = 1.0 / ((1 as $uty << precision) as $ty);
let scale = 1.0 / ((1 as $u_scalar << precision) as $f_scalar);

let value: $uty = rng.gen();
scale * (value >> (float_size - precision)) as $ty
let value = value >> (float_size - precision);
scale * $ty::cast_from_int(value)
}
}

Expand All @@ -114,14 +119,14 @@ macro_rules! float_impls {
// Multiply-based method; 24/53 random bits; (0, 1] interval.
// We use the most significant bits because for simple RNGs
// those are usually more random.
let float_size = mem::size_of::<$ty>() * 8;
let float_size = mem::size_of::<$f_scalar>() * 8;
let precision = $fraction_bits + 1;
let scale = 1.0 / ((1 as $uty << precision) as $ty);
let scale = 1.0 / ((1 as $u_scalar << precision) as $f_scalar);

let value: $uty = rng.gen();
let value = value >> (float_size - precision);
// Add 1 to shift up; will not overflow because of right-shift:
scale * (value + 1) as $ty
scale * $ty::cast_from_int(value + 1)
}
}

Expand All @@ -130,8 +135,8 @@ macro_rules! float_impls {
// Transmute-based method; 23/52 random bits; (0, 1) interval.
// We use the most significant bits because for simple RNGs
// those are usually more random.
const EPSILON: $ty = 1.0 / (1u64 << $fraction_bits) as $ty;
let float_size = mem::size_of::<$ty>() * 8;
use core::$f_scalar::EPSILON;
let float_size = mem::size_of::<$f_scalar>() * 8;

let value: $uty = rng.gen();
let fraction = value >> (float_size - $fraction_bits);
Expand All @@ -140,67 +145,115 @@ macro_rules! float_impls {
}
}
}
float_impls! { f32, u32, 23, 127 }
float_impls! { f64, u64, 52, 1023 }

float_impls! { f32, u32, f32, u32, 23, 127 }
float_impls! { f64, u64, f64, u64, 52, 1023 }

#[cfg(feature="simd_support")]
float_impls! { f32x2, u32x2, f32, u32, 23, 127 }
#[cfg(feature="simd_support")]
float_impls! { f32x4, u32x4, f32, u32, 23, 127 }
#[cfg(feature="simd_support")]
float_impls! { f32x8, u32x8, f32, u32, 23, 127 }
#[cfg(feature="simd_support")]
float_impls! { f32x16, u32x16, f32, u32, 23, 127 }

#[cfg(feature="simd_support")]
float_impls! { f64x2, u64x2, f64, u64, 52, 1023 }
#[cfg(feature="simd_support")]
float_impls! { f64x4, u64x4, f64, u64, 52, 1023 }
#[cfg(feature="simd_support")]
float_impls! { f64x8, u64x8, f64, u64, 52, 1023 }


#[cfg(test)]
mod tests {
use Rng;
use distributions::{Open01, OpenClosed01};
use rngs::mock::StepRng;
#[cfg(feature="simd_support")]
use core::simd::*;

const EPSILON32: f32 = ::core::f32::EPSILON;
const EPSILON64: f64 = ::core::f64::EPSILON;

#[test]
fn standard_fp_edge_cases() {
let mut zeros = StepRng::new(0, 0);
assert_eq!(zeros.gen::<f32>(), 0.0);
assert_eq!(zeros.gen::<f64>(), 0.0);

let mut one32 = StepRng::new(1 << 8, 0);
assert_eq!(one32.gen::<f32>(), EPSILON32 / 2.0);

let mut one64 = StepRng::new(1 << 11, 0);
assert_eq!(one64.gen::<f64>(), EPSILON64 / 2.0);

let mut max = StepRng::new(!0, 0);
assert_eq!(max.gen::<f32>(), 1.0 - EPSILON32 / 2.0);
assert_eq!(max.gen::<f64>(), 1.0 - EPSILON64 / 2.0);
}
macro_rules! test_f32 {
($fnn:ident, $ty:ident, $ZERO:expr, $EPSILON:expr) => {
#[test]
fn $fnn() {
// Standard
let mut zeros = StepRng::new(0, 0);
assert_eq!(zeros.gen::<$ty>(), $ZERO);
let mut one = StepRng::new(1 << 8 | 1 << (8 + 32), 0);
assert_eq!(one.gen::<$ty>(), $EPSILON / 2.0);
let mut max = StepRng::new(!0, 0);
assert_eq!(max.gen::<$ty>(), 1.0 - $EPSILON / 2.0);

#[test]
fn openclosed01_edge_cases() {
let mut zeros = StepRng::new(0, 0);
assert_eq!(zeros.sample::<f32, _>(OpenClosed01), 0.0 + EPSILON32 / 2.0);
assert_eq!(zeros.sample::<f64, _>(OpenClosed01), 0.0 + EPSILON64 / 2.0);
// OpenClosed01
let mut zeros = StepRng::new(0, 0);
assert_eq!(zeros.sample::<$ty, _>(OpenClosed01),
0.0 + $EPSILON / 2.0);
let mut one = StepRng::new(1 << 8 | 1 << (8 + 32), 0);
assert_eq!(one.sample::<$ty, _>(OpenClosed01), $EPSILON);
let mut max = StepRng::new(!0, 0);
assert_eq!(max.sample::<$ty, _>(OpenClosed01), $ZERO + 1.0);

let mut one32 = StepRng::new(1 << 8, 0);
assert_eq!(one32.sample::<f32, _>(OpenClosed01), EPSILON32);

let mut one64 = StepRng::new(1 << 11, 0);
assert_eq!(one64.sample::<f64, _>(OpenClosed01), EPSILON64);

let mut max = StepRng::new(!0, 0);
assert_eq!(max.sample::<f32, _>(OpenClosed01), 1.0);
assert_eq!(max.sample::<f64, _>(OpenClosed01), 1.0);
// Open01
let mut zeros = StepRng::new(0, 0);
assert_eq!(zeros.sample::<$ty, _>(Open01), 0.0 + $EPSILON / 2.0);
let mut one = StepRng::new(1 << 9 | 1 << (9 + 32), 0);
assert_eq!(one.sample::<$ty, _>(Open01), $EPSILON / 2.0 * 3.0);
let mut max = StepRng::new(!0, 0);
assert_eq!(max.sample::<$ty, _>(Open01), 1.0 - $EPSILON / 2.0);
}
}
}
test_f32! { f32_edge_cases, f32, 0.0, EPSILON32 }
#[cfg(feature="simd_support")]
test_f32! { f32x2_edge_cases, f32x2, f32x2::splat(0.0), f32x2::splat(EPSILON32) }
#[cfg(feature="simd_support")]
test_f32! { f32x4_edge_cases, f32x4, f32x4::splat(0.0), f32x4::splat(EPSILON32) }
#[cfg(feature="simd_support")]
test_f32! { f32x8_edge_cases, f32x8, f32x8::splat(0.0), f32x8::splat(EPSILON32) }
#[cfg(feature="simd_support")]
test_f32! { f32x16_edge_cases, f32x16, f32x16::splat(0.0), f32x16::splat(EPSILON32) }

#[test]
fn open01_edge_cases() {
let mut zeros = StepRng::new(0, 0);
assert_eq!(zeros.sample::<f32, _>(Open01), 0.0 + EPSILON32 / 2.0);
assert_eq!(zeros.sample::<f64, _>(Open01), 0.0 + EPSILON64 / 2.0);
macro_rules! test_f64 {
($fnn:ident, $ty:ident, $ZERO:expr, $EPSILON:expr) => {
#[test]
fn $fnn() {
// Standard
let mut zeros = StepRng::new(0, 0);
assert_eq!(zeros.gen::<$ty>(), $ZERO);
let mut one = StepRng::new(1 << 11, 0);
assert_eq!(one.gen::<$ty>(), $EPSILON / 2.0);
let mut max = StepRng::new(!0, 0);
assert_eq!(max.gen::<$ty>(), 1.0 - $EPSILON / 2.0);

let mut one32 = StepRng::new(1 << 9, 0);
assert_eq!(one32.sample::<f32, _>(Open01), EPSILON32 / 2.0 * 3.0);
// OpenClosed01
let mut zeros = StepRng::new(0, 0);
assert_eq!(zeros.sample::<$ty, _>(OpenClosed01),
0.0 + $EPSILON / 2.0);
let mut one = StepRng::new(1 << 11, 0);
assert_eq!(one.sample::<$ty, _>(OpenClosed01), $EPSILON);
let mut max = StepRng::new(!0, 0);
assert_eq!(max.sample::<$ty, _>(OpenClosed01), $ZERO + 1.0);

let mut one64 = StepRng::new(1 << 12, 0);
assert_eq!(one64.sample::<f64, _>(Open01), EPSILON64 / 2.0 * 3.0);

let mut max = StepRng::new(!0, 0);
assert_eq!(max.sample::<f32, _>(Open01), 1.0 - EPSILON32 / 2.0);
assert_eq!(max.sample::<f64, _>(Open01), 1.0 - EPSILON64 / 2.0);
// Open01
let mut zeros = StepRng::new(0, 0);
assert_eq!(zeros.sample::<$ty, _>(Open01), 0.0 + $EPSILON / 2.0);
let mut one = StepRng::new(1 << 12, 0);
assert_eq!(one.sample::<$ty, _>(Open01), $EPSILON / 2.0 * 3.0);
let mut max = StepRng::new(!0, 0);
assert_eq!(max.sample::<$ty, _>(Open01), 1.0 - $EPSILON / 2.0);
}
}
}
test_f64! { f64_edge_cases, f64, 0.0, EPSILON64 }
#[cfg(feature="simd_support")]
test_f64! { f64x2_edge_cases, f64x2, f64x2::splat(0.0), f64x2::splat(EPSILON64) }
#[cfg(feature="simd_support")]
test_f64! { f64x4_edge_cases, f64x4, f64x4::splat(0.0), f64x4::splat(EPSILON64) }
#[cfg(feature="simd_support")]
test_f64! { f64x8_edge_cases, f64x8, f64x8::splat(0.0), f64x8::splat(EPSILON64) }
}
35 changes: 35 additions & 0 deletions src/distributions/integer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@

use {Rng};
use distributions::{Distribution, Standard};
#[cfg(feature="simd_support")]
use core::simd::*;

impl Distribution<u8> for Standard {
#[inline]
Expand Down Expand Up @@ -84,6 +86,39 @@ impl_int_from_uint! { i64, u64 }
#[cfg(feature = "i128_support")] impl_int_from_uint! { i128, u128 }
impl_int_from_uint! { isize, usize }

#[cfg(feature="simd_support")]
macro_rules! simd_impl {
($bits:expr,) => {};
($bits:expr, $ty:ty, $($ty_more:ty,)*) => {
simd_impl!($bits, $($ty_more,)*);

impl Distribution<$ty> for Standard {
#[inline]
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> $ty {
let mut vec = Default::default();
unsafe {
let ptr = &mut vec;
let b_ptr = &mut *(ptr as *mut $ty as *mut [u8; $bits/8]);
rng.fill_bytes(b_ptr);
}
vec
}
}
}
}

#[cfg(feature="simd_support")]
simd_impl!(16, u8x2, i8x2,);
#[cfg(feature="simd_support")]
simd_impl!(32, u8x4, i8x4, u16x2, i16x2,);
#[cfg(feature="simd_support")]
simd_impl!(64, u8x8, i8x8, u16x4, i16x4, u32x2, i32x2,);
#[cfg(feature="simd_support")]
simd_impl!(128, u8x16, i8x16, u16x8, i16x8, u32x4, i32x4, u64x2, i64x2,);
#[cfg(feature="simd_support")]
simd_impl!(256, u8x32, i8x32, u16x16, i16x16, u32x8, i32x8, u64x4, i64x4,);
#[cfg(feature="simd_support")]
simd_impl!(512, u8x64, i8x64, u16x32, i16x32, u32x16, i32x16, u64x8, i64x8,);

#[cfg(test)]
mod tests {
Expand Down
1 change: 1 addition & 0 deletions src/distributions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,7 @@ mod integer;
#[cfg(feature="std")]
mod log_gamma;
mod other;
mod utils;
#[cfg(feature="std")]
mod ziggurat_tables;
#[cfg(feature="std")]
Expand Down
Loading

0 comments on commit 950c0af

Please sign in to comment.