diff --git a/rand_distr/src/poisson.rs b/rand_distr/src/poisson.rs index 96bc8b9b6b4..4f4a0b7a3d9 100644 --- a/rand_distr/src/poisson.rs +++ b/rand_distr/src/poisson.rs @@ -10,8 +10,8 @@ //! The Poisson distribution. use rand::Rng; -use crate::{Distribution, Cauchy}; -use crate::utils::log_gamma; +use crate::{Distribution, Cauchy, Standard}; +use crate::utils::Float; /// The Poisson distribution `Poisson(lambda)`. /// @@ -24,17 +24,17 @@ use crate::utils::log_gamma; /// use rand_distr::{Poisson, Distribution}; /// /// let poi = Poisson::new(2.0).unwrap(); -/// let v = poi.sample(&mut rand::thread_rng()); +/// let v: u64 = poi.sample(&mut rand::thread_rng()); /// println!("{} is from a Poisson(2) distribution", v); /// ``` #[derive(Clone, Copy, Debug)] -pub struct Poisson { - lambda: f64, +pub struct Poisson { + lambda: N, // precalculated values - exp_lambda: f64, - log_lambda: f64, - sqrt_2lambda: f64, - magic_val: f64, + exp_lambda: N, + log_lambda: N, + sqrt_2lambda: N, + magic_val: N, } /// Error type returned from `Poisson::new`. @@ -44,11 +44,13 @@ pub enum Error { ShapeTooSmall, } -impl Poisson { +impl Poisson +where Standard: Distribution +{ /// Construct a new `Poisson` with the given shape parameter /// `lambda`. - pub fn new(lambda: f64) -> Result { - if !(lambda > 0.0) { + pub fn new(lambda: N) -> Result, Error> { + if !(lambda > N::from(0.0)) { return Err(Error::ShapeTooSmall); } let log_lambda = lambda.ln(); @@ -56,36 +58,37 @@ impl Poisson { lambda, exp_lambda: (-lambda).exp(), log_lambda, - sqrt_2lambda: (2.0 * lambda).sqrt(), - magic_val: lambda * log_lambda - log_gamma(1.0 + lambda), + sqrt_2lambda: (N::from(2.0) * lambda).sqrt(), + magic_val: lambda * log_lambda - (N::from(1.0) + lambda).log_gamma(), }) } } -impl Distribution for Poisson { - fn sample(&self, rng: &mut R) -> u64 { +impl Distribution for Poisson +where Standard: Distribution +{ + #[inline] + fn sample(&self, rng: &mut R) -> N { // using the algorithm from Numerical Recipes in C // for low expected values use the Knuth method - if self.lambda < 12.0 { - let mut result = 0; - let mut p = 1.0; + if self.lambda < N::from(12.0) { + let mut result = N::from(0.); + let mut p = N::from(1.0); while p > self.exp_lambda { - p *= rng.gen::(); - result += 1; + p *= rng.gen::(); + result += N::from(1.); } - result - 1 + result - N::from(1.) } // high expected values - rejection method else { - let mut int_result: u64; - // we use the Cauchy distribution as the comparison distribution // f(x) ~ 1/(1+x^2) - let cauchy = Cauchy::new(0.0, 1.0).unwrap(); + let cauchy = Cauchy::new(N::from(0.0), N::from(1.0)).unwrap(); + let mut result; loop { - let mut result; let mut comp_dev; loop { @@ -94,32 +97,41 @@ impl Distribution for Poisson { // shift the peak of the comparison ditribution result = self.sqrt_2lambda * comp_dev + self.lambda; // repeat the drawing until we are in the range of possible values - if result >= 0.0 { + if result >= N::from(0.0) { break; } } // now the result is a random variable greater than 0 with Cauchy distribution // the result should be an integer value result = result.floor(); - int_result = result as u64; // this is the ratio of the Poisson distribution to the comparison distribution // the magic value scales the distribution function to a range of approximately 0-1 // since it is not exact, we multiply the ratio by 0.9 to avoid ratios greater than 1 // this doesn't change the resulting distribution, only increases the rate of failed drawings - let check = 0.9 * (1.0 + comp_dev * comp_dev) - * (result * self.log_lambda - log_gamma(1.0 + result) - self.magic_val).exp(); + let check = N::from(0.9) * (N::from(1.0) + comp_dev * comp_dev) + * (result * self.log_lambda - (N::from(1.0) + result).log_gamma() - self.magic_val).exp(); // check with uniform random value - if below the threshold, we are within the target distribution - if rng.gen::() <= check { + if rng.gen::() <= check { break; } } - int_result + result } } } +impl Distribution for Poisson +where Standard: Distribution +{ + #[inline] + fn sample(&self, rng: &mut R) -> u64 { + let result: N = self.sample(rng); + result.to_u64().unwrap() + } +} + #[cfg(test)] mod test { use crate::Distribution; @@ -129,13 +141,20 @@ mod test { fn test_poisson_10() { let poisson = Poisson::new(10.0).unwrap(); let mut rng = crate::test::rng(123); - let mut sum = 0; + let mut sum_u64 = 0; + let mut sum_f64 = 0.; for _ in 0..1000 { - sum += poisson.sample(&mut rng); + let s_u64: u64 = poisson.sample(&mut rng); + let s_f64: f64 = poisson.sample(&mut rng); + sum_u64 += s_u64; + sum_f64 += s_f64; + } + let avg_u64 = (sum_u64 as f64) / 1000.0; + let avg_f64 = sum_f64 / 1000.0; + println!("Poisson averages: {} (u64) {} (f64)", avg_u64, avg_f64); + for &avg in &[avg_u64, avg_f64] { + assert!((avg - 10.0).abs() < 0.5); // not 100% certain, but probable enough } - let avg = (sum as f64) / 1000.0; - println!("Poisson average: {}", avg); - assert!((avg - 10.0).abs() < 0.5); // not 100% certain, but probable enough } #[test] @@ -143,13 +162,61 @@ mod test { // Take the 'high expected values' path let poisson = Poisson::new(15.0).unwrap(); let mut rng = crate::test::rng(123); - let mut sum = 0; + let mut sum_u64 = 0; + let mut sum_f64 = 0.; for _ in 0..1000 { - sum += poisson.sample(&mut rng); + let s_u64: u64 = poisson.sample(&mut rng); + let s_f64: f64 = poisson.sample(&mut rng); + sum_u64 += s_u64; + sum_f64 += s_f64; + } + let avg_u64 = (sum_u64 as f64) / 1000.0; + let avg_f64 = sum_f64 / 1000.0; + println!("Poisson average: {} (u64) {} (f64)", avg_u64, avg_f64); + for &avg in &[avg_u64, avg_f64] { + assert!((avg - 15.0).abs() < 0.5); // not 100% certain, but probable enough + } + } + + #[test] + fn test_poisson_10_f32() { + let poisson = Poisson::new(10.0f32).unwrap(); + let mut rng = crate::test::rng(123); + let mut sum_u64 = 0; + let mut sum_f32 = 0.; + for _ in 0..1000 { + let s_u64: u64 = poisson.sample(&mut rng); + let s_f32: f32 = poisson.sample(&mut rng); + sum_u64 += s_u64; + sum_f32 += s_f32; + } + let avg_u64 = (sum_u64 as f32) / 1000.0; + let avg_f32 = sum_f32 / 1000.0; + println!("Poisson averages: {} (u64) {} (f32)", avg_u64, avg_f32); + for &avg in &[avg_u64, avg_f32] { + assert!((avg - 10.0).abs() < 0.5); // not 100% certain, but probable enough + } + } + + #[test] + fn test_poisson_15_f32() { + // Take the 'high expected values' path + let poisson = Poisson::new(15.0f32).unwrap(); + let mut rng = crate::test::rng(123); + let mut sum_u64 = 0; + let mut sum_f32 = 0.; + for _ in 0..1000 { + let s_u64: u64 = poisson.sample(&mut rng); + let s_f32: f32 = poisson.sample(&mut rng); + sum_u64 += s_u64; + sum_f32 += s_f32; + } + let avg_u64 = (sum_u64 as f32) / 1000.0; + let avg_f32 = sum_f32 / 1000.0; + println!("Poisson average: {} (u64) {} (f32)", avg_u64, avg_f32); + for &avg in &[avg_u64, avg_f32] { + assert!((avg - 15.0).abs() < 0.5); // not 100% certain, but probable enough } - let avg = (sum as f64) / 1000.0; - println!("Poisson average: {}", avg); - assert!((avg - 15.0).abs() < 0.5); // not 100% certain, but probable enough } #[test] diff --git a/rand_distr/src/utils.rs b/rand_distr/src/utils.rs index e5f107f23e5..c275d0e5cba 100644 --- a/rand_distr/src/utils.rs +++ b/rand_distr/src/utils.rs @@ -33,9 +33,13 @@ pub trait Float: Copy + Sized + cmp::PartialOrd fn pi() -> Self; /// Support approximate representation of a f64 value fn from(x: f64) -> Self; + /// Support converting to an unsigned integer. + fn to_u64(self) -> Option; /// Take the absolute value of self fn abs(self) -> Self; + /// Take the largest integer less than or equal to self + fn floor(self) -> Self; /// Take the exponential of self fn exp(self) -> Self; @@ -48,34 +52,81 @@ pub trait Float: Copy + Sized + cmp::PartialOrd /// Take the tangent of self fn tan(self) -> Self; + /// Take the logarithm of the gamma function of self + fn log_gamma(self) -> Self; } impl Float for f32 { + #[inline] fn pi() -> Self { core::f32::consts::PI } + #[inline] fn from(x: f64) -> Self { x as f32 } + #[inline] + fn to_u64(self) -> Option { + if self >= 0. && self <= ::core::u64::MAX as f32 { + Some(self as u64) + } else { + None + } + } + #[inline] fn abs(self) -> Self { self.abs() } + #[inline] + fn floor(self) -> Self { self.floor() } + #[inline] fn exp(self) -> Self { self.exp() } + #[inline] fn ln(self) -> Self { self.ln() } + #[inline] fn sqrt(self) -> Self { self.sqrt() } + #[inline] fn powf(self, power: Self) -> Self { self.powf(power) } + #[inline] fn tan(self) -> Self { self.tan() } + #[inline] + fn log_gamma(self) -> Self { + let result = log_gamma(self as f64); + assert!(result <= ::core::f32::MAX as f64); + assert!(result >= ::core::f32::MIN as f64); + result as f32 + } } impl Float for f64 { + #[inline] fn pi() -> Self { core::f64::consts::PI } + #[inline] fn from(x: f64) -> Self { x } + #[inline] + fn to_u64(self) -> Option { + if self >= 0. && self <= ::core::u64::MAX as f64 { + Some(self as u64) + } else { + None + } + } + #[inline] fn abs(self) -> Self { self.abs() } + #[inline] + fn floor(self) -> Self { self.floor() } + #[inline] fn exp(self) -> Self { self.exp() } + #[inline] fn ln(self) -> Self { self.ln() } + #[inline] fn sqrt(self) -> Self { self.sqrt() } + #[inline] fn powf(self, power: Self) -> Self { self.powf(power) } + #[inline] fn tan(self) -> Self { self.tan() } + #[inline] + fn log_gamma(self) -> Self { log_gamma(self) } } /// Calculates ln(gamma(x)) (natural logarithm of the gamma @@ -109,7 +160,7 @@ pub(crate) fn log_gamma(x: f64) -> f64 { // the first few terms of the series for Ag(x) let mut a = 1.000000000190015; let mut denom = x; - for coeff in &coefficients { + for &coeff in &coefficients { denom += 1.0; a += coeff / denom; }