Skip to content

Commit

Permalink
Merge pull request #4 from promised-ai/feature/from-parameters
Browse files Browse the repository at this point in the history
Add Parameterized trait and test macro for ConjugatePrior
  • Loading branch information
cscherrer authored Apr 9, 2024
2 parents f69f390 + 716cfc2 commit e91713c
Show file tree
Hide file tree
Showing 56 changed files with 1,143 additions and 864 deletions.
1 change: 0 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ rand_xoshiro = "0.6"
serde1 = ["serde", "nalgebra/serde-serialize"]
arraydist = ["nalgebra"]
process = ["serde", "nalgebra/serde-serialize", "argmin", "argmin-math", "arraydist"]
datum = []
experimental = ["rand_xoshiro"]

[package.metadata.docs.rs]
Expand Down
191 changes: 0 additions & 191 deletions src/data/datum.rs

This file was deleted.

6 changes: 0 additions & 6 deletions src/data/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,6 @@
mod partition;
mod stat;

#[cfg(feature = "datum")]
mod datum;

#[cfg(feature = "datum")]
pub use datum::Datum;

pub use partition::Partition;
pub use stat::BernoulliSuffStat;
pub use stat::BetaSuffStat;
Expand Down
16 changes: 14 additions & 2 deletions src/dist/bernoulli.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,18 @@ pub struct Bernoulli {
p: f64,
}

impl Parameterized for Bernoulli {
type Parameters = f64;

fn emit_params(&self) -> Self::Parameters {
self.p()
}

fn from_params(params: Self::Parameters) -> Self {
Self::new_unchecked(params)
}
}

#[derive(Debug, Clone, PartialEq)]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "serde1", serde(rename_all = "snake_case"))]
Expand Down Expand Up @@ -373,7 +385,7 @@ mod tests {
const N_TRIES: usize = 5;
const X2_PVAL: f64 = 0.2;

test_basic_impls!([binary] Bernoulli::default());
test_basic_impls!(bool, Bernoulli, Bernoulli::default());

#[test]
fn new() {
Expand Down Expand Up @@ -658,7 +670,7 @@ mod tests {
}

#[test]
fn unifrom_entropy() {
fn uniform_entropy() {
let b = Bernoulli::uniform();
assert::close(b.entropy(), f64::consts::LN_2, TOL);
}
Expand Down
24 changes: 22 additions & 2 deletions src/dist/beta.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,26 @@ pub struct Beta {
ln_beta_ab: OnceLock<f64>,
}

pub struct BetaParameters {
pub alpha: f64,
pub beta: f64,
}

impl Parameterized for Beta {
type Parameters = BetaParameters;

fn emit_params(&self) -> Self::Parameters {
Self::Parameters {
alpha: self.alpha(),
beta: self.beta(),
}
}

fn from_params(params: Self::Parameters) -> Self {
Self::new_unchecked(params.alpha, params.beta)
}
}

impl PartialEq for Beta {
fn eq(&self, other: &Beta) -> bool {
self.alpha == other.alpha && self.beta == other.beta
Expand Down Expand Up @@ -454,7 +474,7 @@ mod tests {
const KS_PVAL: f64 = 0.2;
const N_TRIES: usize = 5;

test_basic_impls!([continuous] Beta::jeffreys());
test_basic_impls!(f64, Beta, Beta::jeffreys());

#[test]
fn new() {
Expand Down Expand Up @@ -578,7 +598,7 @@ mod tests {
}

#[test]
fn draw_should_resturn_values_within_0_to_1() {
fn draw_should_return_values_within_0_to_1() {
let mut rng = rand::thread_rng();
let beta = Beta::jeffreys();
for _ in 0..100 {
Expand Down
28 changes: 3 additions & 25 deletions src/dist/beta/bernoulli_prior.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,12 @@ impl<X: Booleable> ConjugatePrior<X, Bernoulli> for Beta {
#[cfg(test)]
mod tests {
use super::*;
use crate::test_conjugate_prior;

const TOL: f64 = 1E-12;

test_conjugate_prior!(bool, Bernoulli, Beta, Beta::new(0.5, 1.2).unwrap());

#[test]
fn posterior_from_data_bool() {
let data = vec![false, true, false, true, true];
Expand All @@ -108,29 +111,4 @@ mod tests {
assert::close(posterior.alpha(), 4.0, TOL);
assert::close(posterior.beta(), 3.0, TOL);
}

#[test]
fn bern_bayes_law() {
let mut rng = rand::thread_rng();

// Prior
let prior = Beta::new(5.0, 2.0).unwrap();
let par: f64 = prior.draw(&mut rng);
let prior_f = prior.f(&par);

// Likelihood
let lik = Bernoulli::new(par).unwrap();
let lik_data: bool = lik.draw(&mut rng);
let lik_f = lik.f(&lik_data);

// Evidence
let ev = prior.m(&DataOrSuffStat::Data(&[lik_data]));

// Posterior
let post = prior.posterior(&DataOrSuffStat::Data(&[lik_data]));
let post_f = post.f(&par);

// Bayes' law
assert::close(post_f, prior_f * lik_f / ev, 1e-12);
}
}
28 changes: 27 additions & 1 deletion src/dist/beta_binom.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,28 @@ pub struct BetaBinomial {
ln_beta_ab: OnceLock<f64>,
}

pub struct BetaBinomialParameters {
pub n: u32,
pub alpha: f64,
pub beta: f64,
}

impl Parameterized for BetaBinomial {
type Parameters = BetaBinomialParameters;

fn emit_params(&self) -> Self::Parameters {
Self::Parameters {
n: self.n(),
alpha: self.alpha(),
beta: self.beta(),
}
}

fn from_params(params: Self::Parameters) -> Self {
Self::new_unchecked(params.n, params.alpha, params.beta)
}
}

impl PartialEq for BetaBinomial {
fn eq(&self, other: &BetaBinomial) -> bool {
self.n == other.n
Expand Down Expand Up @@ -427,7 +449,11 @@ mod tests {

const TOL: f64 = 1E-12;

test_basic_impls!([count] BetaBinomial::new(10, 0.2, 0.7).unwrap());
test_basic_impls!(
u32,
BetaBinomial,
BetaBinomial::new(10, 0.2, 0.7).unwrap()
);

#[test]
fn new() {
Expand Down
Loading

0 comments on commit e91713c

Please sign in to comment.