-
Notifications
You must be signed in to change notification settings - Fork 432
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add binomial and Poisson distributions #96
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,156 @@ | ||
// Copyright 2016-2017 The Rust Project Developers. See the COPYRIGHT | ||
// file at the top-level directory of this distribution and at | ||
// https://rust-lang.org/COPYRIGHT. | ||
// | ||
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or | ||
// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license | ||
// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your | ||
// option. This file may not be copied, modified, or distributed | ||
// except according to those terms. | ||
|
||
//! The binomial distribution. | ||
|
||
use Rng; | ||
use distributions::Distribution; | ||
use distributions::log_gamma::log_gamma; | ||
use std::f64::consts::PI; | ||
|
||
/// The binomial distribution `Binomial(n, p)`. | ||
/// | ||
/// This distribution has density function: `f(k) = n!/(k! (n-k)!) p^k (1-p)^(n-k)` for `k >= 0`. | ||
/// | ||
/// # Example | ||
/// | ||
/// ```rust | ||
/// use rand::distributions::{Binomial, Distribution}; | ||
/// | ||
/// let bin = Binomial::new(20, 0.3); | ||
/// let v = bin.sample(&mut rand::thread_rng()); | ||
/// println!("{} is from a binomial distribution", v); | ||
/// ``` | ||
#[derive(Clone, Copy, Debug)] | ||
pub struct Binomial { | ||
n: u64, // number of trials | ||
p: f64, // probability of success | ||
} | ||
|
||
impl Binomial { | ||
/// Construct a new `Binomial` with the given shape parameters | ||
/// `n`, `p`. Panics if `p <= 0` or `p >= 1`. | ||
pub fn new(n: u64, p: f64) -> Binomial { | ||
assert!(p > 0.0, "Binomial::new called with p <= 0"); | ||
assert!(p < 1.0, "Binomial::new called with p >= 1"); | ||
Binomial { n: n, p: p } | ||
} | ||
} | ||
|
||
impl Distribution<u64> for Binomial { | ||
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> u64 { | ||
// binomial distribution is symmetrical with respect to p -> 1-p, k -> n-k | ||
// switch p so that it is less than 0.5 - this allows for lower expected values | ||
// we will just invert the result at the end | ||
let p = if self.p <= 0.5 { | ||
self.p | ||
} else { | ||
1.0 - self.p | ||
}; | ||
|
||
// expected value of the sample | ||
let expected = self.n as f64 * p; | ||
|
||
let result = | ||
// for low expected values we just simulate n drawings | ||
if expected < 25.0 { | ||
let mut lresult = 0.0; | ||
for _ in 0 .. self.n { | ||
if rng.gen::<f64>() < p { | ||
lresult += 1.0; | ||
} | ||
} | ||
lresult | ||
} | ||
// high expected value - do the rejection method | ||
else { | ||
// prepare some cached values | ||
let float_n = self.n as f64; | ||
let ln_fact_n = log_gamma(float_n + 1.0); | ||
let pc = 1.0 - p; | ||
let log_p = p.ln(); | ||
let log_pc = pc.ln(); | ||
let sq = (expected * (2.0 * pc)).sqrt(); | ||
|
||
let mut lresult; | ||
|
||
loop { | ||
let mut comp_dev: f64; | ||
// we use the lorentzian distribution as the comparison distribution | ||
// f(x) ~ 1/(1+x/^2) | ||
loop { | ||
// draw from the lorentzian distribution | ||
comp_dev = (PI*rng.gen::<f64>()).tan(); | ||
// shift the peak of the comparison ditribution | ||
lresult = expected + sq * comp_dev; | ||
// repeat the drawing until we are in the range of possible values | ||
if lresult >= 0.0 && lresult < float_n + 1.0 { | ||
break; | ||
} | ||
} | ||
|
||
// the result should be discrete | ||
lresult = lresult.floor(); | ||
|
||
let log_binomial_dist = ln_fact_n - log_gamma(lresult+1.0) - | ||
log_gamma(float_n - lresult + 1.0) + lresult*log_p + (float_n - lresult)*log_pc; | ||
// this is the binomial probability divided by the comparison probability | ||
// we will generate a uniform random value and if it is larger than this, | ||
// we interpret it as a value falling out of the distribution and repeat | ||
let comparison_coeff = (log_binomial_dist.exp() * sq) * (1.2 * (1.0 + comp_dev*comp_dev)); | ||
|
||
if comparison_coeff >= rng.gen() { | ||
break; | ||
} | ||
} | ||
|
||
lresult | ||
}; | ||
|
||
// invert the result for p < 0.5 | ||
if p != self.p { | ||
self.n - result as u64 | ||
} else { | ||
result as u64 | ||
} | ||
} | ||
} | ||
|
||
#[cfg(test)] | ||
mod test { | ||
use distributions::Distribution; | ||
use super::Binomial; | ||
|
||
#[test] | ||
fn test_binomial() { | ||
let binomial = Binomial::new(150, 0.1); | ||
let mut rng = ::test::rng(123); | ||
let mut sum = 0; | ||
for _ in 0..1000 { | ||
sum += binomial.sample(&mut rng); | ||
} | ||
let avg = (sum as f64) / 1000.0; | ||
println!("Binomial average: {}", avg); | ||
assert!((avg - 15.0).abs() < 0.5); // not 100% certain, but probable enough | ||
} | ||
|
||
#[test] | ||
#[should_panic] | ||
#[cfg_attr(target_env = "msvc", ignore)] | ||
fn test_binomial_invalid_lambda_zero() { | ||
Binomial::new(20, 0.0); | ||
} | ||
#[test] | ||
#[should_panic] | ||
#[cfg_attr(target_env = "msvc", ignore)] | ||
fn test_binomial_invalid_lambda_neg() { | ||
Binomial::new(20, -10.0); | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
// Copyright 2016-2017 The Rust Project Developers. See the COPYRIGHT | ||
// file at the top-level directory of this distribution and at | ||
// https://rust-lang.org/COPYRIGHT. | ||
// | ||
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or | ||
// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license | ||
// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your | ||
// option. This file may not be copied, modified, or distributed | ||
// except according to those terms. | ||
|
||
/// Calculates ln(gamma(x)) (natural logarithm of the gamma | ||
/// function) using the Lanczos approximation. | ||
/// | ||
/// The approximation expresses the gamma function as: | ||
/// `gamma(z+1) = sqrt(2*pi)*(z+g+0.5)^(z+0.5)*exp(-z-g-0.5)*Ag(z)` | ||
/// `g` is an arbitrary constant; we use the approximation with `g=5`. | ||
/// | ||
/// Noting that `gamma(z+1) = z*gamma(z)` and applying `ln` to both sides: | ||
/// `ln(gamma(z)) = (z+0.5)*ln(z+g+0.5)-(z+g+0.5) + ln(sqrt(2*pi)*Ag(z)/z)` | ||
/// | ||
/// `Ag(z)` is an infinite series with coefficients that can be calculated | ||
/// ahead of time - we use just the first 6 terms, which is good enough | ||
/// for most purposes. | ||
pub fn log_gamma(x: f64) -> f64 { | ||
// precalculated 6 coefficients for the first 6 terms of the series | ||
let coefficients: [f64; 6] = [ | ||
76.18009172947146, | ||
-86.50532032941677, | ||
24.01409824083091, | ||
-1.231739572450155, | ||
0.1208650973866179e-2, | ||
-0.5395239384953e-5, | ||
]; | ||
|
||
// (x+0.5)*ln(x+g+0.5)-(x+g+0.5) | ||
let tmp = x + 5.5; | ||
let log = (x + 0.5) * tmp.ln() - tmp; | ||
|
||
// the first few terms of the series for Ag(x) | ||
let mut a = 1.000000000190015; | ||
let mut denom = x; | ||
for j in 0..6 { | ||
denom += 1.0; | ||
a += coefficients[j] / denom; | ||
} | ||
|
||
// get everything together | ||
// a is Ag(x) | ||
// 2.5066... is sqrt(2pi) | ||
return log + (2.5066282746310005 * a / x).ln(); | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -25,6 +25,10 @@ pub use self::gamma::{Gamma, ChiSquared, FisherF, StudentT}; | |
pub use self::normal::{Normal, LogNormal, StandardNormal}; | ||
#[cfg(feature="std")] | ||
pub use self::exponential::{Exp, Exp1}; | ||
#[cfg(feature = "std")] | ||
pub use self::poisson::Poisson; | ||
#[cfg(feature = "std")] | ||
pub use self::binomial::Binomial; | ||
|
||
pub mod range; | ||
#[cfg(feature="std")] | ||
|
@@ -33,9 +37,14 @@ pub mod gamma; | |
pub mod normal; | ||
#[cfg(feature="std")] | ||
pub mod exponential; | ||
#[cfg(feature = "std")] | ||
pub mod poisson; | ||
#[cfg(feature = "std")] | ||
pub mod binomial; | ||
|
||
mod float; | ||
mod integer; | ||
mod log_gamma; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This module also needs the |
||
mod other; | ||
#[cfg(feature="std")] | ||
mod ziggurat_tables; | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,144 @@ | ||
// Copyright 2016-2017 The Rust Project Developers. See the COPYRIGHT | ||
// file at the top-level directory of this distribution and at | ||
// https://rust-lang.org/COPYRIGHT. | ||
// | ||
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or | ||
// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license | ||
// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your | ||
// option. This file may not be copied, modified, or distributed | ||
// except according to those terms. | ||
|
||
//! The Poisson distribution. | ||
|
||
use Rng; | ||
use distributions::Distribution; | ||
use distributions::log_gamma::log_gamma; | ||
use std::f64::consts::PI; | ||
|
||
/// The Poisson distribution `Poisson(lambda)`. | ||
/// | ||
/// This distribution has a density function: | ||
/// `f(k) = lambda^k * exp(-lambda) / k!` for `k >= 0`. | ||
/// | ||
/// # Example | ||
/// | ||
/// ```rust | ||
/// use rand::distributions::{Poisson, Distribution}; | ||
/// | ||
/// let poi = Poisson::new(2.0); | ||
/// let v = poi.sample(&mut rand::thread_rng()); | ||
/// println!("{} is from a Poisson(2) distribution", v); | ||
/// ``` | ||
#[derive(Clone, Copy, Debug)] | ||
pub struct Poisson { | ||
lambda: f64, | ||
// precalculated values | ||
exp_lambda: f64, | ||
log_lambda: f64, | ||
sqrt_2lambda: f64, | ||
magic_val: f64, | ||
} | ||
|
||
impl Poisson { | ||
/// Construct a new `Poisson` with the given shape parameter | ||
/// `lambda`. Panics if `lambda <= 0`. | ||
pub fn new(lambda: f64) -> Poisson { | ||
assert!(lambda > 0.0, "Poisson::new called with lambda <= 0"); | ||
let log_lambda = lambda.ln(); | ||
Poisson { | ||
lambda: lambda, | ||
exp_lambda: (-lambda).exp(), | ||
log_lambda: log_lambda, | ||
sqrt_2lambda: (2.0 * lambda).sqrt(), | ||
magic_val: lambda * log_lambda - log_gamma(1.0 + lambda), | ||
} | ||
} | ||
} | ||
|
||
impl Distribution<u64> for Poisson { | ||
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> u64 { | ||
// using the algorithm from Numerical Recipes in C | ||
|
||
// for low expected values use the Knuth method | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would it be better to use the inverse transform method for small samples, since it only requires 1 random sample? I don't know a lot about this topic unfortunately. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. To be honest, I don't know, my knowledge is also limited. Maybe it is a good idea, sampling just once sounds attractive. I mostly just ported the algorithm from "Numerical Recipes", but it is possible that something else could be better. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm; unless an expert in this area turns up (unlikely), perhaps the best we can do is implement some tests (e.g. plot a high-resolution histogram), then say this is |
||
if self.lambda < 12.0 { | ||
let mut result = 0; | ||
let mut p = 1.0; | ||
while p > self.exp_lambda { | ||
p *= rng.gen::<f64>(); | ||
result += 1; | ||
} | ||
result - 1 | ||
} | ||
// high expected values - rejection method | ||
else { | ||
let mut int_result: u64; | ||
|
||
loop { | ||
let mut result; | ||
let mut comp_dev; | ||
|
||
// we use the lorentzian distribution as the comparison distribution | ||
// f(x) ~ 1/(1+x/^2) | ||
loop { | ||
// draw from the lorentzian distribution | ||
comp_dev = (PI * rng.gen::<f64>()).tan(); | ||
// shift the peak of the comparison ditribution | ||
result = self.sqrt_2lambda * comp_dev + self.lambda; | ||
// repeat the drawing until we are in the range of possible values | ||
if result >= 0.0 { | ||
break; | ||
} | ||
} | ||
// now the result is a random variable greater than 0 with Lorentzian distribution | ||
// the result should be an integer value | ||
result = result.floor(); | ||
int_result = result as u64; | ||
|
||
// this is the ratio of the Poisson distribution to the comparison distribution | ||
// the magic value scales the distribution function to a range of approximately 0-1 | ||
// since it is not exact, we multiply the ratio by 0.9 to avoid ratios greater than 1 | ||
// this doesn't change the resulting distribution, only increases the rate of failed drawings | ||
let check = 0.9 * (1.0 + comp_dev * comp_dev) | ||
* (result * self.log_lambda - log_gamma(1.0 + result) - self.magic_val).exp(); | ||
|
||
// check with uniform random value - if below the threshold, we are within the target distribution | ||
if rng.gen::<f64>() <= check { | ||
break; | ||
} | ||
} | ||
int_result | ||
} | ||
} | ||
} | ||
|
||
#[cfg(test)] | ||
mod test { | ||
use distributions::Distribution; | ||
use super::Poisson; | ||
|
||
#[test] | ||
fn test_poisson() { | ||
let poisson = Poisson::new(10.0); | ||
let mut rng = ::test::rng(123); | ||
let mut sum = 0; | ||
for _ in 0..1000 { | ||
sum += poisson.sample(&mut rng); | ||
} | ||
let avg = (sum as f64) / 1000.0; | ||
println!("Poisson average: {}", avg); | ||
assert!((avg - 10.0).abs() < 0.5); // not 100% certain, but probable enough | ||
} | ||
|
||
#[test] | ||
#[should_panic] | ||
#[cfg_attr(target_env = "msvc", ignore)] | ||
fn test_poisson_invalid_lambda_zero() { | ||
Poisson::new(0.0); | ||
} | ||
#[test] | ||
#[should_panic] | ||
#[cfg_attr(target_env = "msvc", ignore)] | ||
fn test_poisson_invalid_lambda_neg() { | ||
Poisson::new(-10.0); | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@pitdicker what do you think about this probability test? I've been wondering if we should add a dedicated Bernoulli distribution for more accurate sampling.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I should first learn a lot before I can make any meaningful comments w.r.t. the distributions...
This single line is pretty much the Bernoulli distribution? It might be more generally useful than
gen_weighted_bool
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I think we should add a Bernoulli distribution, and it would be nice to have it reasonably accurate for small
p
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we can do better than just
rng.gen::<f64>() < p
, then sure, this is a good idea, otherwise I'm not really convinced that it makes sense to make it a separate distribution.And as for doing better, that's a bit over my head, so unfortunately I won't be able to help...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To fill you in @fizyk20, @pitdicker already did quite a bit of work implementing higher-precision floating point sampling, since the default method uses the same precision over the 0-1 range despite the format being able to represent a lot more close to 0 — however, we seem to have decided not to use this sampling method by default. There's also the thing that we use a small offset, which normally isn't an issue, but might be for correct sampling of small probabilities.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Almost available with
Rng::gen_bool(p)
from #308.