diff --git a/.gitignore b/.gitignore index 940030b..bf7cf47 100644 --- a/.gitignore +++ b/.gitignore @@ -5,3 +5,4 @@ .vim .helix *.zip +lcov.info diff --git a/CHANGELOG.md b/CHANGELOG.md index b7d4649..f78d0d9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,25 @@ # Changelog +## [0.17.0] - 2024-05-15 +- Add `experimental` module with `stick_breaking_process` submodule containing: + - `StickBreaking` struct representing a stick-breaking process + - `StickBreakingDiscrete` struct representing a discrete distribution based on a stick-breaking process + - `StickSequence` struct representing a sequence of stick breaks + - `BreakSequence` struct representing a sequence of break points + - `posterior` method on `StickBreaking` to compute the posterior distribution given data + - Various helper methods and trait implementations +- Update `ConjugatePrior` + - Change `LnMCache` to `MCache` + - Change `LnPpCache` to `PpCache` +- Split `Rv` into `Sampleable` and `HasDensity` +- Add `Process` generalizing `Rv` +- Add `Parameterized` trait +- Add `impl ConjugatePrior for UnitPowerLaw` +- Add `ConvergentSequence` implementing Aitken's delta-squared method +- Add `sorted_uniforms` helper function +- Minor stylistic changes suggested by Clippy + + ## [0.16.5] - 2024-03-14 - Moved repository to GitHub. @@ -172,7 +192,7 @@ - Remove dependency on `quadrature` crate in favor of hand-rolled adaptive Simpson's rule, which handles multimodal distributions better. - +[0.17.0]: https://github.com/promise-ai/rv/compare/v0.16.5...v0.17.0 [0.16.5]: https://github.com/promise-ai/rv/compare/v0.16.4...v0.16.5 [0.16.4]: https://github.com/promise-ai/rv/compare/v0.16.3...v0.16.4 [0.16.3]: https://github.com/promise-ai/rv/compare/v0.16.2...v0.16.3 diff --git a/Cargo.lock b/Cargo.lock index c762c5a..6360cd2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -121,6 +121,21 @@ dependencies = [ "serde", ] +[[package]] +name = "bit-set" +version = "0.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0700ddab506f33b20a03b13996eccd309a48e5ff77d0d95926aa0210fb4e95f1" +dependencies = [ + "bit-vec", +] + +[[package]] +name = "bit-vec" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "349f9b6a179ed607305526ca489b34ad0a41aed5f7980fa90eb03160b69598fb" + [[package]] name = "bitflags" version = "1.3.2" @@ -221,7 +236,7 @@ dependencies = [ "clap", "criterion-plot", "is-terminal", - "itertools", + "itertools 0.10.5", "num-traits", "once_cell", "oorandom", @@ -242,7 +257,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6b50826342786a51a89e2da3a28f1c32b06e387201bc2d19791f622c673706b1" dependencies = [ "cast", - "itertools", + "itertools 0.10.5", ] [[package]] @@ -333,6 +348,28 @@ version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5" +[[package]] +name = "errno" +version = "0.3.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a258e46cdc063eb8519c00b9fc845fc47bcfca4130e2f08e88665ceda8474245" +dependencies = [ + "libc", + "windows-sys", +] + +[[package]] +name = "fastrand" +version = "2.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "658bd65b1cf4c852a3cc96f18a8ce7b5640f6b703f905c7d74532294c2a63984" + +[[package]] +name = "fnv" +version = "1.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" + [[package]] name = "getrandom" version = "0.2.12" @@ -380,15 +417,15 @@ dependencies = [ [[package]] name = "hermit-abi" -version = "0.3.4" +version = "0.3.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5d3d0e0f38255e7fa3cf31335b3a56f05febd18025f4db5ef7a0cfb4f8da651f" +checksum = "d0c62115964e08cb8039170eb33c1d0e2388a256930279edca206fff675f82c3" [[package]] name = "indexmap" -version = "2.2.1" +version = "2.2.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "433de089bd45971eecf4668ee0ee8f4cec17db4f8bd8f7bc3197a6ce37aa7d9b" +checksum = "168fb715dda47215e360912c096649d23d58bf392ac62f73919e831745e40f26" dependencies = [ "equivalent", "hashbrown 0.14.3", @@ -396,9 +433,9 @@ dependencies = [ [[package]] name = "indoc" -version = "2.0.4" +version = "2.0.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e186cfbae8084e513daff4240b4797e342f988cecda4fb6c939150f96315fd8" +checksum = "b248f5224d1d606005e02c97f5aa4e88eeb230488bcc03bc9ca4d7991399f2b5" [[package]] name = "instant" @@ -415,7 +452,7 @@ version = "0.4.12" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f23ff5ef2b80d608d61efee834934d862cd92461afc0560dedf493e4c033738b" dependencies = [ - "hermit-abi 0.3.4", + "hermit-abi 0.3.5", "libc", "windows-sys", ] @@ -429,6 +466,15 @@ dependencies = [ "either", ] +[[package]] +name = "itertools" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba291022dbbd398a455acf126c1e341954079855bc60dfdda641363bd6922569" +dependencies = [ + "either", +] + [[package]] name = "itoa" version = "1.0.10" @@ -444,11 +490,17 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "lazy_static" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" + [[package]] name = "libc" -version = "0.2.152" +version = "0.2.153" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "13e3bf6590cbc649f4d1a3eefc9d5d6eb746f5200ffb04e5e142700b8faa56e7" +checksum = "9c198f91728a82281a64e1f4f9eeb25d82cb32a5de251c6bd1b5154d63a8e7bd" [[package]] name = "libm" @@ -467,6 +519,12 @@ dependencies = [ "redox_syscall", ] +[[package]] +name = "linux-raw-sys" +version = "0.4.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "01cda141df6706de531b6c46c3a33ecca755538219bd484262fa09410c13539c" + [[package]] name = "log" version = "0.4.20" @@ -564,6 +622,12 @@ dependencies = [ "serde", ] +[[package]] +name = "num-conv" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "51d515d32fb182ee37cda2ccdcb92950d6a3c2893aa280e540671c2cd0f3b1d9" + [[package]] name = "num-integer" version = "0.1.45" @@ -613,7 +677,7 @@ version = "1.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4161fcb6d602d4d2081af7c3a45852d875a03dd337a6bfdd6e06407b61342a43" dependencies = [ - "hermit-abi 0.3.4", + "hermit-abi 0.3.5", "libc", ] @@ -723,12 +787,38 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "proptest" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "31b476131c3c86cb68032fdc5cb6d5a1045e3e42d96b69fa599fd77701e1f5bf" +dependencies = [ + "bit-set", + "bit-vec", + "bitflags 2.4.2", + "lazy_static", + "num-traits", + "rand", + "rand_chacha", + "rand_xorshift", + "regex-syntax", + "rusty-fork", + "tempfile", + "unarray", +] + [[package]] name = "puruspe" version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fe7765e19fb2ba6fd4373b8d90399f5321683ea7c11b598c6bbaa3a72e9c83b8" +[[package]] +name = "quick-error" +version = "1.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a1d01941d82fa2ab50be1e79e6714289dd7cde78eba4c074bc5a4374f650dfe0" + [[package]] name = "quote" version = "1.0.35" @@ -780,6 +870,15 @@ dependencies = [ "rand", ] +[[package]] +name = "rand_xorshift" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d25bf25ec5ae4a3f1b92f929810509a2f53d7dca2f50b794ff57e3face536c8f" +dependencies = [ + "rand_core", +] + [[package]] name = "rand_xoshiro" version = "0.6.0" @@ -865,15 +964,40 @@ version = "0.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c08c74e62047bb2de4ff487b251e4a92e24f48745648451635cec7d591162d9f" +[[package]] +name = "rustix" +version = "0.38.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "65e04861e65f21776e67888bfbea442b3642beaa0138fdb1dd7a84a52dffdb89" +dependencies = [ + "bitflags 2.4.2", + "errno", + "libc", + "linux-raw-sys", + "windows-sys", +] + [[package]] name = "rustversion" version = "1.0.14" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7ffc183a10b4478d04cbbbfc96d0873219d962dd5accaff2ffbd4ceb7df837f4" +[[package]] +name = "rusty-fork" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cb3dcc6e454c328bb824492db107ab7c0ae8fcffe4ad210136ef014458c1bc4f" +dependencies = [ + "fnv", + "quick-error", + "tempfile", + "wait-timeout", +] + [[package]] name = "rv" -version = "0.16.5" +version = "0.17.0" dependencies = [ "approx", "argmin", @@ -882,11 +1006,13 @@ dependencies = [ "criterion", "doc-comment", "indoc", + "itertools 0.12.1", "lru", "nalgebra", "num", "num-traits", "peroxide", + "proptest", "rand", "rand_distr", "rand_xoshiro", @@ -953,9 +1079,9 @@ dependencies = [ [[package]] name = "serde_yaml" -version = "0.9.31" +version = "0.9.34+deprecated" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "adf8a49373e98a4c5f0ceb5d05aa7c648d75f63774981ed95b7c7443bbd50c6e" +checksum = "6a8b1a1a2ebf674015cc02edccce75287f1a0130d394307b36743c2f5d504b47" dependencies = [ "indexmap", "itoa", @@ -1057,6 +1183,18 @@ version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f764005d11ee5f36500a149ace24e00e3da98b0158b3e2d53a7495660d3f4d60" +[[package]] +name = "tempfile" +version = "3.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85b77fafb263dd9d05cbeac119526425676db3784113aa9295c88498cbf8bff1" +dependencies = [ + "cfg-if", + "fastrand", + "rustix", + "windows-sys", +] + [[package]] name = "term" version = "0.7.0" @@ -1109,13 +1247,14 @@ dependencies = [ [[package]] name = "time" -version = "0.3.31" +version = "0.3.34" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f657ba42c3f86e7680e53c8cd3af8abbe56b5491790b46e22e19c0d57463583e" +checksum = "c8248b6521bb14bc45b4067159b9b6ad792e2d6d754d6c41fb50e29fefe38749" dependencies = [ "deranged", "itoa", "libc", + "num-conv", "num_threads", "powerfmt", "serde", @@ -1131,10 +1270,11 @@ checksum = "ef927ca75afb808a4d64dd374f00a2adf8d0fcff8e7b184af886c3c87ec4a3f3" [[package]] name = "time-macros" -version = "0.2.16" +version = "0.2.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "26197e33420244aeb70c3e8c78376ca46571bc4e701e4791c2cd9f57dcb3a43f" +checksum = "7ba3a3ef41e6672a2f0f001392bb5dcd3ff0a9992d618ca761a11c3121547774" dependencies = [ + "num-conv", "time-core", ] @@ -1154,6 +1294,12 @@ version = "1.17.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825" +[[package]] +name = "unarray" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eaea85b334db583fe3274d12b4cd1880032beab409c0d774be044d4480ab9a94" + [[package]] name = "unicode-ident" version = "1.0.12" @@ -1162,9 +1308,9 @@ checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" [[package]] name = "unsafe-libyaml" -version = "0.2.10" +version = "0.2.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ab4c90930b95a82d00dc9e9ac071b4991924390d46cbd0dfe566148667605e4b" +checksum = "673aac59facbab8a9007c7f6108d11f63b603f7cabff99fabf650fea5c32b861" [[package]] name = "version_check" @@ -1172,6 +1318,15 @@ version = "0.9.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f" +[[package]] +name = "wait-timeout" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9f200f5b12eb75f8c1ed65abd4b2db8a6e1b138a20de009dacee265a2498f3f6" +dependencies = [ + "libc", +] + [[package]] name = "walkdir" version = "2.4.0" diff --git a/Cargo.toml b/Cargo.toml index 16f1b36..45295be 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "rv" -version = "0.16.5" -authors = ["Baxter Eaves", "Michael Schmidt"] +version = "0.17.0" +authors = ["Baxter Eaves", "Michael Schmidt", "Chad Scherrer"] description = "Random variables" repository = "https://github.com/promised-ai/rv" readme = "README.md" @@ -27,13 +27,16 @@ rand = { version = "0.8.5", features = ["small_rng"] } rand_distr = "0.4.3" serde = {version = "1", features = ["derive"], optional = true} special = "0.10" -peroxide = { version = "0.32.1" } num-traits = "0.2.17" +rand_xoshiro = { version = "0.6", optional = true, features=["serde1"]} +itertools = "0.12.1" [dev-dependencies] assert = "0.7" criterion = { version = "0.5", features = ["html_reports"] } indoc = "2" +peroxide = { version = "0.32.1" } +proptest = "1.4.0" serde_yaml = "0.9" serde_json = "1" approx = "0.5" @@ -43,11 +46,14 @@ rand_xoshiro = "0.6" serde1 = ["serde", "nalgebra/serde-serialize"] arraydist = ["nalgebra"] process = ["serde", "nalgebra/serde-serialize", "argmin", "argmin-math", "arraydist"] -datum = [] +experimental = ["rand_xoshiro"] [package.metadata.docs.rs] all-features = true +[profile.test.package.proptest] +opt-level = 3 + # Benchmarks # ========== [[bench]] diff --git a/benches/beta.rs b/benches/beta.rs index 9ffe5a8..b6e5345 100644 --- a/benches/beta.rs +++ b/benches/beta.rs @@ -3,7 +3,7 @@ use std::f64; use criterion::Criterion; use criterion::{criterion_group, criterion_main}; use rand_distr::Beta; -use rv::traits::Rv; +use rv::traits::*; fn draw_rand_distr(rng: &mut R) -> f64 { let beta = Beta::new(5.0, 2.0).unwrap(); diff --git a/benches/categorical.rs b/benches/categorical.rs index 812e911..599d92f 100644 --- a/benches/categorical.rs +++ b/benches/categorical.rs @@ -2,7 +2,7 @@ use criterion::BatchSize; use criterion::Criterion; use criterion::{criterion_group, criterion_main}; use rv::dist::Categorical; -use rv::traits::Rv; +use rv::traits::*; fn bench_cat_draw(c: &mut Criterion) { let mut group = c.benchmark_group("Categorical draw compare"); diff --git a/benches/gev.rs b/benches/gev.rs index 2d5c0f7..b2f0006 100644 --- a/benches/gev.rs +++ b/benches/gev.rs @@ -2,7 +2,7 @@ use criterion::BatchSize; use criterion::Criterion; use criterion::{criterion_group, criterion_main}; use rv::dist::Gev; -use rv::traits::Rv; +use rv::traits::*; fn bench_gev_draw_0(c: &mut Criterion) { let gev = Gev::new(0.0, 1.0, 0.0).unwrap(); diff --git a/benches/mixture_entropy.rs b/benches/mixture_entropy.rs index fffdc96..b20fb46 100644 --- a/benches/mixture_entropy.rs +++ b/benches/mixture_entropy.rs @@ -2,7 +2,7 @@ use criterion::BatchSize; use criterion::Criterion; use criterion::{criterion_group, criterion_main}; use rv::dist::{Gaussian, Mixture, NormalGamma, SymmetricDirichlet}; -use rv::traits::{Entropy, Rv}; +use rv::traits::*; fn bench_gmm_entropy(c: &mut Criterion) { let ng = NormalGamma::new_unchecked(0.0, 1.0, 1.0, 1.0); diff --git a/benches/mvg.rs b/benches/mvg.rs index 1a8e52f..7ae9b53 100644 --- a/benches/mvg.rs +++ b/benches/mvg.rs @@ -4,7 +4,7 @@ use criterion::Criterion; use criterion::{criterion_group, criterion_main}; use nalgebra::DVector; use rv::dist::MvGaussian; -use rv::traits::{ContinuousDistr, Rv}; +use rv::traits::*; fn bench_mvg_draw(c: &mut Criterion) { let mut group = c.benchmark_group("MvGaussian, draw 1"); diff --git a/benches/rv.rs b/benches/rv.rs index ac6f496..25227a6 100644 --- a/benches/rv.rs +++ b/benches/rv.rs @@ -58,7 +58,7 @@ macro_rules! benchrv { let fx = $ctor; b.iter(|| { let _count: usize = - <$fxtype as Rv<$xtype>>::sample_stream(&fx, &mut rng) + <$fxtype as Sampleable<$xtype>>::sample_stream(&fx, &mut rng) .take(5) .count(); }) diff --git a/benches/unit_powerlaw.rs b/benches/unit_powerlaw.rs index 47f2fee..5fc33b7 100644 --- a/benches/unit_powerlaw.rs +++ b/benches/unit_powerlaw.rs @@ -2,7 +2,7 @@ use std::f64; use criterion::Criterion; use criterion::{criterion_group, criterion_main}; -use rv::traits::Rv; +use rv::traits::*; fn draw_rv(mut rng: &mut R) -> f64 { let powlaw = rv::dist::UnitPowerLaw::new(5.0).unwrap(); @@ -24,7 +24,6 @@ fn draw_2u_recip(rng: &mut R) -> f64 { fn bench_powlaw_draw(c: &mut Criterion) { let mut group = c.benchmark_group("unit_powerlaw_draw"); - group.bench_function("draw_rv", |b| { let mut rng = rand::thread_rng(); b.iter(|| draw_rv(&mut rng)) diff --git a/benches/wishart.rs b/benches/wishart.rs index 57f6a44..8b46199 100644 --- a/benches/wishart.rs +++ b/benches/wishart.rs @@ -4,7 +4,7 @@ use criterion::Criterion; use criterion::{criterion_group, criterion_main}; use nalgebra::DMatrix; use rv::dist::InvWishart; -use rv::traits::Rv; +use rv::traits::*; fn bench_wishart(c: &mut Criterion) { let mut group = c.benchmark_group("InvWishart"); diff --git a/examples/coin_flips.rs b/examples/coin_flips.rs index 78f9663..497e0fa 100644 --- a/examples/coin_flips.rs +++ b/examples/coin_flips.rs @@ -1,5 +1,4 @@ use rand::Rng; -use rv::data::DataOrSuffStat; use rv::dist::{Bernoulli, Beta}; use rv::prelude::BernoulliData; use rv::traits::*; diff --git a/examples/die_rolls.rs b/examples/die_rolls.rs index be1595c..24925e2 100644 --- a/examples/die_rolls.rs +++ b/examples/die_rolls.rs @@ -1,4 +1,3 @@ -use rv::data::DataOrSuffStat; use rv::dist::{Categorical, SymmetricDirichlet}; use rv::prelude::CategoricalData; use rv::traits::*; diff --git a/examples/estimate_pi.rs b/examples/estimate_pi.rs index 25c2cd4..18e5834 100644 --- a/examples/estimate_pi.rs +++ b/examples/estimate_pi.rs @@ -8,7 +8,7 @@ // A_square 4 * r^2 4 # in square // use rv::dist::Uniform; -use rv::traits::Rv; +use rv::traits::*; use std::f64::consts::PI; fn main() { diff --git a/examples/sbd.rs b/examples/sbd.rs new file mode 100644 index 0000000..224018c --- /dev/null +++ b/examples/sbd.rs @@ -0,0 +1,33 @@ +use rand::SeedableRng; +use rv::prelude::*; + +#[cfg(feature = "experimental")] +use rv::experimental::stick_breaking_process::{ + StickBreaking, StickBreakingDiscrete, StickSequence, +}; + +fn main() { + #[cfg(feature = "experimental")] + { + // Instantiate a stick-breaking process + let alpha = 10.0; + let sbp = StickBreaking::new(UnitPowerLaw::new(alpha).unwrap()); + + // Sample from it to get a StickSequence + let mut rng = rand_xoshiro::Xoshiro256PlusPlus::seed_from_u64(42); + let sticks: StickSequence = sbp.draw(&mut rng); + + // Use the StickSequence to instantiate a stick-breaking discrete distribution + let sbd = StickBreakingDiscrete::new(sticks.clone()); + + let start = std::time::Instant::now(); + let entropy = sbd.entropy(); + let duration = start.elapsed(); + println!("Entropy: {}", entropy); + println!("Time elapsed in entropy() is: {:?}", duration); + + let num_weights = sbd.stick_sequence().num_weights_unstable(); + + println!("num weights: {}", num_weights); + } +} diff --git a/examples/stickbreaking_posterior.rs b/examples/stickbreaking_posterior.rs new file mode 100644 index 0000000..d72eb43 --- /dev/null +++ b/examples/stickbreaking_posterior.rs @@ -0,0 +1,63 @@ +use itertools::Either; +use peroxide::statistics::stat::Statistics; +use rv::prelude::*; + +#[cfg(feature = "experimental")] +use rv::experimental::stick_breaking_process::*; + +fn main() { + #[cfg(feature = "experimental")] + { + let mut rng = rand::thread_rng(); + let sb = StickBreaking::new(UnitPowerLaw::new(3.0).unwrap()); + + let num_samples = 1_000_000; + + // Our computed posterior + let data = [10]; + let dist = sb.posterior(&DataOrSuffStat::Data(&data[..])); + // let dist = sb.clone(); + + // An approximation using rejection sampling + let mut approx: Vec> = Vec::new(); + while approx.len() < num_samples { + let seq: StickSequence = sb.draw(&mut rng); + let sbd = StickBreakingDiscrete::new(seq.clone()); + if sbd.draw(&mut rng) == 10 { + approx.push(BreakSequence::from(&seq.weights(20)).0); + } + } + + let mut counts: Vec> = vec![]; + for j in 0..20 { + counts.push( + approx + .iter() + .map(|breaks: &Vec| *breaks.get(j).unwrap()) + .collect(), + ) + } + + let break_dists: Vec = dist + .break_dists() + .take(20) + .map(|x| match x { + Either::Left(p) => p.clone(), + Either::Right(p) => Beta::new_unchecked(p.alpha(), 1.0), + }) + .collect(); + + counts.iter().enumerate().for_each(|(n, c)| { + let data_mean: f64 = c.mean(); + let beta_mean: f64 = break_dists[n].mean().unwrap(); + println!( + "n: {}\tmean: {:.3} (pred {:.3})\t var: {:.3} (pred {:.3})", + n, + data_mean, + beta_mean, + c.var(), + break_dists[n].variance().unwrap(), + ); + }); + } +} diff --git a/src/data/datum.rs b/src/data/datum.rs deleted file mode 100644 index 02366f9..0000000 --- a/src/data/datum.rs +++ /dev/null @@ -1,184 +0,0 @@ -#[cfg(feature = "serde1")] -use serde::{Deserialize, Serialize}; - -#[cfg(feature = "arraydist")] -use crate::nalgebra::{DMatrix, DVector}; -use crate::traits::Rv; -use std::convert::TryInto; - -/// Represents any Datum/Value, X, for which Rv may be implemented on a -/// `Distribution`. -#[non_exhaustive] -#[derive(Clone, Debug, PartialEq, PartialOrd)] -#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] -pub enum Datum { - F64(f64), - F32(f32), - Bool(bool), - U8(u8), - U16(u16), - U32(u32), - U64(u64), - I8(i8), - I16(i16), - I32(i32), - I64(i64), - ISize(isize), - USize(usize), - Vec(Vec), - #[cfg(feature = "arraydist")] - DVector(DVector), - #[cfg(feature = "arraydist")] - DMatrix(DMatrix), - Compound(Vec), -} - -pub trait RvDatum -where - Self: Rv<::Support>, -{ - type Support: From + Into; -} - -impl Rv for Fx -where - Fx: RvDatum, -{ - fn ln_f(&self, x: &Datum) -> f64 { - let y = ::Support::from(x.clone()); - ::Support>>::ln_f(self, &y) - } - - fn draw(&self, rng: &mut R) -> Datum { - let x = ::Support>>::draw(self, rng); - x.into() - } - - fn sample(&self, n: usize, rng: &mut R) -> Vec { - ::Support>>::sample(self, n, rng) - .drain(..) - .map(|x| x.into()) - .collect() - } - - fn sample_stream<'r, R: rand::Rng>( - &'r self, - rng: &'r mut R, - ) -> Box + 'r> { - let iter = - ::Support>>::sample_stream(self, rng) - .map(|x| x.try_into().unwrap()); - Box::new(iter) - } -} - -macro_rules! impl_rvdatum { - ($dist: ident, $type: ty) => { - #[cfg(feature = "datum")] - impl RvDatum for $crate::dist::$dist { - type Support = $type; - } - }; -} - -macro_rules! convert_datum { - ($self:ty | $primary:ident, $( $variant:ident ),*) => ( - impl From<$self> for Datum { - fn from(x: $self) -> Datum { - Datum::$primary(x) - } - } - - impl From for $self { - fn from(datum: Datum) -> $self { - match datum { - Datum::$primary(x) => x, - $(Datum::$variant(x) => x.into(),)* - Datum::Compound(mut xs) => { - if xs.len() == 1 { - xs.pop().unwrap().into() - } else { - panic!("failed") - } - } - _ => { - panic!("failed") - } - } - } - } - ); - ($self:ty | $primary:ident) => ( - convert_datum!($self | $primary, ); - ) -} - -convert_datum!(f32 | F32); -convert_datum!(f64 | F64, F32); -convert_datum!(bool | Bool); -convert_datum!(u8 | U8); -convert_datum!(u16 | U16, U8); -convert_datum!(u32 | U32, U8, U16); -convert_datum!(u64 | U64, U8, U16, U32); -convert_datum!(usize | USize, U8, U16); -convert_datum!(i8 | I8); -convert_datum!(i16 | I16, I8); -convert_datum!(i32 | I32, I16); -convert_datum!(i64 | I64, I16, I32, I8); -convert_datum!(isize | ISize, I8, I16); -convert_datum!(Vec | Vec); -#[cfg(feature = "arraydist")] -convert_datum!(DVector | DVector); -#[cfg(feature = "arraydist")] -convert_datum!(DMatrix | DMatrix); - -impl From> for Datum { - fn from(xs: Vec) -> Self { - Datum::Compound(xs) - } -} - -impl From for Vec { - fn from(x: Datum) -> Self { - match x { - Datum::Compound(xs) => xs, - _ => panic!("invalid From type for Datum::Compound"), - } - } -} - -impl_rvdatum!(Bernoulli, bool); -impl_rvdatum!(Beta, f64); -impl_rvdatum!(BetaBinomial, u32); -impl_rvdatum!(Binomial, u32); -impl_rvdatum!(Categorical, u32); -impl_rvdatum!(Cauchy, f64); -impl_rvdatum!(ChiSquared, f64); -impl_rvdatum!(Dirichlet, Vec); -// impl_rvdatum!(DiscreteUniform, u32); -impl_rvdatum!(Empirical, f64); -impl_rvdatum!(Exponential, f64); -impl_rvdatum!(Gamma, f64); -impl_rvdatum!(Gaussian, f64); -impl_rvdatum!(Geometric, u32); -impl_rvdatum!(Gev, f64); -impl_rvdatum!(InvChiSquared, f64); -impl_rvdatum!(InvGamma, f64); -impl_rvdatum!(InvGaussian, f64); -impl_rvdatum!(KsTwoAsymptotic, f64); -impl_rvdatum!(Kumaraswamy, f64); -impl_rvdatum!(Laplace, f64); -impl_rvdatum!(LogNormal, f64); -#[cfg(feature = "arraydist")] -impl_rvdatum!(MvGaussian, DVector); -impl_rvdatum!(NegBinomial, u32); -impl_rvdatum!(Pareto, f64); -impl_rvdatum!(Poisson, u32); -impl_rvdatum!(ScaledInvChiSquared, f64); -impl_rvdatum!(Skellam, i32); -impl_rvdatum!(StudentsT, f64); -impl_rvdatum!(SymmetricDirichlet, Vec); -impl_rvdatum!(Uniform, f64); -impl_rvdatum!(VonMises, f64); -#[cfg(feature = "arraydist")] -impl_rvdatum!(InvWishart, DMatrix); diff --git a/src/data/mod.rs b/src/data/mod.rs index ed4c68a..80e2af8 100644 --- a/src/data/mod.rs +++ b/src/data/mod.rs @@ -2,12 +2,6 @@ mod partition; mod stat; -#[cfg(feature = "datum")] -mod datum; - -#[cfg(feature = "datum")] -pub use datum::Datum; - pub use partition::Partition; pub use stat::BernoulliSuffStat; pub use stat::BetaSuffStat; @@ -154,8 +148,6 @@ where Data(&'a [X]), /// A sufficient statistic SuffStat(&'a Fx::Stat), - /// No data - None, } impl<'a, X, Fx> DataOrSuffStat<'a, X, Fx> @@ -168,7 +160,6 @@ where match &self { DataOrSuffStat::Data(data) => data.len(), DataOrSuffStat::SuffStat(s) => s.n(), - DataOrSuffStat::None => 0, } } @@ -217,33 +208,6 @@ where pub fn is_suffstat(&self) -> bool { matches!(&self, DataOrSuffStat::SuffStat(..)) } - - /// Determine whether the object is empty - /// - /// # Example - /// - /// ``` - /// # use rv::data::DataOrSuffStat; - /// use rv::dist::Gaussian; - /// use rv::data::GaussianSuffStat; - /// - /// let xs = vec![1.0_f64]; - /// let data: DataOrSuffStat = DataOrSuffStat::Data(&xs); - /// - /// assert!(!data.is_none()); - /// - /// let gauss_stats = GaussianSuffStat::new(); - /// let suffstat: DataOrSuffStat = DataOrSuffStat::SuffStat(&gauss_stats); - /// - /// assert!(!suffstat.is_none()); - /// - /// let none: DataOrSuffStat = DataOrSuffStat::None; - /// - /// assert!(none.is_none()); - /// ``` - pub fn is_none(&self) -> bool { - matches!(&self, DataOrSuffStat::None) - } } /// Convert a `DataOrSuffStat` into a `Stat` @@ -264,7 +228,6 @@ where xs.iter().for_each(|y| stat.observe(y)); stat } - DataOrSuffStat::None => stat_ctor(), } } diff --git a/src/data/stat/categorical.rs b/src/data/stat/categorical.rs index eb4c848..4d24b3f 100644 --- a/src/data/stat/categorical.rs +++ b/src/data/stat/categorical.rs @@ -41,7 +41,7 @@ impl CategoricalSuffStat { /// ``` /// # use rv::data::CategoricalSuffStat; /// # use rv::traits::SuffStat; - /// let mut stat = CategoricalSuffStat::new(3); + /// let mut stat = CategoricalSuffStat::new(2); /// /// stat.observe(&0_u8); /// stat.observe(&1_u8); @@ -64,10 +64,10 @@ impl CategoricalSuffStat { /// let mut stat = CategoricalSuffStat::new(3); /// /// stat.observe(&0_u8); - /// stat.observe(&1_u8); - /// stat.observe(&1_u8); + /// stat.observe(&2_u8); + /// stat.observe(&2_u8); /// - /// assert_eq!(*stat.counts(), vec![1.0, 2.0, 0.0]); + /// assert_eq!(*stat.counts(), vec![1.0, 0.0, 2.0]); /// ``` #[inline] pub fn counts(&self) -> &Vec { diff --git a/src/dist/bernoulli.rs b/src/dist/bernoulli.rs index 0f4f74e..484806b 100644 --- a/src/dist/bernoulli.rs +++ b/src/dist/bernoulli.rs @@ -39,6 +39,18 @@ pub struct Bernoulli { p: f64, } +impl Parameterized for Bernoulli { + type Parameters = f64; + + fn emit_params(&self) -> Self::Parameters { + self.p() + } + + fn from_params(params: Self::Parameters) -> Self { + Self::new_unchecked(params) + } +} + #[derive(Debug, Clone, PartialEq)] #[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] #[cfg_attr(feature = "serde1", serde(rename_all = "snake_case"))] @@ -58,7 +70,7 @@ impl Bernoulli { /// /// ```rust /// # use rv::dist::Bernoulli; - /// # use rv::traits::Rv; + /// # use rv::traits::*; /// # let mut rng = rand::thread_rng(); /// let b = Bernoulli::new(0.5).unwrap(); /// @@ -146,9 +158,9 @@ impl Bernoulli { /// assert!(b.set_p(1.0).is_ok()); /// assert!(b.set_p(-1.0).is_err()); /// assert!(b.set_p(1.1).is_err()); - /// assert!(b.set_p(std::f64::INFINITY).is_err()); + /// assert!(b.set_p(f64::INFINITY).is_err()); /// assert!(b.set_p(std::f64::NEG_INFINITY).is_err()); - /// assert!(b.set_p(std::f64::NAN).is_err()); + /// assert!(b.set_p(f64::NAN).is_err()); /// ``` #[inline] pub fn set_p(&mut self, p: f64) -> Result<(), BernoulliError> { @@ -200,7 +212,7 @@ impl From<&Bernoulli> for String { impl_display!(Bernoulli); -impl Rv for Bernoulli { +impl HasDensity for Bernoulli { fn f(&self, x: &X) -> f64 { let val: bool = x.into_bool(); if val { @@ -214,7 +226,9 @@ impl Rv for Bernoulli { // TODO: this is really slow, we should cache ln(p) and ln(q) self.f(x).ln() } +} +impl Sampleable for Bernoulli { fn draw(&self, rng: &mut R) -> X { let u = rand_distr::Open01; let x: f64 = rng.sample(u); @@ -365,13 +379,12 @@ mod tests { use super::*; use crate::misc::x2_test; use crate::test_basic_impls; - use std::f64; const TOL: f64 = 1E-12; const N_TRIES: usize = 5; const X2_PVAL: f64 = 0.2; - test_basic_impls!([binary] Bernoulli::default()); + test_basic_impls!(bool, Bernoulli, Bernoulli::default()); #[test] fn new() { @@ -656,7 +669,7 @@ mod tests { } #[test] - fn unifrom_entropy() { + fn uniform_entropy() { let b = Bernoulli::uniform(); assert::close(b.entropy(), f64::consts::LN_2, TOL); } diff --git a/src/dist/beta.rs b/src/dist/beta.rs index 500e6b5..c8ddd88 100644 --- a/src/dist/beta.rs +++ b/src/dist/beta.rs @@ -29,7 +29,7 @@ pub mod bernoulli_prior; /// /// // The posterior predictive probability that a coin will come up heads given /// // no new observations. -/// let p_prior_heads = beta.pp(&true, &DataOrSuffStat::None); // 0.5 +/// let p_prior_heads = beta.pp(&true, &DataOrSuffStat::from(&vec![])); // 0.5 /// assert!((p_prior_heads - 0.5).abs() < 1E-12); /// /// // Five Bernoulli trials. We flipped a coin five times and it came up head @@ -52,6 +52,26 @@ pub struct Beta { ln_beta_ab: OnceLock, } +pub struct BetaParameters { + pub alpha: f64, + pub beta: f64, +} + +impl Parameterized for Beta { + type Parameters = BetaParameters; + + fn emit_params(&self) -> Self::Parameters { + Self::Parameters { + alpha: self.alpha(), + beta: self.beta(), + } + } + + fn from_params(params: Self::Parameters) -> Self { + Self::new_unchecked(params.alpha, params.beta) + } +} + impl PartialEq for Beta { fn eq(&self, other: &Beta) -> bool { self.alpha == other.alpha && self.beta == other.beta @@ -190,8 +210,8 @@ impl Beta { /// assert!(beta.set_alpha(0.1).is_ok()); /// assert!(beta.set_alpha(0.0).is_err()); /// assert!(beta.set_alpha(-1.0).is_err()); - /// assert!(beta.set_alpha(std::f64::INFINITY).is_err()); - /// assert!(beta.set_alpha(std::f64::NAN).is_err()); + /// assert!(beta.set_alpha(f64::INFINITY).is_err()); + /// assert!(beta.set_alpha(f64::NAN).is_err()); /// ``` #[inline] pub fn set_alpha(&mut self, alpha: f64) -> Result<(), BetaError> { @@ -246,8 +266,8 @@ impl Beta { /// assert!(beta.set_beta(0.1).is_ok()); /// assert!(beta.set_beta(0.0).is_err()); /// assert!(beta.set_beta(-1.0).is_err()); - /// assert!(beta.set_beta(std::f64::INFINITY).is_err()); - /// assert!(beta.set_beta(std::f64::NAN).is_err()); + /// assert!(beta.set_beta(f64::INFINITY).is_err()); + /// assert!(beta.set_beta(f64::NAN).is_err()); /// ``` #[inline] pub fn set_beta(&mut self, beta: f64) -> Result<(), BetaError> { @@ -293,14 +313,16 @@ impl_display!(Beta); macro_rules! impl_traits { ($kind:ty) => { - impl Rv<$kind> for Beta { + impl HasDensity<$kind> for Beta { fn ln_f(&self, x: &$kind) -> f64 { (self.alpha - 1.0).mul_add( f64::from(*x).ln(), (self.beta - 1.0) * (1.0 - f64::from(*x)).ln(), ) - self.ln_beta_ab() } + } + impl Sampleable<$kind> for Beta { fn draw(&self, rng: &mut R) -> $kind { let b = rand_distr::Beta::new(self.alpha, self.beta).unwrap(); rng.sample(b) as $kind @@ -446,13 +468,12 @@ mod tests { use super::*; use crate::misc::ks_test; use crate::test_basic_impls; - use std::f64; const TOL: f64 = 1E-12; const KS_PVAL: f64 = 0.2; const N_TRIES: usize = 5; - test_basic_impls!([continuous] Beta::jeffreys()); + test_basic_impls!(f64, Beta, Beta::jeffreys()); #[test] fn new() { @@ -576,7 +597,7 @@ mod tests { } #[test] - fn draw_should_resturn_values_within_0_to_1() { + fn draw_should_return_values_within_0_to_1() { let mut rng = rand::thread_rng(); let beta = Beta::jeffreys(); for _ in 0..100 { diff --git a/src/dist/beta/bernoulli_prior.rs b/src/dist/beta/bernoulli_prior.rs index 9a58387..d58a573 100644 --- a/src/dist/beta/bernoulli_prior.rs +++ b/src/dist/beta/bernoulli_prior.rs @@ -1,15 +1,17 @@ use rand::Rng; use special::Beta as SBeta; -use crate::data::{BernoulliSuffStat, Booleable, DataOrSuffStat}; +use crate::data::{BernoulliSuffStat, Booleable}; use crate::dist::{Bernoulli, Beta}; use crate::traits::*; -impl Rv for Beta { +impl HasDensity for Beta { fn ln_f(&self, x: &Bernoulli) -> f64 { self.ln_f(&x.p()) } +} +impl Sampleable for Beta { fn draw(&self, mut rng: &mut R) -> Bernoulli { let p: f64 = self.draw(&mut rng); Bernoulli::new(p).expect("Failed to draw valid weight") @@ -26,8 +28,8 @@ impl ContinuousDistr for Beta {} impl ConjugatePrior for Beta { type Posterior = Self; - type LnMCache = f64; - type LnPpCache = (f64, f64); + type MCache = f64; + type PpCache = (f64, f64); #[allow(clippy::many_single_char_names)] fn posterior(&self, x: &DataOrSuffStat) -> Self { @@ -38,7 +40,6 @@ impl ConjugatePrior for Beta { (stat.n(), stat.k()) } DataOrSuffStat::SuffStat(stat) => (stat.n(), stat.k()), - DataOrSuffStat::None => (0, 0), }; let a = self.alpha() + k as f64; @@ -48,13 +49,13 @@ impl ConjugatePrior for Beta { } #[inline] - fn ln_m_cache(&self) -> Self::LnMCache { + fn ln_m_cache(&self) -> Self::MCache { self.alpha().ln_beta(self.beta()) } fn ln_m_with_cache( &self, - cache: &Self::LnMCache, + cache: &Self::MCache, x: &DataOrSuffStat, ) -> f64 { let post = self.posterior(x); @@ -62,14 +63,14 @@ impl ConjugatePrior for Beta { } #[inline] - fn ln_pp_cache(&self, x: &DataOrSuffStat) -> Self::LnPpCache { + fn ln_pp_cache(&self, x: &DataOrSuffStat) -> Self::PpCache { // P(y=1 | xs) happens to be the posterior mean let post = self.posterior(x); let p: f64 = post.mean().expect("Mean undefined"); (p.ln(), (1.0 - p).ln()) } - fn ln_pp_with_cache(&self, cache: &Self::LnPpCache, y: &X) -> f64 { + fn ln_pp_with_cache(&self, cache: &Self::PpCache, y: &X) -> f64 { // P(y=1 | xs) happens to be the posterior mean if y.into_bool() { cache.0 @@ -82,9 +83,12 @@ impl ConjugatePrior for Beta { #[cfg(test)] mod tests { use super::*; + use crate::test_conjugate_prior; const TOL: f64 = 1E-12; + test_conjugate_prior!(bool, Bernoulli, Beta, Beta::new(0.5, 1.2).unwrap()); + #[test] fn posterior_from_data_bool() { let data = vec![false, true, false, true, true]; diff --git a/src/dist/beta_binom.rs b/src/dist/beta_binom.rs index 87e5c87..354ae32 100644 --- a/src/dist/beta_binom.rs +++ b/src/dist/beta_binom.rs @@ -66,6 +66,28 @@ pub struct BetaBinomial { ln_beta_ab: OnceLock, } +pub struct BetaBinomialParameters { + pub n: u32, + pub alpha: f64, + pub beta: f64, +} + +impl Parameterized for BetaBinomial { + type Parameters = BetaBinomialParameters; + + fn emit_params(&self) -> Self::Parameters { + Self::Parameters { + n: self.n(), + alpha: self.alpha(), + beta: self.beta(), + } + } + + fn from_params(params: Self::Parameters) -> Self { + Self::new_unchecked(params.n, params.alpha, params.beta) + } +} + impl PartialEq for BetaBinomial { fn eq(&self, other: &BetaBinomial) -> bool { self.n == other.n @@ -192,8 +214,8 @@ impl BetaBinomial { /// assert!(bb.set_alpha(0.1).is_ok()); /// assert!(bb.set_alpha(0.0).is_err()); /// assert!(bb.set_alpha(-1.0).is_err()); - /// assert!(bb.set_alpha(std::f64::INFINITY).is_err()); - /// assert!(bb.set_alpha(std::f64::NAN).is_err()); + /// assert!(bb.set_alpha(f64::INFINITY).is_err()); + /// assert!(bb.set_alpha(f64::NAN).is_err()); /// ``` #[inline] pub fn set_alpha(&mut self, alpha: f64) -> Result<(), BetaBinomialError> { @@ -248,8 +270,8 @@ impl BetaBinomial { /// assert!(bb.set_beta(0.1).is_ok()); /// assert!(bb.set_beta(0.0).is_err()); /// assert!(bb.set_beta(-1.0).is_err()); - /// assert!(bb.set_beta(std::f64::INFINITY).is_err()); - /// assert!(bb.set_beta(std::f64::NAN).is_err()); + /// assert!(bb.set_beta(f64::INFINITY).is_err()); + /// assert!(bb.set_beta(f64::NAN).is_err()); /// ``` #[inline] pub fn set_beta(&mut self, beta: f64) -> Result<(), BetaBinomialError> { @@ -320,7 +342,7 @@ impl_display!(BetaBinomial); macro_rules! impl_int_traits { ($kind:ty) => { - impl Rv<$kind> for BetaBinomial { + impl HasDensity<$kind> for BetaBinomial { fn ln_f(&self, k: &$kind) -> f64 { let nf = f64::from(self.n); let kf = *k as f64; @@ -328,7 +350,9 @@ macro_rules! impl_int_traits { + (kf + self.alpha).ln_beta(nf - kf + self.beta) - self.ln_beta_ab() } + } + impl Sampleable<$kind> for BetaBinomial { fn draw(&self, mut rng: &mut R) -> $kind { self.sample(1, &mut rng)[0] } @@ -422,11 +446,14 @@ impl fmt::Display for BetaBinomialError { mod tests { use super::*; use crate::test_basic_impls; - use std::f64; const TOL: f64 = 1E-12; - test_basic_impls!([count] BetaBinomial::new(10, 0.2, 0.7).unwrap()); + test_basic_impls!( + u32, + BetaBinomial, + BetaBinomial::new(10, 0.2, 0.7).unwrap() + ); #[test] fn new() { diff --git a/src/dist/binomial.rs b/src/dist/binomial.rs index 938945a..1289619 100644 --- a/src/dist/binomial.rs +++ b/src/dist/binomial.rs @@ -51,6 +51,26 @@ pub struct Binomial { p: f64, } +pub struct BinomialParameters { + pub n: u64, + pub p: f64, +} + +impl Parameterized for Binomial { + type Parameters = BinomialParameters; + + fn emit_params(&self) -> Self::Parameters { + Self::Parameters { + n: self.n(), + p: self.p(), + } + } + + fn from_params(params: Self::Parameters) -> Self { + Self::new_unchecked(params.n, params.p) + } +} + #[derive(Debug, Clone, PartialEq)] #[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] #[cfg_attr(feature = "serde1", serde(rename_all = "snake_case"))] @@ -195,9 +215,9 @@ impl Binomial { /// assert!(binom.set_p(1.0).is_ok()); /// assert!(binom.set_p(-1.0).is_err()); /// assert!(binom.set_p(1.1).is_err()); - /// assert!(binom.set_p(std::f64::INFINITY).is_err()); - /// assert!(binom.set_p(std::f64::NEG_INFINITY).is_err()); - /// assert!(binom.set_p(std::f64::NAN).is_err()); + /// assert!(binom.set_p(f64::INFINITY).is_err()); + /// assert!(binom.set_p(f64::NEG_INFINITY).is_err()); + /// assert!(binom.set_p(f64::NAN).is_err()); /// ``` #[inline] pub fn set_p(&mut self, p: f64) -> Result<(), BinomialError> { @@ -244,7 +264,7 @@ impl_display!(Binomial); macro_rules! impl_int_traits { ($kind:ty) => { - impl Rv<$kind> for Binomial { + impl HasDensity<$kind> for Binomial { fn ln_f(&self, k: &$kind) -> f64 { let nf = self.n as f64; let kf = *k as f64; @@ -254,6 +274,9 @@ macro_rules! impl_int_traits { self.p.ln().mul_add(kf, ln_binom(nf, kf)), ) } + } + + impl Sampleable<$kind> for Binomial { fn draw(&self, rng: &mut R) -> $kind { let b = rand_distr::Binomial::new(self.n, self.p).unwrap(); rng.sample(b) as $kind @@ -337,13 +360,12 @@ mod tests { use super::*; use crate::misc::x2_test; use crate::test_basic_impls; - use std::f64; const TOL: f64 = 1E-12; const N_TRIES: usize = 5; const X2_PVAL: f64 = 0.2; - test_basic_impls!([count] Binomial::uniform(10)); + test_basic_impls!(u32, Binomial, Binomial::uniform(10)); #[test] fn new() { diff --git a/src/dist/categorical.rs b/src/dist/categorical.rs index eea5532..9dc2a41 100644 --- a/src/dist/categorical.rs +++ b/src/dist/categorical.rs @@ -19,6 +19,24 @@ pub struct Categorical { ln_weights: Vec, } +pub struct CategoricalParameters { + pub ln_weights: Vec, +} + +impl Parameterized for Categorical { + type Parameters = CategoricalParameters; + + fn emit_params(&self) -> Self::Parameters { + Self::Parameters { + ln_weights: self.ln_weights().clone(), + } + } + + fn from_params(params: Self::Parameters) -> Self { + Self::new_unchecked(params.ln_weights) + } +} + #[derive(Debug, Clone, PartialEq)] #[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] #[cfg_attr(feature = "serde1", serde(rename_all = "snake_case"))] @@ -118,7 +136,7 @@ impl Categorical { .enumerate() .try_for_each(|(ix, &weight)| { // Manually check for -Inf - if weight.is_finite() || weight == std::f64::NEG_INFINITY { + if weight.is_finite() || weight == f64::NEG_INFINITY { Ok(()) } else { // Catch Inf and NaN @@ -188,12 +206,14 @@ impl From<&Categorical> for String { impl_display!(Categorical); -impl Rv for Categorical { +impl HasDensity for Categorical { fn ln_f(&self, x: &X) -> f64 { let ix: usize = x.into_usize(); self.ln_weights[ix] } +} +impl Sampleable for Categorical { fn draw(&self, mut rng: &mut R) -> X { let ix = ln_pflip(&self.ln_weights, 1, true, &mut rng)[0]; CategoricalDatum::from_usize(ix) @@ -302,17 +322,16 @@ mod tests { use crate::misc::x2_test; use crate::test_basic_impls; use std::f64::consts::LN_2; - use std::f64::NEG_INFINITY; const TOL: f64 = 1E-12; const N_TRIES: usize = 5; const X2_PVAL: f64 = 0.2; - test_basic_impls!([categorical] Categorical::uniform(3)); + test_basic_impls!(u8, Categorical, Categorical::uniform(3)); #[test] fn from_ln_weights_with_zero_weight_should_work() { - let ln_weights: Vec = vec![-LN_2, NEG_INFINITY, -LN_2]; + let ln_weights: Vec = vec![-LN_2, f64::NEG_INFINITY, -LN_2]; let res = Categorical::from_ln_weights(ln_weights); assert!(res.is_ok()); } diff --git a/src/dist/cauchy.rs b/src/dist/cauchy.rs index 7a597d6..d0ad0e8 100644 --- a/src/dist/cauchy.rs +++ b/src/dist/cauchy.rs @@ -33,6 +33,26 @@ pub struct Cauchy { scale: f64, } +pub struct CauchyParameters { + pub loc: f64, + pub scale: f64, +} + +impl Parameterized for Cauchy { + type Parameters = CauchyParameters; + + fn emit_params(&self) -> Self::Parameters { + Self::Parameters { + loc: self.loc(), + scale: self.scale(), + } + } + + fn from_params(params: Self::Parameters) -> Self { + Self::new_unchecked(params.loc, params.scale) + } +} + #[derive(Debug, Clone, PartialEq)] #[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] #[cfg_attr(feature = "serde1", serde(rename_all = "snake_case"))] @@ -101,9 +121,9 @@ impl Cauchy { /// # use rv::dist::Cauchy; /// # let mut c = Cauchy::new(0.1, 1.0).unwrap(); /// assert!(c.set_loc(2.0).is_ok()); - /// assert!(c.set_loc(std::f64::INFINITY).is_err()); - /// assert!(c.set_loc(std::f64::NEG_INFINITY).is_err()); - /// assert!(c.set_loc(std::f64::NAN).is_err()); + /// assert!(c.set_loc(f64::INFINITY).is_err()); + /// assert!(c.set_loc(f64::NEG_INFINITY).is_err()); + /// assert!(c.set_loc(f64::NAN).is_err()); /// ``` #[inline] pub fn set_loc(&mut self, loc: f64) -> Result<(), CauchyError> { @@ -154,8 +174,8 @@ impl Cauchy { /// # let mut c = Cauchy::new(0.1, 1.0).unwrap(); /// assert!(c.set_scale(0.0).is_err()); /// assert!(c.set_scale(-1.0).is_err()); - /// assert!(c.set_scale(std::f64::NAN).is_err()); - /// assert!(c.set_scale(std::f64::INFINITY).is_err()); + /// assert!(c.set_scale(f64::NAN).is_err()); + /// assert!(c.set_scale(f64::INFINITY).is_err()); /// ``` #[inline] pub fn set_scale(&mut self, scale: f64) -> Result<(), CauchyError> { @@ -192,7 +212,7 @@ impl_display!(Cauchy); macro_rules! impl_traits { ($kind:ty) => { - impl Rv<$kind> for Cauchy { + impl HasDensity<$kind> for Cauchy { fn ln_f(&self, x: &$kind) -> f64 { let ln_scale = self.scale.ln(); let term = 2.0_f64.mul_add( @@ -202,7 +222,9 @@ macro_rules! impl_traits { // TODO: make a logaddexp method for two floats -logsumexp(&[ln_scale, term]) - LN_PI } + } + impl Sampleable<$kind> for Cauchy { fn draw(&self, rng: &mut R) -> $kind { let cauchy = RCauchy::new(self.loc, self.scale).unwrap(); rng.sample(cauchy) as $kind @@ -289,7 +311,7 @@ mod tests { const KS_PVAL: f64 = 0.2; const N_TRIES: usize = 5; - test_basic_impls!([continuous] Cauchy::default()); + test_basic_impls!(f64, Cauchy); #[test] fn ln_pdf_loc_zero() { diff --git a/src/dist/chi_squared.rs b/src/dist/chi_squared.rs index bf186ab..970f714 100644 --- a/src/dist/chi_squared.rs +++ b/src/dist/chi_squared.rs @@ -1,4 +1,4 @@ -//! Χ2 over x in (0, ∞) +//! Χ2 over x in (0, ∞) #[cfg(feature = "serde1")] use serde::{Deserialize, Serialize}; @@ -28,6 +28,18 @@ pub struct ChiSquared { k: f64, } +impl Parameterized for ChiSquared { + type Parameters = f64; + + fn emit_params(&self) -> Self::Parameters { + self.k() + } + + fn from_params(k: Self::Parameters) -> Self { + Self::new_unchecked(k) + } +} + #[derive(Debug, Clone, PartialEq)] #[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] #[cfg_attr(feature = "serde1", serde(rename_all = "snake_case"))] @@ -95,8 +107,8 @@ impl ChiSquared { /// assert!(x2.set_k(2.2).is_ok()); /// assert!(x2.set_k(0.0).is_err()); /// assert!(x2.set_k(-1.0).is_err()); - /// assert!(x2.set_k(std::f64::NAN).is_err()); - /// assert!(x2.set_k(std::f64::INFINITY).is_err()); + /// assert!(x2.set_k(f64::NAN).is_err()); + /// assert!(x2.set_k(f64::INFINITY).is_err()); /// ``` #[inline] pub fn set_k(&mut self, k: f64) -> Result<(), ChiSquaredError> { @@ -126,7 +138,7 @@ impl_display!(ChiSquared); macro_rules! impl_traits { ($kind:ty) => { - impl Rv<$kind> for ChiSquared { + impl HasDensity<$kind> for ChiSquared { fn ln_f(&self, x: &$kind) -> f64 { let k2 = self.k / 2.0; let xf = f64::from(*x); @@ -134,7 +146,9 @@ macro_rules! impl_traits { k2.mul_add(-LN_2, (k2 - 1.0).mul_add(xf.ln(), -xf / 2.0)) - ln_gammafn(k2) } + } + impl Sampleable<$kind> for ChiSquared { fn draw(&self, rng: &mut R) -> $kind { let x2 = rand_distr::ChiSquared::new(self.k).unwrap(); rng.sample(x2) as $kind @@ -219,7 +233,7 @@ mod tests { const KS_PVAL: f64 = 0.2; const N_TRIES: usize = 5; - test_basic_impls!([continuous] ChiSquared::new(3.2).unwrap()); + test_basic_impls!(f64, ChiSquared, ChiSquared::new(3.2).unwrap()); #[test] fn new() { diff --git a/src/dist/crp.rs b/src/dist/crp.rs index 656d528..ff12a47 100644 --- a/src/dist/crp.rs +++ b/src/dist/crp.rs @@ -113,9 +113,9 @@ impl Crp { /// assert!(crp.set_alpha(0.5).is_ok()); /// assert!(crp.set_alpha(0.0).is_err()); /// assert!(crp.set_alpha(-1.0).is_err()); - /// assert!(crp.set_alpha(std::f64::INFINITY).is_err()); - /// assert!(crp.set_alpha(std::f64::NEG_INFINITY).is_err()); - /// assert!(crp.set_alpha(std::f64::NAN).is_err()); + /// assert!(crp.set_alpha(f64::INFINITY).is_err()); + /// assert!(crp.set_alpha(f64::NEG_INFINITY).is_err()); + /// assert!(crp.set_alpha(f64::NAN).is_err()); /// ``` #[inline] pub fn set_alpha(&mut self, alpha: f64) -> Result<(), CrpError> { @@ -195,7 +195,7 @@ impl From<&Crp> for String { impl_display!(Crp); -impl Rv for Crp { +impl HasDensity for Crp { fn ln_f(&self, x: &Partition) -> f64 { let gsum = x .counts() @@ -206,7 +206,9 @@ impl Rv for Crp { (x.k() as f64).mul_add(self.alpha.ln(), gsum) + ln_gammafn(self.alpha) - ln_gammafn(x.len() as f64 + self.alpha) } +} +impl Sampleable for Crp { fn draw(&self, rng: &mut R) -> Partition { let mut k = 1; let mut weights: Vec = vec![1.0]; @@ -261,14 +263,14 @@ impl fmt::Display for CrpError { #[cfg(test)] mod tests { use super::*; - use crate::test_basic_impls; + // use crate::test_basic_impls; const TOL: f64 = 1E-12; - test_basic_impls!( - Crp::new(1.0, 10).unwrap(), - Partition::new_unchecked(vec![0; 10], vec![10]) - ); + // test_basic_impls!( + // Crp::new(1.0, 10).unwrap(), + // Partition::new_unchecked(vec![0; 10], vec![10]) + // ); #[test] fn new() { diff --git a/src/dist/dirichlet.rs b/src/dist/dirichlet.rs index 1c84fb8..b47d761 100644 --- a/src/dist/dirichlet.rs +++ b/src/dist/dirichlet.rs @@ -30,6 +30,26 @@ pub struct SymmetricDirichlet { ln_gamma_alpha: OnceLock, } +pub struct SymmetricDirichletParameters { + pub alpha: f64, + pub k: usize, +} + +impl Parameterized for SymmetricDirichlet { + type Parameters = SymmetricDirichletParameters; + + fn emit_params(&self) -> Self::Parameters { + Self::Parameters { + alpha: self.alpha(), + k: self.k(), + } + } + + fn from_params(params: Self::Parameters) -> Self { + Self::new_unchecked(params.alpha, params.k) + } +} + impl PartialEq for SymmetricDirichlet { fn eq(&self, other: &Self) -> bool { self.alpha == other.alpha && self.k == other.k @@ -71,7 +91,7 @@ impl SymmetricDirichlet { } } - /// Create a new SymmetricDirichlet without checking whether the parmaeters + /// Create a new SymmetricDirichlet without checking whether the parameters /// are valid. #[inline] pub fn new_unchecked(alpha: f64, k: usize) -> Self { @@ -138,9 +158,9 @@ impl SymmetricDirichlet { /// assert!(symdir.set_alpha(0.5).is_ok()); /// assert!(symdir.set_alpha(0.0).is_err()); /// assert!(symdir.set_alpha(-1.0).is_err()); - /// assert!(symdir.set_alpha(std::f64::INFINITY).is_err()); - /// assert!(symdir.set_alpha(std::f64::NEG_INFINITY).is_err()); - /// assert!(symdir.set_alpha(std::f64::NAN).is_err()); + /// assert!(symdir.set_alpha(f64::INFINITY).is_err()); + /// assert!(symdir.set_alpha(f64::NEG_INFINITY).is_err()); + /// assert!(symdir.set_alpha(f64::NAN).is_err()); /// ``` #[inline] pub fn set_alpha( @@ -193,7 +213,7 @@ impl From<&SymmetricDirichlet> for String { impl_display!(SymmetricDirichlet); -impl Rv> for SymmetricDirichlet { +impl Sampleable> for SymmetricDirichlet { fn draw(&self, rng: &mut R) -> Vec { let g = RGamma::new(self.alpha, 1.0).unwrap(); let mut xs: Vec = (0..self.k).map(|_| rng.sample(g)).collect(); @@ -201,7 +221,9 @@ impl Rv> for SymmetricDirichlet { xs.iter_mut().for_each(|x| *x /= z); xs } +} +impl HasDensity> for SymmetricDirichlet { fn ln_f(&self, x: &Vec) -> f64 { let kf = self.k as f64; let sum_ln_gamma = self.ln_gamma_alpha() * kf; @@ -238,6 +260,24 @@ pub struct Dirichlet { pub(crate) alphas: Vec, } +pub struct DirichletParameters { + pub alphas: Vec, +} + +impl Parameterized for Dirichlet { + type Parameters = DirichletParameters; + + fn emit_params(&self) -> Self::Parameters { + Self::Parameters { + alphas: self.alphas().clone(), + } + } + + fn from_params(params: Self::Parameters) -> Self { + Self::new_unchecked(params.alphas) + } +} + impl From for Dirichlet { fn from(symdir: SymmetricDirichlet) -> Self { Dirichlet::new_unchecked(vec![symdir.alpha; symdir.k]) @@ -288,7 +328,7 @@ impl Dirichlet { /// /// ``` /// # use rv::dist::{Dirichlet, SymmetricDirichlet}; - /// # use rv::traits::Rv; + /// # use rv::traits::*; /// let dir = Dirichlet::symmetric(1.0, 4).unwrap(); /// assert_eq!(*dir.alphas(), vec![1.0, 1.0, 1.0, 1.0]); /// @@ -324,7 +364,7 @@ impl Dirichlet { /// ``` /// # use rv::dist::Dirichlet; /// # use rv::dist::SymmetricDirichlet; - /// # use rv::traits::Rv; + /// # use rv::traits::*; /// let dir = Dirichlet::jeffreys(3).unwrap(); /// assert_eq!(*dir.alphas(), vec![0.5, 0.5, 0.5]); /// @@ -376,7 +416,7 @@ impl Support> for SymmetricDirichlet { } } -impl Rv> for Dirichlet { +impl Sampleable> for Dirichlet { fn draw(&self, rng: &mut R) -> Vec { let gammas: Vec> = self .alphas @@ -388,7 +428,9 @@ impl Rv> for Dirichlet { xs.iter_mut().for_each(|x| *x /= z); xs } +} +impl HasDensity> for Dirichlet { fn ln_f(&self, x: &Vec) -> f64 { // XXX: could cache all ln_gamma(alpha) let sum_ln_gamma: f64 = self @@ -464,7 +506,7 @@ mod tests { mod dir { use super::*; - test_basic_impls!(Dirichlet::jeffreys(4).unwrap(), vec![0.25_f64; 4]); + test_basic_impls!(Vec, Dirichlet, Dirichlet::jeffreys(4).unwrap()); #[test] fn properly_sized_points_on_simplex_should_be_in_support() { @@ -491,7 +533,7 @@ mod tests { fn draws_should_be_in_support() { let mut rng = rand::thread_rng(); // Small alphas gives us more variability in the simplex, and more - // variability gives us a beter test. + // variability gives us a better test. let dir = Dirichlet::jeffreys(10).unwrap(); for _ in 0..100 { let x = dir.draw(&mut rng); @@ -544,8 +586,9 @@ mod tests { use super::*; test_basic_impls!( - SymmetricDirichlet::jeffreys(4).unwrap(), - vec![0.25_f64; 4] + Vec, + SymmetricDirichlet, + SymmetricDirichlet::jeffreys(4).unwrap() ); #[test] @@ -577,7 +620,7 @@ mod tests { fn draws_should_be_in_support() { let mut rng = rand::thread_rng(); // Small alphas gives us more variability in the simplex, and more - // variability gives us a beter test. + // variability gives us a better test. let symdir = SymmetricDirichlet::jeffreys(10).unwrap(); for _ in 0..100 { let x: Vec = symdir.draw(&mut rng); diff --git a/src/dist/dirichlet/categorical_prior.rs b/src/dist/dirichlet/categorical_prior.rs index 1d87f1a..380d29a 100644 --- a/src/dist/dirichlet/categorical_prior.rs +++ b/src/dist/dirichlet/categorical_prior.rs @@ -6,11 +6,13 @@ use crate::misc::ln_gammafn; use crate::prelude::CategoricalData; use crate::traits::*; -impl Rv for SymmetricDirichlet { +impl HasDensity for SymmetricDirichlet { fn ln_f(&self, x: &Categorical) -> f64 { self.ln_f(&x.weights()) } +} +impl Sampleable for SymmetricDirichlet { fn draw(&self, mut rng: &mut R) -> Categorical { let weights: Vec = self.draw(&mut rng); Categorical::new(&weights).expect("Invalid draw") @@ -21,8 +23,8 @@ impl ConjugatePrior for SymmetricDirichlet { type Posterior = Dirichlet; - type LnMCache = f64; - type LnPpCache = (Vec, f64); + type MCache = f64; + type PpCache = (Vec, f64); fn posterior(&self, x: &CategoricalData) -> Self::Posterior { extract_stat_then( @@ -38,7 +40,7 @@ impl ConjugatePrior } #[inline] - fn ln_m_cache(&self) -> Self::LnMCache { + fn ln_m_cache(&self) -> Self::MCache { let sum_alpha = self.alpha() * self.k() as f64; let a = ln_gammafn(sum_alpha); let d = ln_gammafn(self.alpha()) * self.k() as f64; @@ -47,7 +49,7 @@ impl ConjugatePrior fn ln_m_with_cache( &self, - cache: &Self::LnMCache, + cache: &Self::MCache, x: &CategoricalData, ) -> f64 { let sum_alpha = self.alpha() * self.k() as f64; @@ -69,23 +71,25 @@ impl ConjugatePrior } #[inline] - fn ln_pp_cache(&self, x: &CategoricalData) -> Self::LnPpCache { + fn ln_pp_cache(&self, x: &CategoricalData) -> Self::PpCache { let post = self.posterior(x); let norm = post.alphas().iter().fold(0.0, |acc, &a| acc + a); (post.alphas, norm.ln()) } - fn ln_pp_with_cache(&self, cache: &Self::LnPpCache, y: &X) -> f64 { + fn ln_pp_with_cache(&self, cache: &Self::PpCache, y: &X) -> f64 { let ix = y.into_usize(); cache.0[ix].ln() - cache.1 } } -impl Rv for Dirichlet { +impl HasDensity for Dirichlet { fn ln_f(&self, x: &Categorical) -> f64 { self.ln_f(&x.weights()) } +} +impl Sampleable for Dirichlet { fn draw(&self, mut rng: &mut R) -> Categorical { let weights: Vec = self.draw(&mut rng); Categorical::new(&weights).expect("Invalid draw") @@ -94,8 +98,8 @@ impl Rv for Dirichlet { impl ConjugatePrior for Dirichlet { type Posterior = Self; - type LnMCache = (f64, f64); - type LnPpCache = (Vec, f64); + type MCache = (f64, f64); + type PpCache = (Vec, f64); fn posterior(&self, x: &CategoricalData) -> Self::Posterior { extract_stat_then( @@ -115,7 +119,7 @@ impl ConjugatePrior for Dirichlet { } #[inline] - fn ln_m_cache(&self) -> Self::LnMCache { + fn ln_m_cache(&self) -> Self::MCache { let sum_alpha = self.alphas().iter().fold(0.0, |acc, &a| acc + a); let a = ln_gammafn(sum_alpha); let d = self @@ -127,7 +131,7 @@ impl ConjugatePrior for Dirichlet { fn ln_m_with_cache( &self, - cache: &Self::LnMCache, + cache: &Self::MCache, x: &CategoricalData, ) -> f64 { let (sum_alpha, ln_norm) = cache; @@ -141,7 +145,8 @@ impl ConjugatePrior for Dirichlet { .alphas() .iter() .zip(stat.counts().iter()) - .fold(0.0, |acc, (&a, &ct)| ln_gammafn(acc + (a + ct))); + .map(|(&a, &ct)| ln_gammafn(a + ct)) + .sum::(); -b + c + ln_norm }, @@ -149,13 +154,13 @@ impl ConjugatePrior for Dirichlet { } #[inline] - fn ln_pp_cache(&self, x: &CategoricalData) -> Self::LnPpCache { + fn ln_pp_cache(&self, x: &CategoricalData) -> Self::PpCache { let post = self.posterior(x); let norm = post.alphas().iter().fold(0.0, |acc, &a| acc + a); (post.alphas, norm.ln()) } - fn ln_pp_with_cache(&self, cache: &Self::LnPpCache, y: &X) -> f64 { + fn ln_pp_with_cache(&self, cache: &Self::PpCache, y: &X) -> f64 { let ix = y.into_usize(); cache.0[ix].ln() - cache.1 } @@ -165,14 +170,35 @@ impl ConjugatePrior for Dirichlet { mod test { use super::*; use crate::data::DataOrSuffStat; + use crate::test_conjugate_prior; const TOL: f64 = 1E-12; type CategoricalData<'a, X> = DataOrSuffStat<'a, X, Categorical>; + mod dir { + use super::*; + + test_conjugate_prior!( + u8, + Categorical, + Dirichlet, + Dirichlet::new(vec![1.0, 2.0]).unwrap(), + n = 1_000_000 + ); + } + mod symmetric { use super::*; + test_conjugate_prior!( + u8, + Categorical, + SymmetricDirichlet, + SymmetricDirichlet::jeffreys(2).unwrap(), + n = 1_000_000 + ); + #[test] fn marginal_likelihood_u8_1() { let alpha = 1.0; diff --git a/src/dist/discrete_uniform.rs b/src/dist/discrete_uniform.rs index 63a7c8a..3bc1c11 100644 --- a/src/dist/discrete_uniform.rs +++ b/src/dist/discrete_uniform.rs @@ -30,6 +30,26 @@ pub enum DiscreteUniformError { InvalidInterval, } +pub struct DiscreteUniformParameters { + pub a: T, + pub b: T, +} + +impl Parameterized for DiscreteUniform { + type Parameters = DiscreteUniformParameters; + + fn emit_params(&self) -> Self::Parameters { + Self::Parameters { + a: self.a(), + b: self.b(), + } + } + + fn from_params(params: Self::Parameters) -> Self { + Self::new_unchecked(params.a, params.b) + } +} + impl DiscreteUniform { /// Create a new discreet uniform distribution /// @@ -99,7 +119,7 @@ where } } -impl Rv for DiscreteUniform +impl HasDensity for DiscreteUniform where T: DuParam + SampleUniform + Copy, X: Integer + From, @@ -111,7 +131,13 @@ where f64::NEG_INFINITY } } +} +impl Sampleable for DiscreteUniform +where + T: DuParam + SampleUniform + Copy, + X: Integer + From, +{ fn draw(&self, rng: &mut R) -> X { let d = rand::distributions::Uniform::new_inclusive(self.a, self.b); X::from(rng.sample(d)) @@ -247,7 +273,11 @@ mod tests { const KS_PVAL: f64 = 0.2; const N_TRIES: usize = 5; - test_basic_impls!([count] DiscreteUniform::new(0_u32, 10_u32).unwrap()); + test_basic_impls!( + u32, + DiscreteUniform, + DiscreteUniform::new(0_u32, 10_u32).unwrap() + ); #[test] fn new() { diff --git a/src/dist/distribution.rs b/src/dist/distribution.rs deleted file mode 100644 index 77d8ff4..0000000 --- a/src/dist/distribution.rs +++ /dev/null @@ -1,379 +0,0 @@ -#[cfg(feature = "serde1")] -use serde::{Deserialize, Serialize}; - -use crate::data::Datum; -use crate::traits::Rv; - -/// Represents any distribution -#[non_exhaustive] -#[derive(Clone, Debug, PartialEq)] -#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] -pub enum Distribution { - Bernoulli(super::Bernoulli), - Beta(super::Beta), - BetaBinomial(super::BetaBinomial), - Binomial(super::Binomial), - Categorical(super::Categorical), - Cauchy(super::Cauchy), - ChiSquared(super::ChiSquared), - Dirichlet(super::Dirichlet), - SymmetricDirichlet(super::SymmetricDirichlet), - Exponential(super::Exponential), - Gamma(super::Gamma), - Gaussian(super::Gaussian), - Geometric(super::Geometric), - Gev(super::Gev), - InvChiSquared(super::InvChiSquared), - InvGamma(super::InvGamma), - InvGaussian(super::InvGaussian), - KsTwoAsymptotic(super::KsTwoAsymptotic), - Kumaraswamy(super::Kumaraswamy), - Laplace(super::Laplace), - LogNormal(super::LogNormal), - #[cfg(feature = "arraydist")] - MvGaussian(super::MvGaussian), - NegBinomial(super::NegBinomial), - Pareto(super::Pareto), - Poisson(super::Poisson), - Product(super::ProductDistribution), - ScaledInvChiSquared(super::ScaledInvChiSquared), - Skellam(super::Skellam), - StudentsT(super::StudentsT), - Uniform(super::Uniform), - VonMises(super::VonMises), - #[cfg(feature = "arraydist")] - InvWishart(super::InvWishart), -} - -impl Rv for Distribution { - fn f(&self, x: &Datum) -> f64 { - match self { - Distribution::Bernoulli(inner) => inner.f(x), - Distribution::Beta(inner) => inner.f(x), - Distribution::BetaBinomial(inner) => inner.f(x), - Distribution::Binomial(inner) => inner.f(x), - Distribution::Categorical(inner) => inner.f(x), - Distribution::Cauchy(inner) => inner.f(x), - Distribution::ChiSquared(inner) => inner.f(x), - Distribution::Dirichlet(inner) => inner.f(x), - Distribution::SymmetricDirichlet(inner) => inner.f(x), - Distribution::Exponential(inner) => inner.f(x), - Distribution::Gamma(inner) => inner.f(x), - Distribution::Gaussian(inner) => inner.f(x), - Distribution::Geometric(inner) => inner.f(x), - Distribution::Gev(inner) => inner.f(x), - Distribution::InvChiSquared(inner) => inner.f(x), - Distribution::InvGamma(inner) => inner.f(x), - Distribution::InvGaussian(inner) => inner.f(x), - Distribution::KsTwoAsymptotic(inner) => inner.f(x), - Distribution::Kumaraswamy(inner) => inner.f(x), - Distribution::Laplace(inner) => inner.f(x), - Distribution::LogNormal(inner) => inner.f(x), - #[cfg(feature = "arraydist")] - Distribution::MvGaussian(inner) => inner.f(x), - Distribution::NegBinomial(inner) => inner.f(x), - Distribution::Pareto(inner) => inner.f(x), - Distribution::Poisson(inner) => inner.f(x), - Distribution::Product(inner) => inner.f(x), - Distribution::ScaledInvChiSquared(inner) => inner.f(x), - Distribution::Skellam(inner) => inner.f(x), - Distribution::StudentsT(inner) => inner.f(x), - Distribution::Uniform(inner) => inner.f(x), - Distribution::VonMises(inner) => inner.f(x), - #[cfg(feature = "arraydist")] - Distribution::InvWishart(inner) => inner.f(x), - } - } - - fn ln_f(&self, x: &Datum) -> f64 { - match self { - Distribution::Bernoulli(inner) => inner.ln_f(x), - Distribution::Beta(inner) => inner.ln_f(x), - Distribution::BetaBinomial(inner) => inner.ln_f(x), - Distribution::Binomial(inner) => inner.ln_f(x), - Distribution::Categorical(inner) => inner.ln_f(x), - Distribution::Cauchy(inner) => inner.ln_f(x), - Distribution::ChiSquared(inner) => inner.ln_f(x), - Distribution::Dirichlet(inner) => inner.ln_f(x), - Distribution::SymmetricDirichlet(inner) => inner.ln_f(x), - Distribution::Exponential(inner) => inner.ln_f(x), - Distribution::Gamma(inner) => inner.ln_f(x), - Distribution::Gaussian(inner) => inner.ln_f(x), - Distribution::Geometric(inner) => inner.ln_f(x), - Distribution::Gev(inner) => inner.ln_f(x), - Distribution::InvChiSquared(inner) => inner.ln_f(x), - Distribution::InvGamma(inner) => inner.ln_f(x), - Distribution::InvGaussian(inner) => inner.ln_f(x), - Distribution::KsTwoAsymptotic(inner) => inner.ln_f(x), - Distribution::Kumaraswamy(inner) => inner.ln_f(x), - Distribution::Laplace(inner) => inner.ln_f(x), - Distribution::LogNormal(inner) => inner.ln_f(x), - #[cfg(feature = "arraydist")] - Distribution::MvGaussian(inner) => inner.ln_f(x), - Distribution::NegBinomial(inner) => inner.ln_f(x), - Distribution::Pareto(inner) => inner.ln_f(x), - Distribution::Poisson(inner) => inner.ln_f(x), - Distribution::Product(inner) => inner.ln_f(x), - Distribution::ScaledInvChiSquared(inner) => inner.ln_f(x), - Distribution::Skellam(inner) => inner.ln_f(x), - Distribution::StudentsT(inner) => inner.ln_f(x), - Distribution::Uniform(inner) => inner.ln_f(x), - Distribution::VonMises(inner) => inner.ln_f(x), - #[cfg(feature = "arraydist")] - Distribution::InvWishart(inner) => inner.ln_f(x), - } - } - - fn draw(&self, rng: &mut R) -> Datum { - match self { - Distribution::Bernoulli(inner) => inner.draw(rng), - Distribution::Beta(inner) => inner.draw(rng), - Distribution::BetaBinomial(inner) => inner.draw(rng), - Distribution::Binomial(inner) => inner.draw(rng), - Distribution::Categorical(inner) => inner.draw(rng), - Distribution::Cauchy(inner) => inner.draw(rng), - Distribution::ChiSquared(inner) => inner.draw(rng), - Distribution::Dirichlet(inner) => inner.draw(rng), - Distribution::SymmetricDirichlet(inner) => inner.draw(rng), - Distribution::Exponential(inner) => inner.draw(rng), - Distribution::Gamma(inner) => inner.draw(rng), - Distribution::Gaussian(inner) => inner.draw(rng), - Distribution::Geometric(inner) => inner.draw(rng), - Distribution::Gev(inner) => inner.draw(rng), - Distribution::InvChiSquared(inner) => inner.draw(rng), - Distribution::InvGamma(inner) => inner.draw(rng), - Distribution::InvGaussian(inner) => inner.draw(rng), - Distribution::KsTwoAsymptotic(inner) => inner.draw(rng), - Distribution::Kumaraswamy(inner) => inner.draw(rng), - Distribution::Laplace(inner) => inner.draw(rng), - Distribution::LogNormal(inner) => inner.draw(rng), - #[cfg(feature = "arraydist")] - Distribution::MvGaussian(inner) => inner.draw(rng), - Distribution::NegBinomial(inner) => inner.draw(rng), - Distribution::Pareto(inner) => inner.draw(rng), - Distribution::Poisson(inner) => inner.draw(rng), - Distribution::Product(inner) => inner.draw(rng), - Distribution::ScaledInvChiSquared(inner) => inner.draw(rng), - Distribution::Skellam(inner) => inner.draw(rng), - Distribution::StudentsT(inner) => inner.draw(rng), - Distribution::Uniform(inner) => inner.draw(rng), - Distribution::VonMises(inner) => inner.draw(rng), - #[cfg(feature = "arraydist")] - Distribution::InvWishart(inner) => inner.draw(rng), - } - } - - fn sample(&self, n: usize, rng: &mut R) -> Vec { - match self { - Distribution::Bernoulli(inner) => inner.sample(n, rng), - Distribution::Beta(inner) => inner.sample(n, rng), - Distribution::BetaBinomial(inner) => inner.sample(n, rng), - Distribution::Binomial(inner) => inner.sample(n, rng), - Distribution::Categorical(inner) => inner.sample(n, rng), - Distribution::Cauchy(inner) => inner.sample(n, rng), - Distribution::ChiSquared(inner) => inner.sample(n, rng), - Distribution::Dirichlet(inner) => inner.sample(n, rng), - Distribution::SymmetricDirichlet(inner) => inner.sample(n, rng), - Distribution::Exponential(inner) => inner.sample(n, rng), - Distribution::Gamma(inner) => inner.sample(n, rng), - Distribution::Gaussian(inner) => inner.sample(n, rng), - Distribution::Geometric(inner) => inner.sample(n, rng), - Distribution::Gev(inner) => inner.sample(n, rng), - Distribution::InvChiSquared(inner) => inner.sample(n, rng), - Distribution::InvGamma(inner) => inner.sample(n, rng), - Distribution::InvGaussian(inner) => inner.sample(n, rng), - Distribution::KsTwoAsymptotic(inner) => inner.sample(n, rng), - Distribution::Kumaraswamy(inner) => inner.sample(n, rng), - Distribution::Laplace(inner) => inner.sample(n, rng), - Distribution::LogNormal(inner) => inner.sample(n, rng), - #[cfg(feature = "arraydist")] - Distribution::MvGaussian(inner) => inner.sample(n, rng), - Distribution::NegBinomial(inner) => inner.sample(n, rng), - Distribution::Pareto(inner) => inner.sample(n, rng), - Distribution::Poisson(inner) => inner.sample(n, rng), - Distribution::Product(inner) => inner.sample(n, rng), - Distribution::ScaledInvChiSquared(inner) => inner.sample(n, rng), - Distribution::Skellam(inner) => inner.sample(n, rng), - Distribution::StudentsT(inner) => inner.sample(n, rng), - Distribution::Uniform(inner) => inner.sample(n, rng), - Distribution::VonMises(inner) => inner.sample(n, rng), - #[cfg(feature = "arraydist")] - Distribution::InvWishart(inner) => inner.sample(n, rng), - } - } - - fn sample_stream<'r, R: rand::Rng>( - &'r self, - rng: &'r mut R, - ) -> Box + 'r> { - match self { - Distribution::Bernoulli(inner) => inner.sample_stream(rng), - Distribution::Beta(inner) => inner.sample_stream(rng), - Distribution::BetaBinomial(inner) => inner.sample_stream(rng), - Distribution::Binomial(inner) => inner.sample_stream(rng), - Distribution::Categorical(inner) => inner.sample_stream(rng), - Distribution::Cauchy(inner) => inner.sample_stream(rng), - Distribution::ChiSquared(inner) => inner.sample_stream(rng), - Distribution::Dirichlet(inner) => inner.sample_stream(rng), - Distribution::SymmetricDirichlet(inner) => inner.sample_stream(rng), - Distribution::Exponential(inner) => inner.sample_stream(rng), - Distribution::Gamma(inner) => inner.sample_stream(rng), - Distribution::Gaussian(inner) => inner.sample_stream(rng), - Distribution::Geometric(inner) => inner.sample_stream(rng), - Distribution::Gev(inner) => inner.sample_stream(rng), - Distribution::InvChiSquared(inner) => inner.sample_stream(rng), - Distribution::InvGamma(inner) => inner.sample_stream(rng), - Distribution::InvGaussian(inner) => inner.sample_stream(rng), - Distribution::KsTwoAsymptotic(inner) => inner.sample_stream(rng), - Distribution::Kumaraswamy(inner) => inner.sample_stream(rng), - Distribution::Laplace(inner) => inner.sample_stream(rng), - Distribution::LogNormal(inner) => inner.sample_stream(rng), - #[cfg(feature = "arraydist")] - Distribution::MvGaussian(inner) => inner.sample_stream(rng), - Distribution::NegBinomial(inner) => inner.sample_stream(rng), - Distribution::Pareto(inner) => inner.sample_stream(rng), - Distribution::Poisson(inner) => inner.sample_stream(rng), - Distribution::Product(inner) => inner.sample_stream(rng), - Distribution::ScaledInvChiSquared(inner) => { - inner.sample_stream(rng) - } - Distribution::Skellam(inner) => inner.sample_stream(rng), - Distribution::StudentsT(inner) => inner.sample_stream(rng), - Distribution::Uniform(inner) => inner.sample_stream(rng), - Distribution::VonMises(inner) => inner.sample_stream(rng), - #[cfg(feature = "arraydist")] - Distribution::InvWishart(inner) => inner.sample_stream(rng), - } - } -} - -impl Rv for super::Mixture> { - fn ln_f(&self, x: &Datum) -> f64 { - if let Datum::Compound(xs) = x { - assert_eq!(xs.len(), self.components()[0].len()); - let ln_fs: Vec = self - .weights() - .iter() - .zip(self.components().iter()) - .map(|(&w, cpnts)| { - w.ln() - + xs.iter() - .zip(cpnts.iter()) - .map(|(x, cpnt)| cpnt.ln_f(x)) - .sum::() - }) - .collect(); - crate::misc::logsumexp(&ln_fs) - } else { - panic!("Mixture of Vec accepts Datum::Compound") - } - } - - fn sample(&self, n: usize, rng: &mut R) -> Vec { - let cpnt_ixs = crate::misc::pflip(self.weights(), n, rng); - cpnt_ixs - .iter() - .map(|&ix| { - let data = self.components()[ix] - .iter() - .map(|cpnt| cpnt.draw(rng)) - .collect(); - Datum::Compound(data) - }) - .collect() - } - - fn draw(&self, rng: &mut R) -> Datum { - self.sample(1, rng).pop().unwrap() - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::test_basic_impls; - - test_basic_impls!( - Distribution::Bernoulli(crate::dist::Bernoulli::uniform()), - Datum::Bool(true), - bernoulli - ); - - test_basic_impls!( - Distribution::Beta(crate::dist::Beta::jeffreys()), - Datum::F64(0.5), - beta - ); - - test_basic_impls!( - Distribution::BetaBinomial( - crate::dist::BetaBinomial::new(10, 0.5, 1.2).unwrap() - ), - Datum::U32(3), - beta_binom - ); - - test_basic_impls!( - Distribution::Binomial(crate::dist::Binomial::new(10, 0.5).unwrap()), - Datum::U32(3), - binom - ); - - test_basic_impls!( - Distribution::Categorical(crate::dist::Categorical::uniform(4)), - Datum::U8(2), - categorical - ); - - test_basic_impls!( - Distribution::Cauchy(crate::dist::Cauchy::new(0.5, 1.0).unwrap()), - Datum::F64(2.0), - cauchy - ); - - test_basic_impls!( - Distribution::ChiSquared(crate::dist::ChiSquared::new(0.5).unwrap()), - Datum::F64(2.0), - chi_squared - ); - - test_basic_impls!( - Distribution::Dirichlet( - crate::dist::Dirichlet::new(vec![5.0, 2.0, 0.5]).unwrap() - ), - Datum::Vec(vec![0.2, 0.1, 0.7]), - dirichlet - ); - - test_basic_impls!( - Distribution::SymmetricDirichlet( - crate::dist::SymmetricDirichlet::new(0.5, 3).unwrap() - ), - Datum::Vec(vec![0.2, 0.1, 0.7]), - symmetric_dirichlet - ); - - test_basic_impls!( - Distribution::Exponential(crate::dist::Exponential::new(0.5).unwrap()), - Datum::F64(2.0), - exponential - ); - - test_basic_impls!( - Distribution::Gamma(crate::dist::Gamma::new(0.5, 1.0).unwrap()), - Datum::F64(2.0), - gamma - ); - - test_basic_impls!( - Distribution::Gaussian(crate::dist::Gaussian::standard()), - Datum::F64(0.5), - gaussian - ); - - test_basic_impls!( - Distribution::Geometric(crate::dist::Geometric::new(0.5).unwrap()), - Datum::U16(2), - geometric - ); -} diff --git a/src/dist/empirical.rs b/src/dist/empirical.rs index 87ae89f..54d8d8d 100644 --- a/src/dist/empirical.rs +++ b/src/dist/empirical.rs @@ -1,7 +1,7 @@ #[cfg(feature = "serde1")] use serde::{Deserialize, Serialize}; -use crate::traits::{Cdf, Mean, Rv, Variance}; +use crate::traits::*; use rand::Rng; /// An empirical distribution derived from samples. @@ -11,7 +11,7 @@ use rand::Rng; /// /// ```rust /// use rv::dist::{Gaussian, Empirical}; -/// use rv::prelude::Rv; +/// use rv::prelude::*; /// use rv::misc::linspace; /// use rand_xoshiro::Xoshiro256Plus; /// use rand::SeedableRng; @@ -44,6 +44,24 @@ enum Pos { Absent(usize), } +pub struct EmpiricalParameters { + pub xs: Vec, +} + +impl Parameterized for Empirical { + type Parameters = EmpiricalParameters; + + fn emit_params(&self) -> Self::Parameters { + Self::Parameters { + xs: self.xs.clone(), + } + } + + fn from_params(params: Self::Parameters) -> Self { + Self::new(params.xs) + } +} + impl Empirical { /// Create a new Empirical distribution with the given observed values pub fn new(mut xs: Vec) -> Self { @@ -121,7 +139,7 @@ impl Empirical { } } -impl Rv for Empirical { +impl HasDensity for Empirical { fn f(&self, x: &f64) -> f64 { eprintln!("WARNING: empirical.f is unstable. You probably don't want to use it."); match self.pos(*x) { @@ -148,7 +166,9 @@ impl Rv for Empirical { fn ln_f(&self, x: &f64) -> f64 { self.f(x).ln() } +} +impl Sampleable for Empirical { fn draw(&self, rng: &mut R) -> f64 { let n = self.xs.len(); let ix: usize = rng.gen_range(0..n); diff --git a/src/dist/exponential.rs b/src/dist/exponential.rs index 5971024..dcf4932 100644 --- a/src/dist/exponential.rs +++ b/src/dist/exponential.rs @@ -31,6 +31,24 @@ pub struct Exponential { rate: f64, } +impl Default for Exponential { + fn default() -> Self { + Self::new_unchecked(1.0) + } +} + +impl Parameterized for Exponential { + type Parameters = f64; + + fn emit_params(&self) -> Self::Parameters { + self.rate() + } + + fn from_params(rate: Self::Parameters) -> Self { + Self::new_unchecked(rate) + } +} + #[derive(Debug, Clone, PartialEq)] #[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] #[cfg_attr(feature = "serde1", serde(rename_all = "snake_case"))] @@ -99,9 +117,9 @@ impl Exponential { /// assert!(expon.set_rate(0.1).is_ok()); /// assert!(expon.set_rate(0.0).is_err()); /// assert!(expon.set_rate(-1.0).is_err()); - /// assert!(expon.set_rate(std::f64::INFINITY).is_err()); - /// assert!(expon.set_rate(std::f64::NEG_INFINITY).is_err()); - /// assert!(expon.set_rate(std::f64::NAN).is_err()); + /// assert!(expon.set_rate(f64::INFINITY).is_err()); + /// assert!(expon.set_rate(f64::NEG_INFINITY).is_err()); + /// assert!(expon.set_rate(f64::NAN).is_err()); /// ``` #[inline] pub fn set_rate(&mut self, rate: f64) -> Result<(), ExponentialError> { @@ -132,7 +150,7 @@ impl_display!(Exponential); macro_rules! impl_traits { ($kind:ty) => { - impl Rv<$kind> for Exponential { + impl HasDensity<$kind> for Exponential { fn ln_f(&self, x: &$kind) -> f64 { // TODO: could cache ln(rate) if x < &0.0 { @@ -141,7 +159,9 @@ macro_rules! impl_traits { self.rate.mul_add(-f64::from(*x), self.rate.ln()) } } + } + impl Sampleable<$kind> for Exponential { fn draw(&self, rng: &mut R) -> $kind { let expdist = Exp::new(self.rate).unwrap(); rng.sample(expdist) as $kind @@ -248,13 +268,12 @@ mod tests { use super::*; use crate::misc::ks_test; use crate::test_basic_impls; - use std::f64; const TOL: f64 = 1E-12; const KS_PVAL: f64 = 0.2; const N_TRIES: usize = 5; - test_basic_impls!([continuous] Exponential::new(1.0).unwrap()); + test_basic_impls!(f64, Exponential); #[test] fn new() { diff --git a/src/dist/gamma.rs b/src/dist/gamma.rs index f2dce43..07b791c 100644 --- a/src/dist/gamma.rs +++ b/src/dist/gamma.rs @@ -37,6 +37,26 @@ pub struct Gamma { ln_rate: OnceLock, } +pub struct GammaParameters { + pub shape: f64, + pub rate: f64, +} + +impl Parameterized for Gamma { + type Parameters = GammaParameters; + + fn emit_params(&self) -> Self::Parameters { + Self::Parameters { + shape: self.shape(), + rate: self.rate(), + } + } + + fn from_params(params: Self::Parameters) -> Self { + Self::new_unchecked(params.shape, params.rate) + } +} + impl PartialEq for Gamma { fn eq(&self, other: &Gamma) -> bool { self.shape == other.shape && self.rate == other.rate @@ -131,9 +151,9 @@ impl Gamma { /// assert!(gam.set_shape(1.1).is_ok()); /// assert!(gam.set_shape(0.0).is_err()); /// assert!(gam.set_shape(-1.0).is_err()); - /// assert!(gam.set_shape(std::f64::INFINITY).is_err()); - /// assert!(gam.set_shape(std::f64::NEG_INFINITY).is_err()); - /// assert!(gam.set_shape(std::f64::NAN).is_err()); + /// assert!(gam.set_shape(f64::INFINITY).is_err()); + /// assert!(gam.set_shape(f64::NEG_INFINITY).is_err()); + /// assert!(gam.set_shape(f64::NAN).is_err()); /// ``` #[inline] pub fn set_shape(&mut self, shape: f64) -> Result<(), GammaError> { @@ -189,9 +209,9 @@ impl Gamma { /// assert!(gam.set_rate(1.1).is_ok()); /// assert!(gam.set_rate(0.0).is_err()); /// assert!(gam.set_rate(-1.0).is_err()); - /// assert!(gam.set_rate(std::f64::INFINITY).is_err()); - /// assert!(gam.set_rate(std::f64::NEG_INFINITY).is_err()); - /// assert!(gam.set_rate(std::f64::NAN).is_err()); + /// assert!(gam.set_rate(f64::INFINITY).is_err()); + /// assert!(gam.set_rate(f64::NEG_INFINITY).is_err()); + /// assert!(gam.set_rate(f64::NAN).is_err()); /// ``` #[inline] pub fn set_rate(&mut self, rate: f64) -> Result<(), GammaError> { @@ -229,7 +249,7 @@ impl_display!(Gamma); macro_rules! impl_traits { ($kind:ty) => { - impl Rv<$kind> for Gamma { + impl HasDensity<$kind> for Gamma { fn ln_f(&self, x: &$kind) -> f64 { self.shape.mul_add(self.ln_rate(), -self.ln_gamma_shape()) + (self.shape - 1.0).mul_add( @@ -237,7 +257,9 @@ macro_rules! impl_traits { -(self.rate * f64::from(*x)), ) } + } + impl Sampleable<$kind> for Gamma { fn draw(&self, rng: &mut R) -> $kind { let g = rand_distr::Gamma::new(self.shape, 1.0 / self.rate) .unwrap(); @@ -345,7 +367,7 @@ mod tests { const KS_PVAL: f64 = 0.2; const N_TRIES: usize = 5; - test_basic_impls!([continuous] Gamma::default()); + test_basic_impls!(f64, Gamma, Gamma::new_unchecked(1.0, 2.0)); #[test] fn new() { @@ -379,7 +401,7 @@ mod tests { } #[test] - fn ln_pdf_hight_value() { + fn ln_pdf_high_value() { let gam = Gamma::new(1.2, 3.4).unwrap(); assert::close( gam.ln_pdf(&0.352_941_176_470_588_26_f64), diff --git a/src/dist/gamma/poisson_prior.rs b/src/dist/gamma/poisson_prior.rs index 17440c5..fffa9f8 100644 --- a/src/dist/gamma/poisson_prior.rs +++ b/src/dist/gamma/poisson_prior.rs @@ -1,27 +1,27 @@ -use std::f64::EPSILON; - use rand::Rng; -use crate::data::{DataOrSuffStat, PoissonSuffStat}; +use crate::data::PoissonSuffStat; use crate::dist::poisson::PoissonError; use crate::dist::{Gamma, Poisson}; use crate::misc::ln_binom; use crate::traits::*; -impl Rv for Gamma { +impl HasDensity for Gamma { fn ln_f(&self, x: &Poisson) -> f64 { match x.mean() { Some(mean) => self.ln_f(&mean), - None => std::f64::NEG_INFINITY, + None => f64::NEG_INFINITY, } } +} +impl Sampleable for Gamma { fn draw(&self, mut rng: &mut R) -> Poisson { let mean: f64 = self.draw(&mut rng); match Poisson::new(mean) { Ok(pois) => pois, Err(PoissonError::RateTooLow { .. }) => { - Poisson::new_unchecked(EPSILON) + Poisson::new_unchecked(f64::EPSILON) } Err(err) => panic!("Failed to draw Possion: {}", err), } @@ -43,8 +43,8 @@ macro_rules! impl_traits { ($kind: ty) => { impl ConjugatePrior<$kind, Poisson> for Gamma { type Posterior = Self; - type LnMCache = f64; - type LnPpCache = (f64, f64, f64); + type MCache = f64; + type PpCache = (f64, f64, f64); fn posterior(&self, x: &DataOrSuffStat<$kind, Poisson>) -> Self { let (n, sum) = match x { @@ -56,7 +56,6 @@ macro_rules! impl_traits { DataOrSuffStat::SuffStat(ref stat) => { (stat.n(), stat.sum()) } - DataOrSuffStat::None => (0, 0.0), }; let a = self.shape() + sum; @@ -65,7 +64,7 @@ macro_rules! impl_traits { } #[inline] - fn ln_m_cache(&self) -> Self::LnMCache { + fn ln_m_cache(&self) -> Self::MCache { let z0 = self .shape() .mul_add(-self.ln_rate(), self.ln_gamma_shape()); @@ -74,7 +73,7 @@ macro_rules! impl_traits { fn ln_m_with_cache( &self, - cache: &Self::LnMCache, + cache: &Self::MCache, x: &DataOrSuffStat<$kind, Poisson>, ) -> f64 { let stat: PoissonSuffStat = match x { @@ -84,7 +83,6 @@ macro_rules! impl_traits { stat } DataOrSuffStat::SuffStat(ref stat) => (*stat).clone(), - DataOrSuffStat::None => PoissonSuffStat::new(), }; let data_or_suff: DataOrSuffStat<$kind, Poisson> = @@ -102,7 +100,7 @@ macro_rules! impl_traits { fn ln_pp_cache( &self, x: &DataOrSuffStat<$kind, Poisson>, - ) -> Self::LnPpCache { + ) -> Self::PpCache { let post = self.posterior(x); let r = post.shape(); let p = 1.0 / (1.0 + post.rate()); @@ -111,7 +109,7 @@ macro_rules! impl_traits { fn ln_pp_with_cache( &self, - cache: &Self::LnPpCache, + cache: &Self::PpCache, y: &$kind, ) -> f64 { let (r, p, ln_p) = cache; @@ -130,8 +128,12 @@ impl_traits!(u32); #[cfg(test)] mod tests { use super::*; + use crate::test_conjugate_prior; + const TOL: f64 = 1E-12; + test_conjugate_prior!(u32, Poisson, Gamma, Gamma::new(2.0, 1.2).unwrap()); + #[test] fn posterior_from_data() { let data: Vec = vec![1, 2, 3, 4, 5]; @@ -145,7 +147,8 @@ mod tests { #[test] fn ln_m_no_data() { let dist = Gamma::new(1.0, 1.0).unwrap(); - let data: DataOrSuffStat = DataOrSuffStat::None; + let new_vec = Vec::new(); + let data: DataOrSuffStat = DataOrSuffStat::from(&new_vec); assert::close(dist.ln_m(&data), 0.0, TOL); } @@ -195,7 +198,7 @@ mod tests { for i in 0..inputs.len() { assert::close( - dist.ln_pp(&inputs[i], &DataOrSuffStat::None), + dist.ln_pp(&inputs[i], &DataOrSuffStat::from(&vec![])), expected[i], TOL, ) @@ -229,7 +232,8 @@ mod tests { fn cannot_draw_zero_rate() { let mut rng = rand::thread_rng(); let dist = Gamma::new(1.0, 1e-10).unwrap(); - let stream = >::sample_stream(&dist, &mut rng); + let stream = + >::sample_stream(&dist, &mut rng); assert!(stream.take(10_000).all(|pois| pois.rate() > 0.0)); } } diff --git a/src/dist/gaussian.rs b/src/dist/gaussian.rs index c8586df..7e53d51 100644 --- a/src/dist/gaussian.rs +++ b/src/dist/gaussian.rs @@ -55,6 +55,26 @@ impl PartialEq for Gaussian { } } +pub struct GaussianParameters { + pub mu: f64, + pub sigma: f64, +} + +impl Parameterized for Gaussian { + type Parameters = GaussianParameters; + + fn emit_params(&self) -> Self::Parameters { + Self::Parameters { + mu: self.mu(), + sigma: self.sigma(), + } + } + + fn from_params(params: Self::Parameters) -> Self { + Self::new_unchecked(params.mu, params.sigma) + } +} + #[derive(Debug, Clone, PartialEq)] #[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] #[cfg_attr(feature = "serde1", serde(rename_all = "snake_case"))] @@ -70,7 +90,7 @@ pub enum GaussianError { impl Gaussian { /// Create a new Gaussian distribution /// - /// # Aruments + /// # Arguments /// - mu: mean /// - sigma: standard deviation pub fn new(mu: f64, sigma: f64) -> Result { @@ -153,9 +173,9 @@ impl Gaussian { /// # use rv::dist::Gaussian; /// # let mut gauss = Gaussian::new(2.0, 1.5).unwrap(); /// assert!(gauss.set_mu(1.3).is_ok()); - /// assert!(gauss.set_mu(std::f64::NEG_INFINITY).is_err()); - /// assert!(gauss.set_mu(std::f64::INFINITY).is_err()); - /// assert!(gauss.set_mu(std::f64::NAN).is_err()); + /// assert!(gauss.set_mu(f64::NEG_INFINITY).is_err()); + /// assert!(gauss.set_mu(f64::INFINITY).is_err()); + /// assert!(gauss.set_mu(f64::NAN).is_err()); /// ``` #[inline] pub fn set_mu(&mut self, mu: f64) -> Result<(), GaussianError> { @@ -209,9 +229,9 @@ impl Gaussian { /// assert!(gauss.set_sigma(2.3).is_ok()); /// assert!(gauss.set_sigma(0.0).is_err()); /// assert!(gauss.set_sigma(-1.0).is_err()); - /// assert!(gauss.set_sigma(std::f64::INFINITY).is_err()); - /// assert!(gauss.set_sigma(std::f64::NEG_INFINITY).is_err()); - /// assert!(gauss.set_sigma(std::f64::NAN).is_err()); + /// assert!(gauss.set_sigma(f64::INFINITY).is_err()); + /// assert!(gauss.set_sigma(f64::NEG_INFINITY).is_err()); + /// assert!(gauss.set_sigma(f64::NAN).is_err()); /// ``` #[inline] pub fn set_sigma(&mut self, sigma: f64) -> Result<(), GaussianError> { @@ -255,12 +275,14 @@ impl_display!(Gaussian); macro_rules! impl_traits { ($kind:ty) => { - impl Rv<$kind> for Gaussian { + impl HasDensity<$kind> for Gaussian { fn ln_f(&self, x: &$kind) -> f64 { let k = (f64::from(*x) - self.mu) / self.sigma; (0.5 * k).mul_add(-k, -self.ln_sigma()) - HALF_LN_2PI } + } + impl Sampleable<$kind> for Gaussian { fn draw(&self, rng: &mut R) -> $kind { let g = Normal::new(self.mu, self.sigma).unwrap(); rng.sample(g) as $kind @@ -411,7 +433,7 @@ mod tests { const TOL: f64 = 1E-12; - test_basic_impls!([continuous] Gaussian::standard()); + test_basic_impls!(f64, Gaussian); #[test] fn new() { @@ -618,7 +640,7 @@ mod tests { } #[test] - fn kl_of_idential_dsitrbutions_should_be_zero() { + fn kl_of_identical_dsitrbutions_should_be_zero() { let gauss = Gaussian::new(1.2, 3.4).unwrap(); assert::close(gauss.kl(&gauss), 0.0, TOL); } diff --git a/src/dist/geometric.rs b/src/dist/geometric.rs index 37ea7b1..cb57857 100644 --- a/src/dist/geometric.rs +++ b/src/dist/geometric.rs @@ -1,4 +1,4 @@ -//! Possion distribution on unisgned integers +//! Possion distribution on unsigned integers #[cfg(feature = "serde1")] use serde::{Deserialize, Serialize}; @@ -39,6 +39,18 @@ pub struct Geometric { ln_1mp: OnceLock, } +impl Parameterized for Geometric { + type Parameters = f64; + + fn emit_params(&self) -> Self::Parameters { + self.p() + } + + fn from_params(p: Self::Parameters) -> Self { + Self::new_unchecked(p) + } +} + #[derive(Debug, Clone, PartialEq)] #[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] #[cfg_attr(feature = "serde1", serde(rename_all = "snake_case"))] @@ -115,9 +127,9 @@ impl Geometric { /// assert!(geom.set_p(0.0).is_err()); /// assert!(geom.set_p(-1.0).is_err()); /// assert!(geom.set_p(1.1).is_err()); - /// assert!(geom.set_p(std::f64::INFINITY).is_err()); - /// assert!(geom.set_p(std::f64::NEG_INFINITY).is_err()); - /// assert!(geom.set_p(std::f64::NAN).is_err()); + /// assert!(geom.set_p(f64::INFINITY).is_err()); + /// assert!(geom.set_p(f64::NEG_INFINITY).is_err()); + /// assert!(geom.set_p(f64::NAN).is_err()); /// ``` #[inline] pub fn set_p(&mut self, p: f64) -> Result<(), GeometricError> { @@ -204,7 +216,7 @@ impl From<&Geometric> for String { impl_display!(Geometric); -impl Rv for Geometric +impl HasDensity for Geometric where X: Unsigned + Integer + FromPrimitive + ToPrimitive + Saturating + Bounded, { @@ -213,7 +225,12 @@ where kf.mul_add(self.ln_1mp(), self.ln_p()) // kf.mul_add((1.0 - self.p).ln(), self.p.ln()) } +} +impl Sampleable for Geometric +where + X: Unsigned + Integer + FromPrimitive + ToPrimitive + Saturating + Bounded, +{ fn draw(&self, rng: &mut R) -> X { // Follows the same pattern as // https://github.com/numpy/numpy/blob/7c41164f5340dc998ea1c04d2061f7d246894955/numpy/random/mtrand/distributions.c#L777 @@ -309,7 +326,7 @@ mod tests { const N_TRIES: usize = 5; const X2_PVAL: f64 = 0.2; - test_basic_impls!([count] Geometric::default()); + test_basic_impls!(u32, Geometric); #[test] fn new() { diff --git a/src/dist/gev.rs b/src/dist/gev.rs index 7f84a67..e589a3f 100644 --- a/src/dist/gev.rs +++ b/src/dist/gev.rs @@ -32,6 +32,28 @@ pub struct Gev { shape: f64, } +pub struct GevParameters { + pub loc: f64, + pub scale: f64, + pub shape: f64, +} + +impl Parameterized for Gev { + type Parameters = GevParameters; + + fn emit_params(&self) -> Self::Parameters { + Self::Parameters { + loc: self.loc(), + scale: self.scale(), + shape: self.shape(), + } + } + + fn from_params(params: Self::Parameters) -> Self { + Self::new_unchecked(params.loc, params.scale, params.shape) + } +} + #[derive(Debug, Clone, PartialEq)] #[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] #[cfg_attr(feature = "serde1", serde(rename_all = "snake_case"))] @@ -102,9 +124,9 @@ impl Gev { /// # use rv::dist::Gev; /// # let mut gev = Gev::new(1.2, 2.3, 3.4).unwrap(); /// assert!(gev.set_loc(2.8).is_ok()); - /// assert!(gev.set_loc(std::f64::INFINITY).is_err()); - /// assert!(gev.set_loc(std::f64::NEG_INFINITY).is_err()); - /// assert!(gev.set_loc(std::f64::NAN).is_err()); + /// assert!(gev.set_loc(f64::INFINITY).is_err()); + /// assert!(gev.set_loc(f64::NEG_INFINITY).is_err()); + /// assert!(gev.set_loc(f64::NAN).is_err()); /// ``` #[inline] pub fn set_loc(&mut self, loc: f64) -> Result<(), GevError> { @@ -156,9 +178,9 @@ impl Gev { /// # use rv::dist::Gev; /// # let mut gev = Gev::new(1.2, 2.3, 3.4).unwrap(); /// assert!(gev.set_shape(2.8).is_ok()); - /// assert!(gev.set_shape(std::f64::INFINITY).is_err()); - /// assert!(gev.set_shape(std::f64::NEG_INFINITY).is_err()); - /// assert!(gev.set_shape(std::f64::NAN).is_err()); + /// assert!(gev.set_shape(f64::INFINITY).is_err()); + /// assert!(gev.set_shape(f64::NEG_INFINITY).is_err()); + /// assert!(gev.set_shape(f64::NAN).is_err()); /// ``` #[inline] pub fn set_shape(&mut self, shape: f64) -> Result<(), GevError> { @@ -212,9 +234,9 @@ impl Gev { /// assert!(gev.set_scale(2.8).is_ok()); /// assert!(gev.set_scale(0.0).is_err()); /// assert!(gev.set_scale(-1.0).is_err()); - /// assert!(gev.set_scale(std::f64::INFINITY).is_err()); - /// assert!(gev.set_scale(std::f64::NEG_INFINITY).is_err()); - /// assert!(gev.set_scale(std::f64::NAN).is_err()); + /// assert!(gev.set_scale(f64::INFINITY).is_err()); + /// assert!(gev.set_scale(f64::NEG_INFINITY).is_err()); + /// assert!(gev.set_scale(f64::NAN).is_err()); /// ``` #[inline] pub fn set_scale(&mut self, scale: f64) -> Result<(), GevError> { @@ -253,13 +275,15 @@ impl_display!(Gev); macro_rules! impl_traits { ($kind: ty) => { - impl Rv<$kind> for Gev { + impl HasDensity<$kind> for Gev { fn ln_f(&self, x: &$kind) -> f64 { // TODO: could cache ln(scale) let tv = t(self.loc, self.shape, self.scale, f64::from(*x)); (self.shape + 1.0).mul_add(tv.ln(), -self.scale.ln()) - tv } + } + impl Sampleable<$kind> for Gev { fn draw(&self, rng: &mut R) -> $kind { let uni = rand_distr::Open01; let u: f64 = rng.sample(uni); @@ -399,13 +423,12 @@ mod tests { use crate::misc::ks_test; use crate::misc::linspace; use crate::test_basic_impls; - use std::f64; const TOL: f64 = 1E-12; const KS_PVAL: f64 = 0.2; const N_TRIES: usize = 5; - test_basic_impls!([continuous] Gev::new(0.0, 1.0, 2.0).unwrap()); + test_basic_impls!(f64, Gev, Gev::new(0.0, 1.0, 2.0).unwrap()); #[test] fn new() { diff --git a/src/dist/inv_chi_squared.rs b/src/dist/inv_chi_squared.rs index acf0b17..f46fb7e 100644 --- a/src/dist/inv_chi_squared.rs +++ b/src/dist/inv_chi_squared.rs @@ -1,4 +1,4 @@ -//! Χ-2 over x in (0, ∞) +//! Χ-2 over x in (0, ∞) #[cfg(feature = "serde1")] use serde::{Deserialize, Serialize}; @@ -32,6 +32,18 @@ pub struct InvChiSquared { ln_f_const: OnceLock, } +impl Parameterized for InvChiSquared { + type Parameters = f64; + + fn emit_params(&self) -> Self::Parameters { + self.v() + } + + fn from_params(v: Self::Parameters) -> Self { + Self::new_unchecked(v) + } +} + impl PartialEq for InvChiSquared { fn eq(&self, other: &InvChiSquared) -> bool { self.v == other.v @@ -111,8 +123,8 @@ impl InvChiSquared { /// assert!(ix2.set_v(2.2).is_ok()); /// assert!(ix2.set_v(0.0).is_err()); /// assert!(ix2.set_v(-1.0).is_err()); - /// assert!(ix2.set_v(std::f64::NAN).is_err()); - /// assert!(ix2.set_v(std::f64::INFINITY).is_err()); + /// assert!(ix2.set_v(f64::NAN).is_err()); + /// assert!(ix2.set_v(f64::INFINITY).is_err()); /// ``` #[inline] pub fn set_v(&mut self, v: f64) -> Result<(), InvChiSquaredError> { @@ -152,13 +164,15 @@ impl_display!(InvChiSquared); macro_rules! impl_traits { ($kind:ty) => { - impl Rv<$kind> for InvChiSquared { + impl HasDensity<$kind> for InvChiSquared { fn ln_f(&self, x: &$kind) -> f64 { let x64 = f64::from(*x); let z = self.ln_f_const(); (-self.v / 2.0 - 1.0).mul_add(x64.ln(), z) - (2.0 * x64).recip() } + } + impl Sampleable<$kind> for InvChiSquared { fn draw(&self, rng: &mut R) -> $kind { let x2 = rand_distr::ChiSquared::new(self.v).unwrap(); let x_inv: f64 = rng.sample(x2); @@ -262,7 +276,7 @@ mod test { const KS_PVAL: f64 = 0.2; const N_TRIES: usize = 5; - test_basic_impls!([continuous] InvChiSquared::new(3.2).unwrap()); + test_basic_impls!(f64, InvChiSquared, InvChiSquared::new(3.2).unwrap()); #[test] fn new() { diff --git a/src/dist/invgamma.rs b/src/dist/invgamma.rs index af1938d..b5a6bd2 100644 --- a/src/dist/invgamma.rs +++ b/src/dist/invgamma.rs @@ -28,6 +28,26 @@ pub struct InvGamma { scale: f64, } +pub struct InvGammaParameters { + pub shape: f64, + pub scale: f64, +} + +impl Parameterized for InvGamma { + type Parameters = InvGammaParameters; + + fn emit_params(&self) -> Self::Parameters { + Self::Parameters { + shape: self.shape(), + scale: self.scale(), + } + } + + fn from_params(params: Self::Parameters) -> Self { + Self::new_unchecked(params.shape, params.scale) + } +} + #[derive(Debug, Clone, PartialEq)] #[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] #[cfg_attr(feature = "serde1", serde(rename_all = "snake_case"))] @@ -69,7 +89,7 @@ impl InvGamma { InvGamma { shape, scale } } - /// Get the shape paramter + /// Get the shape parameter /// /// # Example /// @@ -104,9 +124,9 @@ impl InvGamma { /// assert!(ig.set_shape(1.1).is_ok()); /// assert!(ig.set_shape(0.0).is_err()); /// assert!(ig.set_shape(-1.0).is_err()); - /// assert!(ig.set_shape(std::f64::INFINITY).is_err()); - /// assert!(ig.set_shape(std::f64::NEG_INFINITY).is_err()); - /// assert!(ig.set_shape(std::f64::NAN).is_err()); + /// assert!(ig.set_shape(f64::INFINITY).is_err()); + /// assert!(ig.set_shape(f64::NEG_INFINITY).is_err()); + /// assert!(ig.set_shape(f64::NAN).is_err()); /// ``` #[inline] pub fn set_shape(&mut self, shape: f64) -> Result<(), InvGammaError> { @@ -161,9 +181,9 @@ impl InvGamma { /// assert!(ig.set_scale(1.1).is_ok()); /// assert!(ig.set_scale(0.0).is_err()); /// assert!(ig.set_scale(-1.0).is_err()); - /// assert!(ig.set_scale(std::f64::INFINITY).is_err()); - /// assert!(ig.set_scale(std::f64::NEG_INFINITY).is_err()); - /// assert!(ig.set_scale(std::f64::NAN).is_err()); + /// assert!(ig.set_scale(f64::INFINITY).is_err()); + /// assert!(ig.set_scale(f64::NEG_INFINITY).is_err()); + /// assert!(ig.set_scale(f64::NAN).is_err()); /// ``` #[inline] pub fn set_scale(&mut self, scale: f64) -> Result<(), InvGammaError> { @@ -203,7 +223,7 @@ impl_display!(InvGamma); macro_rules! impl_traits { ($kind:ty) => { - impl Rv<$kind> for InvGamma { + impl HasDensity<$kind> for InvGamma { fn ln_f(&self, x: &$kind) -> f64 { // TODO: could cache ln(scale) and ln_gamma(shape) let xf = f64::from(*x); @@ -213,7 +233,9 @@ macro_rules! impl_traits { .mul_add(self.scale.ln(), -ln_gammafn(self.shape)), ) - (self.scale / xf) } + } + impl Sampleable<$kind> for InvGamma { fn draw(&self, rng: &mut R) -> $kind { let g = rand_distr::Gamma::new(self.shape, self.scale.recip()) .unwrap(); @@ -356,7 +378,7 @@ mod tests { const KS_PVAL: f64 = 0.2; const N_TRIES: usize = 5; - test_basic_impls!([continuous] InvGamma::default()); + test_basic_impls!(f64, InvGamma, InvGamma::new_unchecked(1.5, 2.3)); #[test] fn new() { @@ -456,7 +478,7 @@ mod tests { } #[test] - fn ln_pdf_at_mode_should_be_higest() { + fn ln_pdf_at_mode_should_be_highest() { let ig = InvGamma::new(3.0, 2.0).unwrap(); let x: f64 = ig.mode().unwrap(); let delta = 1E-6; diff --git a/src/dist/invgaussian.rs b/src/dist/invgaussian.rs index c74bbca..295148f 100644 --- a/src/dist/invgaussian.rs +++ b/src/dist/invgaussian.rs @@ -27,6 +27,26 @@ pub struct InvGaussian { ln_lambda: OnceLock, } +pub struct InvGaussianParameters { + pub mu: f64, + pub lambda: f64, +} + +impl Parameterized for InvGaussian { + type Parameters = InvGaussianParameters; + + fn emit_params(&self) -> Self::Parameters { + Self::Parameters { + mu: self.mu(), + lambda: self.lambda(), + } + } + + fn from_params(params: Self::Parameters) -> Self { + Self::new_unchecked(params.mu, params.lambda) + } +} + impl PartialEq for InvGaussian { fn eq(&self, other: &InvGaussian) -> bool { self.mu == other.mu && self.lambda == other.lambda @@ -50,7 +70,7 @@ pub enum InvGaussianError { impl InvGaussian { /// Create a new Inverse Gaussian distribution /// - /// # Aruments + /// # Arguments /// - mu: mean > 0 /// - lambda: shape > 0 /// @@ -136,9 +156,9 @@ impl InvGaussian { /// assert!(ig.set_mu(1.3).is_ok()); /// assert!(ig.set_mu(0.0).is_err()); /// assert!(ig.set_mu(-1.0).is_err()); - /// assert!(ig.set_mu(std::f64::NEG_INFINITY).is_err()); - /// assert!(ig.set_mu(std::f64::INFINITY).is_err()); - /// assert!(ig.set_mu(std::f64::NAN).is_err()); + /// assert!(ig.set_mu(f64::NEG_INFINITY).is_err()); + /// assert!(ig.set_mu(f64::INFINITY).is_err()); + /// assert!(ig.set_mu(f64::NAN).is_err()); /// ``` #[inline] pub fn set_mu(&mut self, mu: f64) -> Result<(), InvGaussianError> { @@ -194,9 +214,9 @@ impl InvGaussian { /// assert!(ig.set_lambda(2.3).is_ok()); /// assert!(ig.set_lambda(0.0).is_err()); /// assert!(ig.set_lambda(-1.0).is_err()); - /// assert!(ig.set_lambda(std::f64::INFINITY).is_err()); - /// assert!(ig.set_lambda(std::f64::NEG_INFINITY).is_err()); - /// assert!(ig.set_lambda(std::f64::NAN).is_err()); + /// assert!(ig.set_lambda(f64::INFINITY).is_err()); + /// assert!(ig.set_lambda(f64::NEG_INFINITY).is_err()); + /// assert!(ig.set_lambda(f64::NAN).is_err()); /// ``` #[inline] pub fn set_lambda(&mut self, lambda: f64) -> Result<(), InvGaussianError> { @@ -217,12 +237,6 @@ impl InvGaussian { self.lambda = lambda; } - /// Return (mu, lambda) - #[inline] - pub fn params(&self) -> (f64, f64) { - (self.mu, self.lambda) - } - #[inline] fn ln_lambda(&self) -> f64 { *self.ln_lambda.get_or_init(|| self.lambda.ln()) @@ -239,19 +253,21 @@ impl_display!(InvGaussian); macro_rules! impl_traits { ($kind:ty) => { - impl Rv<$kind> for InvGaussian { + impl HasDensity<$kind> for InvGaussian { fn ln_f(&self, x: &$kind) -> f64 { - let (mu, lambda) = self.params(); + let InvGaussianParameters { mu, lambda } = self.emit_params(); let xf = f64::from(*x); let z = self.ln_lambda() - xf.ln().mul_add(3.0, LN_2PI); let err = xf - mu; let term = lambda * err * err / (2.0 * mu * mu * xf); z.mul_add(0.5, -term) } + } + impl Sampleable<$kind> for InvGaussian { // https://en.wikipedia.org/wiki/Inverse_Gaussian_distribution#Sampling_from_an_inverse-Gaussian_distribution fn draw(&self, rng: &mut R) -> $kind { - let (mu, lambda) = self.params(); + let InvGaussianParameters { mu, lambda } = self.emit_params(); let g = Normal::new(0.0, 1.0).unwrap(); let v: f64 = rng.sample(g); let y = v * v; @@ -284,7 +300,7 @@ macro_rules! impl_traits { impl Cdf<$kind> for InvGaussian { fn cdf(&self, x: &$kind) -> f64 { let xf = f64::from(*x); - let (mu, lambda) = self.params(); + let InvGaussianParameters { mu, lambda } = self.emit_params(); let gauss = crate::dist::Gaussian::standard(); let z = (lambda / xf).sqrt(); let a = z * (xf / mu - 1.0); @@ -302,7 +318,7 @@ macro_rules! impl_traits { impl Mode<$kind> for InvGaussian { fn mode(&self) -> Option<$kind> { - let (mu, lambda) = self.params(); + let InvGaussianParameters { mu, lambda } = self.emit_params(); let a = (1.0 + 0.25 * 9.0 * mu * mu / (lambda * lambda)).sqrt(); let b = 0.5 * 3.0 * mu / lambda; let mode = mu * (a - b); @@ -381,8 +397,14 @@ mod tests { const N_TRIES: usize = 10; const KS_PVAL: f64 = 0.2; + crate::test_basic_impls!( + f64, + InvGaussian, + InvGaussian::new(1.0, 2.3).unwrap() + ); + #[test] - fn mode_is_higest_point() { + fn mode_is_highest_point() { let mut rng = rand::thread_rng(); let mu_prior = crate::dist::InvGamma::new_unchecked(2.0, 2.0); let lambda_prior = crate::dist::InvGamma::new_unchecked(2.0, 2.0); diff --git a/src/dist/ks.rs b/src/dist/ks.rs index 2af0f41..b75a9bc 100644 --- a/src/dist/ks.rs +++ b/src/dist/ks.rs @@ -1,17 +1,13 @@ //! Kolmogorow-Smirnov two-sided test for large values of N. //! Heavily inspired by SciPy's implementation which can be found here: -//! https://github.com/scipy/scipy/blob/a767030252ba3f7c8e2924847dffa7024171657b/scipy/special/cephes/kolmogorov.c#L153 - +//! #[cfg(feature = "serde1")] use serde::{Deserialize, Serialize}; use crate::impl_display; use crate::traits::*; use rand::Rng; -use std::f64::{ - consts::{PI, SQRT_2}, - EPSILON, -}; +use std::f64::consts::{PI, SQRT_2}; #[inline] fn within_tol(x: f64, y: f64, atol: f64, rtol: f64) -> bool { @@ -19,9 +15,12 @@ fn within_tol(x: f64, y: f64, atol: f64, rtol: f64) -> bool { diff <= rtol.mul_add(y.abs(), atol) } -/// Kolmogorov-Smirnov distribution where the number of samples, $N$, is assumed to be large -/// This is the distribution of $\sqrt{N} D_n$ where $D_n = \sup_x |F_n(x) - F(x)|$ where $F$ -/// is the true CDF and $F_n$ the emperical CDF. +/// Kolmogorov-Smirnov distribution where the number of samples, $N$, is +/// assumed to be large. +/// +/// This is the distribution of $\sqrt{N} D_n$ where +/// $D_n = \sup_x |F_n(x) - F(x)|$ where $F$ is the true CDF and $F_n$ the +/// empirical CDF. /// /// # Example /// @@ -41,6 +40,16 @@ fn within_tol(x: f64, y: f64, atol: f64, rtol: f64) -> bool { #[cfg_attr(feature = "serde1", serde(rename_all = "snake_case"))] pub struct KsTwoAsymptotic {} +impl Parameterized for KsTwoAsymptotic { + type Parameters = (); + + fn emit_params(&self) -> Self::Parameters {} + + fn from_params(_params: Self::Parameters) -> Self { + Self {} + } +} + struct CdfPdf { cdf: f64, pdf: f64, @@ -97,7 +106,7 @@ impl KsTwoAsymptotic { } CdfPdf { - cdf: p.max(0.0).min(1.0), + cdf: p.clamp(0.0, 1.0), pdf: d.max(0.0), } } else { @@ -133,7 +142,7 @@ impl KsTwoAsymptotic { p *= 2.0 * v; d *= 8.0 * v * x; p = p.max(0.0); - let cdf = (1.0 - p).max(0.0).min(1.0); + let cdf = (1.0 - p).clamp(0.0, 1.0); let pdf = d.max(0.0); CdfPdf { cdf, pdf } } @@ -145,13 +154,13 @@ impl KsTwoAsymptotic { #[allow(clippy::many_single_char_names)] fn inverse(sf: f64, cdf: f64) -> f64 { if !(sf >= 0.0 && cdf >= 0.0 && sf <= 1.0 && cdf <= 1.0) - || (1.0 - cdf - sf).abs() > 4.0 * EPSILON + || (1.0 - cdf - sf).abs() > 4.0 * f64::EPSILON { - std::f64::NAN + f64::NAN } else if cdf == 0.0 { 0.0 } else if sf == 0.0 { - std::f64::INFINITY + f64::INFINITY } else { let mut x: f64; let mut a: f64; @@ -177,7 +186,7 @@ impl KsTwoAsymptotic { * (-(logcdf + b.ln() - log_sqrt_2pi)).sqrt()); x = (a + b) / 2.0; } else { - const JITTERB: f64 = EPSILON * 256.0; + const JITTERB: f64 = f64::EPSILON * 256.0; let pba = sf / (2.0 * (1.0 - (-4.0_f64).exp())); let pbb = sf * (1.0 - JITTERB) / 2.0; @@ -230,7 +239,7 @@ impl KsTwoAsymptotic { } let dfdx = -c.pdf; - if dfdx.abs() <= EPSILON { + if dfdx.abs() <= f64::EPSILON { x = (a + b) / 2.0; } else { let t = df / dfdx; @@ -238,18 +247,21 @@ impl KsTwoAsymptotic { } if x >= a && x <= b { - if within_tol(x, x0, EPSILON, EPSILON * 2.0) { + if within_tol(x, x0, f64::EPSILON, f64::EPSILON * 2.0) { break; - } else if (x - a).abs() < EPSILON || (x - b).abs() < EPSILON + } else if (x - a).abs() < f64::EPSILON + || (x - b).abs() < f64::EPSILON { x = (a + b) / 2.0; - if (x - a).abs() > EPSILON || (x - b).abs() < EPSILON { + if (x - a).abs() > f64::EPSILON + || (x - b).abs() < f64::EPSILON + { break; } } } else { x = (a + b) / 2.0; - if within_tol(x, x0, EPSILON, EPSILON * 2.0) { + if within_tol(x, x0, f64::EPSILON, f64::EPSILON * 2.0) { break; } } @@ -270,11 +282,13 @@ impl_display!(KsTwoAsymptotic); macro_rules! impl_traits { ($kind:ty) => { - impl Rv<$kind> for KsTwoAsymptotic { + impl HasDensity<$kind> for KsTwoAsymptotic { fn ln_f(&self, x: &$kind) -> f64 { Self::compute(*x as f64).pdf.ln() } + } + impl Sampleable<$kind> for KsTwoAsymptotic { fn draw(&self, rng: &mut R) -> $kind { let p: f64 = rng.gen(); self.invcdf(p) diff --git a/src/dist/kumaraswamy.rs b/src/dist/kumaraswamy.rs index 4048ed5..e87d199 100644 --- a/src/dist/kumaraswamy.rs +++ b/src/dist/kumaraswamy.rs @@ -30,7 +30,8 @@ use std::sync::OnceLock; /// assert::close(x, y, 1E-10); /// ``` /// -/// Kumaraswamy(a, 1) is equivalent to Beta(a, 1) and Kumaraswamy(1, b) is equivalent to Beta(1, b) +/// Kumaraswamy(a, 1) is equivalent to Beta(a, 1) and Kumaraswamy(1, b) is +/// equivalent to Beta(1, b) /// /// ``` /// # use rv::prelude::*; @@ -54,6 +55,26 @@ pub struct Kumaraswamy { ab_ln: OnceLock, } +pub struct KumaraswamyParameters { + pub a: f64, + pub b: f64, +} + +impl Parameterized for Kumaraswamy { + type Parameters = KumaraswamyParameters; + + fn emit_params(&self) -> Self::Parameters { + Self::Parameters { + a: self.a(), + b: self.b(), + } + } + + fn from_params(params: Self::Parameters) -> Self { + Self::new_unchecked(params.a, params.b) + } +} + impl PartialEq for Kumaraswamy { fn eq(&self, other: &Kumaraswamy) -> bool { self.a == other.a && self.b == other.b @@ -162,7 +183,7 @@ impl Kumaraswamy { /// /// ```rust /// # use rv::dist::Kumaraswamy; - /// # use rv::traits::{Rv, Cdf, Median}; + /// # use rv::traits::*; /// // Bowl-shaped /// let kuma_1 = Kumaraswamy::centered(0.5).unwrap(); /// let median_1: f64 = kuma_1.median().unwrap(); @@ -182,7 +203,7 @@ impl Kumaraswamy { /// /// ```rust /// # use rv::dist::Kumaraswamy; - /// # use rv::traits::{Rv, Cdf}; + /// # use rv::traits::*; /// fn absolute_error(a: f64, b: f64) -> f64 { /// (a - b).abs() /// } @@ -253,9 +274,9 @@ impl Kumaraswamy { /// # let mut kuma = Kumaraswamy::new(1.0, 5.0).unwrap(); /// assert!(kuma.set_a(2.3).is_ok()); /// assert!(kuma.set_a(0.0).is_err()); - /// assert!(kuma.set_a(std::f64::INFINITY).is_err()); - /// assert!(kuma.set_a(std::f64::NEG_INFINITY).is_err()); - /// assert!(kuma.set_a(std::f64::NAN).is_err()); + /// assert!(kuma.set_a(f64::INFINITY).is_err()); + /// assert!(kuma.set_a(f64::NEG_INFINITY).is_err()); + /// assert!(kuma.set_a(f64::NAN).is_err()); /// ``` #[inline] pub fn set_a(&mut self, a: f64) -> Result<(), KumaraswamyError> { @@ -295,9 +316,9 @@ impl Kumaraswamy { /// # let mut kuma = Kumaraswamy::new(1.0, 5.0).unwrap(); /// assert!(kuma.set_b(2.3).is_ok()); /// assert!(kuma.set_b(0.0).is_err()); - /// assert!(kuma.set_b(std::f64::INFINITY).is_err()); - /// assert!(kuma.set_b(std::f64::NEG_INFINITY).is_err()); - /// assert!(kuma.set_b(std::f64::NAN).is_err()); + /// assert!(kuma.set_b(f64::INFINITY).is_err()); + /// assert!(kuma.set_b(f64::NEG_INFINITY).is_err()); + /// assert!(kuma.set_b(f64::NAN).is_err()); /// ``` #[inline] pub fn set_b(&mut self, b: f64) -> Result<(), KumaraswamyError> { @@ -332,7 +353,7 @@ fn invcdf(p: f64, a: f64, b: f64) -> f64 { macro_rules! impl_kumaraswamy { ($kind: ty) => { - impl Rv<$kind> for Kumaraswamy { + impl HasDensity<$kind> for Kumaraswamy { fn ln_f(&self, x: &$kind) -> f64 { let xf = *x as f64; let a = self.a; @@ -342,7 +363,9 @@ macro_rules! impl_kumaraswamy { (a - 1.0).mul_add(xf.ln(), self.ab_ln()), ) } + } + impl Sampleable<$kind> for Kumaraswamy { fn draw(&self, rng: &mut R) -> $kind { let p: f64 = rng.gen(); invcdf(p, self.a, self.b) as $kind @@ -444,7 +467,7 @@ mod tests { const KS_PVAL: f64 = 0.2; const N_TRIES: usize = 5; - test_basic_impls!([continuous] Kumaraswamy::centered(1.2).unwrap()); + test_basic_impls!(f64, Kumaraswamy, Kumaraswamy::centered(1.2).unwrap()); #[test] fn cdf_uniform_midpoint() { @@ -453,7 +476,7 @@ mod tests { } #[test] - fn draw_should_resturn_values_within_0_to_1() { + fn draw_should_return_values_within_0_to_1() { let mut rng = rand::thread_rng(); let kuma = Kumaraswamy::default(); for _ in 0..100 { diff --git a/src/dist/laplace.rs b/src/dist/laplace.rs index 58909aa..1c21216 100644 --- a/src/dist/laplace.rs +++ b/src/dist/laplace.rs @@ -33,6 +33,26 @@ pub struct Laplace { b: f64, } +pub struct LaplaceParameters { + pub mu: f64, + pub b: f64, +} + +impl Parameterized for Laplace { + type Parameters = LaplaceParameters; + + fn emit_params(&self) -> Self::Parameters { + Self::Parameters { + mu: self.mu(), + b: self.b(), + } + } + + fn from_params(params: Self::Parameters) -> Self { + Self::new_unchecked(params.mu, params.b) + } +} + #[derive(Debug, Clone, PartialEq)] #[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] #[cfg_attr(feature = "serde1", serde(rename_all = "snake_case"))] @@ -98,9 +118,9 @@ impl Laplace { /// # use rv::dist::Laplace; /// # let mut laplace = Laplace::new(-1.0, 2.0).unwrap(); /// assert!(laplace.set_mu(0.0).is_ok()); - /// assert!(laplace.set_mu(std::f64::INFINITY).is_err()); - /// assert!(laplace.set_mu(std::f64::NEG_INFINITY).is_err()); - /// assert!(laplace.set_mu(std::f64::NAN).is_err()); + /// assert!(laplace.set_mu(f64::INFINITY).is_err()); + /// assert!(laplace.set_mu(f64::NEG_INFINITY).is_err()); + /// assert!(laplace.set_mu(f64::NAN).is_err()); /// ``` #[inline] pub fn set_mu(&mut self, mu: f64) -> Result<(), LaplaceError> { @@ -151,9 +171,9 @@ impl Laplace { /// # let mut laplace = Laplace::new(-1.0, 2.0).unwrap(); /// assert!(laplace.set_b(2.3).is_ok()); /// assert!(laplace.set_b(0.0).is_err()); - /// assert!(laplace.set_b(std::f64::INFINITY).is_err()); - /// assert!(laplace.set_b(std::f64::NEG_INFINITY).is_err()); - /// assert!(laplace.set_b(std::f64::NAN).is_err()); + /// assert!(laplace.set_b(f64::INFINITY).is_err()); + /// assert!(laplace.set_b(f64::NEG_INFINITY).is_err()); + /// assert!(laplace.set_b(f64::NAN).is_err()); /// ``` #[inline] pub fn set_b(&mut self, b: f64) -> Result<(), LaplaceError> { @@ -197,12 +217,14 @@ fn laplace_partial_draw(u: f64) -> f64 { macro_rules! impl_traits { ($kind:ty) => { - impl Rv<$kind> for Laplace { + impl HasDensity<$kind> for Laplace { fn ln_f(&self, x: &$kind) -> f64 { // TODO: could cache ln(b) -(f64::from(*x) - self.mu).abs() / self.b - self.b.ln() - LN_2 } + } + impl Sampleable<$kind> for Laplace { fn draw(&self, rng: &mut R) -> $kind { let u = rng.sample(rand_distr::OpenClosed01); self.b.mul_add(-laplace_partial_draw(u), self.mu) as $kind @@ -300,7 +322,7 @@ mod tests { const KS_PVAL: f64 = 0.2; const N_TRIES: usize = 5; - test_basic_impls!([continuous] Laplace::default()); + test_basic_impls!(f64, Laplace); #[test] fn new() { diff --git a/src/dist/lognormal.rs b/src/dist/lognormal.rs index 086f558..f74c93e 100644 --- a/src/dist/lognormal.rs +++ b/src/dist/lognormal.rs @@ -21,6 +21,26 @@ pub struct LogNormal { sigma: f64, } +pub struct LogNormalParameters { + pub mu: f64, + pub sigma: f64, +} + +impl Parameterized for LogNormal { + type Parameters = LogNormalParameters; + + fn emit_params(&self) -> Self::Parameters { + Self::Parameters { + mu: self.mu(), + sigma: self.sigma(), + } + } + + fn from_params(params: Self::Parameters) -> Self { + Self::new_unchecked(params.mu, params.sigma) + } +} + #[derive(Debug, Clone, PartialEq)] #[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] #[cfg_attr(feature = "serde1", serde(rename_all = "snake_case"))] @@ -59,7 +79,7 @@ impl LogNormal { LogNormal { mu, sigma } } - /// LogNorma(0, 1) + /// LogNormal(0, 1) /// /// # Example /// @@ -109,9 +129,9 @@ impl LogNormal { /// # use rv::dist::LogNormal; /// # let mut lognormal = LogNormal::new(2.0, 1.5).unwrap(); /// assert!(lognormal.set_mu(1.3).is_ok()); - /// assert!(lognormal.set_mu(std::f64::NEG_INFINITY).is_err()); - /// assert!(lognormal.set_mu(std::f64::INFINITY).is_err()); - /// assert!(lognormal.set_mu(std::f64::NAN).is_err()); + /// assert!(lognormal.set_mu(f64::NEG_INFINITY).is_err()); + /// assert!(lognormal.set_mu(f64::INFINITY).is_err()); + /// assert!(lognormal.set_mu(f64::NAN).is_err()); /// ``` #[inline] pub fn set_mu(&mut self, mu: f64) -> Result<(), LogNormalError> { @@ -164,9 +184,9 @@ impl LogNormal { /// assert!(lognormal.set_sigma(2.3).is_ok()); /// assert!(lognormal.set_sigma(0.0).is_err()); /// assert!(lognormal.set_sigma(-1.0).is_err()); - /// assert!(lognormal.set_sigma(std::f64::INFINITY).is_err()); - /// assert!(lognormal.set_sigma(std::f64::NEG_INFINITY).is_err()); - /// assert!(lognormal.set_sigma(std::f64::NAN).is_err()); + /// assert!(lognormal.set_sigma(f64::INFINITY).is_err()); + /// assert!(lognormal.set_sigma(f64::NEG_INFINITY).is_err()); + /// assert!(lognormal.set_sigma(f64::NAN).is_err()); /// ``` #[inline] pub fn set_sigma(&mut self, sigma: f64) -> Result<(), LogNormalError> { @@ -203,7 +223,7 @@ impl_display!(LogNormal); macro_rules! impl_traits { ($kind: ty) => { - impl Rv<$kind> for LogNormal { + impl HasDensity<$kind> for LogNormal { fn ln_f(&self, x: &$kind) -> f64 { // TODO: cache ln(sigma) let xk = f64::from(*x); @@ -211,7 +231,9 @@ macro_rules! impl_traits { let d = (xk_ln - self.mu) / self.sigma; (0.5 * d).mul_add(-d, -xk_ln - self.sigma.ln() - HALF_LN_2PI) } + } + impl Sampleable<$kind> for LogNormal { fn draw(&self, rng: &mut R) -> $kind { let g = rand_distr::LogNormal::new(self.mu, self.sigma).unwrap(); @@ -331,11 +353,12 @@ impl fmt::Display for LogNormalError { mod tests { use super::*; use crate::test_basic_impls; + use proptest::prelude::*; use std::f64; const TOL: f64 = 1E-12; - test_basic_impls!([continuous] LogNormal::default()); + test_basic_impls!(f64, LogNormal); #[test] fn new() { @@ -466,17 +489,16 @@ mod tests { assert::close(lognorm.cdf(&2.0_f64), 0.755_891_404_214_417_3, TOL); } - #[test] - fn quantile_agree_with_cdf() { - let mut rng = rand::thread_rng(); - let lognorm = LogNormal::standard(); - let xs: Vec = lognorm.sample(100, &mut rng); - - xs.iter().for_each(|x| { - let p = lognorm.cdf(x); + proptest! { + #[test] + fn quantile_agree_with_cdf(p in 0.0..1.0) { + prop_assume!(p > 0.0); + prop_assume!(p < 1.0); + let lognorm = LogNormal::standard(); let y: f64 = lognorm.quantile(p); - assert::close(y, *x, TOL); - }) + let p2 = lognorm.cdf(&y); + assert::close(p, p2, TOL); + } } #[test] diff --git a/src/dist/mixture.rs b/src/dist/mixture.rs index 2e40be6..be41714 100644 --- a/src/dist/mixture.rs +++ b/src/dist/mixture.rs @@ -9,7 +9,6 @@ use crate::dist::{Categorical, Gaussian, Poisson}; use crate::misc::{logsumexp, pflip}; use crate::traits::*; use rand::Rng; -use std::convert::TryFrom; use std::fmt; use std::sync::OnceLock; @@ -42,6 +41,37 @@ pub struct Mixture { ln_weights: OnceLock>, } +pub struct MixtureParameters { + pub component_params: Vec, + pub weights: Vec, +} + +impl Parameterized for Mixture { + type Parameters = MixtureParameters; + + fn emit_params(&self) -> Self::Parameters { + let component_params = self + .components() + .iter() + .map(|cpnt| cpnt.emit_params()) + .collect(); + + Self::Parameters { + component_params, + weights: self.weights().clone(), + } + } + + fn from_params(mut params: Self::Parameters) -> Self { + let components = params + .component_params + .drain(..) + .map(|p| Fx::from_params(p)) + .collect(); + Self::new_unchecked(params.weights, components) + } +} + #[derive(Debug, Clone, PartialEq)] #[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] #[cfg_attr(feature = "serde1", serde(rename_all = "snake_case"))] @@ -353,7 +383,7 @@ impl From> for Vec<(f64, Fx)> { } } -impl Rv for Mixture +impl HasDensity for Mixture where Fx: Rv, { @@ -374,7 +404,12 @@ where .zip(self.components.iter()) .fold(0.0, |acc, (&w, cpnt)| cpnt.f(x).mul_add(w, acc)) } +} +impl Sampleable for Mixture +where + Fx: Rv, +{ fn draw(&self, mut rng: &mut R) -> X { let k: usize = pflip(&self.weights, 1, &mut rng)[0]; self.components[k].draw(&mut rng) @@ -773,8 +808,6 @@ fn continuous_mixture_quad_points(mm: &Mixture) -> Vec where Fx: Mode + Variance, { - use std::f64::INFINITY; - let mut state = (None, None); mm.components() @@ -785,7 +818,8 @@ where match (&state, (mode, std)) { ((Some(m1), s1), (Some(m2), s2)) => { if (m2 - *m1) - > s1.unwrap_or(INFINITY).min(s2.unwrap_or(INFINITY)) + > s1.unwrap_or(f64::INFINITY) + .min(s2.unwrap_or(f64::INFINITY)) { state = (mode, std); Some(m2) @@ -903,7 +937,7 @@ macro_rules! ds_discrete_quad_bounds { }; } -ds_discrete_quad_bounds!(Mixture, u32, 0, u32::max_value()); +ds_discrete_quad_bounds!(Mixture, u32, 0, u32::MAX); #[cfg(test)] mod tests { @@ -1345,7 +1379,6 @@ mod tests { #[test] fn gauss_mixture_quad_bounds_have_zero_pdf() { use crate::dist::{InvGamma, Poisson}; - use crate::traits::Rv; let mut rng = rand::thread_rng(); let pois = Poisson::new(7.0).unwrap(); diff --git a/src/dist/mod.rs b/src/dist/mod.rs index bec5146..50d9134 100644 --- a/src/dist/mod.rs +++ b/src/dist/mod.rs @@ -17,8 +17,6 @@ mod chi_squared; mod crp; mod dirichlet; mod discrete_uniform; -#[cfg(feature = "datum")] -mod distribution; mod empirical; mod exponential; mod gamma; @@ -43,8 +41,6 @@ mod normal_inv_chi_squared; mod normal_inv_gamma; mod pareto; mod poisson; -#[cfg(feature = "datum")] -mod product; mod scaled_inv_chi_squared; mod skellam; mod students_t; @@ -64,8 +60,6 @@ pub use chi_squared::{ChiSquared, ChiSquaredError}; pub use crp::{Crp, CrpError}; pub use dirichlet::{Dirichlet, DirichletError, SymmetricDirichlet}; pub use discrete_uniform::{DiscreteUniform, DiscreteUniformError}; -#[cfg(feature = "datum")] -pub use distribution::Distribution; pub use empirical::Empirical; pub use exponential::{Exponential, ExponentialError}; pub use gamma::{Gamma, GammaError}; @@ -92,8 +86,6 @@ pub use normal_inv_chi_squared::{ pub use normal_inv_gamma::{NormalInvGamma, NormalInvGammaError}; pub use pareto::{Pareto, ParetoError}; pub use poisson::{Poisson, PoissonError}; -#[cfg(feature = "datum")] -pub use product::ProductDistribution; pub use scaled_inv_chi_squared::{ ScaledInvChiSquared, ScaledInvChiSquaredError, }; diff --git a/src/dist/mvg.rs b/src/dist/mvg.rs index 417e88b..2d78651 100644 --- a/src/dist/mvg.rs +++ b/src/dist/mvg.rs @@ -103,6 +103,27 @@ pub struct MvGaussian { cache: OnceLock, } +pub struct MvGaussianParameters { + pub mu: DVector, + // Covariance Matrix + pub cov: DMatrix, +} + +impl Parameterized for MvGaussian { + type Parameters = MvGaussianParameters; + + fn emit_params(&self) -> Self::Parameters { + Self::Parameters { + mu: self.mu().clone_owned(), + cov: self.cov().clone_owned(), + } + } + + fn from_params(params: Self::Parameters) -> Self { + Self::new_unchecked(params.mu, params.cov) + } +} + #[allow(dead_code)] #[cfg(feature = "serde1")] fn default_cache_none() -> OnceLock { @@ -395,7 +416,7 @@ impl From<&MvGaussian> for String { impl_display!(MvGaussian); -impl Rv> for MvGaussian { +impl HasDensity> for MvGaussian { fn ln_f(&self, x: &DVector) -> f64 { let diff = x - &self.mu; let det_sqrt: f64 = self @@ -411,7 +432,9 @@ impl Rv> for MvGaussian { let term: f64 = (diff.transpose() * inv * &diff)[0]; -0.5 * (det.ln() + (diff.nrows() as f64).mul_add(LN_2PI, term)) } +} +impl Sampleable> for MvGaussian { fn draw(&self, rng: &mut R) -> DVector { let dims = self.mu.len(); let norm = rand_distr::StandardNormal; @@ -537,7 +560,11 @@ mod tests { const KS_PVAL: f64 = 0.2; const MARDIA_PVAL: f64 = 0.2; - test_basic_impls!(MvGaussian::standard(3).unwrap(), DVector::zeros(3)); + test_basic_impls!( + DVector, + MvGaussian, + MvGaussian::standard(2).unwrap() + ); #[test] fn new() { diff --git a/src/dist/neg_binom.rs b/src/dist/neg_binom.rs index a4e6c26..92f3269 100644 --- a/src/dist/neg_binom.rs +++ b/src/dist/neg_binom.rs @@ -47,6 +47,26 @@ pub struct NegBinomial { r_ln_p: OnceLock, } +pub struct NegBinomialParameters { + pub r: f64, + pub p: f64, +} + +impl Parameterized for NegBinomial { + type Parameters = NegBinomialParameters; + + fn emit_params(&self) -> Self::Parameters { + Self::Parameters { + r: self.r(), + p: self.p(), + } + } + + fn from_params(params: Self::Parameters) -> Self { + Self::new_unchecked(params.r, params.p) + } +} + impl PartialEq for NegBinomial { fn eq(&self, other: &NegBinomial) -> bool { self.r == other.r && self.p == other.p @@ -114,9 +134,9 @@ impl NegBinomial { /// // r must be >= 1.0 /// assert!(nbin.set_r(0.99).is_err()); /// - /// assert!(nbin.set_r(std::f64::INFINITY).is_err()); - /// assert!(nbin.set_r(std::f64::NEG_INFINITY).is_err()); - /// assert!(nbin.set_r(std::f64::NAN).is_err()); + /// assert!(nbin.set_r(f64::INFINITY).is_err()); + /// assert!(nbin.set_r(f64::NEG_INFINITY).is_err()); + /// assert!(nbin.set_r(f64::NAN).is_err()); /// ``` #[inline] pub fn set_r(&mut self, r: f64) -> Result<(), NegBinomialError> { @@ -175,9 +195,9 @@ impl NegBinomial { /// // Too high, not in [0, 1] /// assert!(nbin.set_p(-1.1).is_err()); /// - /// assert!(nbin.set_p(std::f64::INFINITY).is_err()); - /// assert!(nbin.set_p(std::f64::NEG_INFINITY).is_err()); - /// assert!(nbin.set_p(std::f64::NAN).is_err()); + /// assert!(nbin.set_p(f64::INFINITY).is_err()); + /// assert!(nbin.set_p(f64::NEG_INFINITY).is_err()); + /// assert!(nbin.set_p(f64::NAN).is_err()); /// ``` #[inline] pub fn set_p(&mut self, p: f64) -> Result<(), NegBinomialError> { @@ -212,13 +232,15 @@ impl NegBinomial { macro_rules! impl_traits { ($kind:ty) => { - impl Rv<$kind> for NegBinomial { + impl HasDensity<$kind> for NegBinomial { fn ln_f(&self, x: &$kind) -> f64 { let xf = (*x) as f64; ln_binom(xf + self.r - 1.0, self.r - 1.0) + xf.mul_add(self.ln_1mp(), self.r_ln_p()) } + } + impl Sampleable<$kind> for NegBinomial { fn draw(&self, mut rng: &mut R) -> $kind { let q = 1.0 - self.p; let scale = q / (1.0 - q); @@ -316,7 +338,7 @@ mod tests { const TOL: f64 = 1E-10; - test_basic_impls!([count] NegBinomial::new(2.1, 0.6).unwrap()); + test_basic_impls!(u32, NegBinomial, NegBinomial::new(2.1, 0.6).unwrap()); #[test] fn new_with_good_params() { @@ -580,7 +602,7 @@ mod tests { // How many bins do we need? let k: usize = (0..100) - .position(|x| nbin.pmf(&(x as u32)) < std::f64::EPSILON) + .position(|x| nbin.pmf(&(x as u32)) < f64::EPSILON) .unwrap_or(99) + 1; @@ -612,7 +634,7 @@ mod tests { // How many bins do we need? let k: usize = (0..100) - .position(|x| nbin.pmf(&(x as u32)) < std::f64::EPSILON) + .position(|x| nbin.pmf(&(x as u32)) < f64::EPSILON) .unwrap_or(99) + 1; diff --git a/src/dist/niw.rs b/src/dist/niw.rs index b815f05..9adf716 100644 --- a/src/dist/niw.rs +++ b/src/dist/niw.rs @@ -50,6 +50,30 @@ pub struct NormalInvWishart { scale: DMatrix, } +pub struct NormalInvWishartParameters { + pub mu: DVector, + pub k: f64, + pub df: usize, + pub scale: DMatrix, +} + +impl Parameterized for NormalInvWishart { + type Parameters = NormalInvWishartParameters; + + fn emit_params(&self) -> Self::Parameters { + Self::Parameters { + mu: self.mu().clone_owned(), + k: self.k(), + df: self.df(), + scale: self.scale().clone_owned(), + } + } + + fn from_params(params: Self::Parameters) -> Self { + Self::new_unchecked(params.mu, params.k, params.df, params.scale) + } +} + #[derive(Debug, Clone, PartialEq)] #[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] #[cfg_attr(feature = "serde1", serde(rename_all = "snake_case"))] @@ -242,7 +266,7 @@ impl_display!(NormalInvWishart); // TODO: We might be able to make things faster by storing the InvWishart // because each time we create it, it clones and validates the parameters. -impl Rv for NormalInvWishart { +impl HasDensity for NormalInvWishart { fn ln_f(&self, x: &MvGaussian) -> f64 { let m = self.mu.clone(); let sigma = x.cov().clone() / self.k; @@ -251,7 +275,9 @@ impl Rv for NormalInvWishart { let iw = InvWishart::new_unchecked(self.scale.clone(), self.df); mvg.ln_f(x.mu()) + iw.ln_f(x.cov()) } +} +impl Sampleable for NormalInvWishart { fn draw(&self, mut rng: &mut R) -> MvGaussian { let iw = InvWishart::new_unchecked(self.scale.clone(), self.df); let sigma: DMatrix = iw.draw(&mut rng); @@ -306,6 +332,19 @@ impl fmt::Display for NormalInvWishartError { #[cfg(test)] mod tests { use super::*; + use crate::test_basic_impls; + + test_basic_impls!( + MvGaussian, + NormalInvWishart, + NormalInvWishart::new( + DVector::zeros(2), + 1.0, + 2, + DMatrix::identity(2, 2), + ) + .unwrap() + ); #[test] fn disallow_zero_k() { diff --git a/src/dist/niw/mvg_prior.rs b/src/dist/niw/mvg_prior.rs index 33094fc..b93a24f 100644 --- a/src/dist/niw/mvg_prior.rs +++ b/src/dist/niw/mvg_prior.rs @@ -2,7 +2,8 @@ use crate::consts::LN_2PI; use crate::data::{extract_stat_then, DataOrSuffStat, MvGaussianSuffStat}; use crate::dist::{MvGaussian, NormalInvWishart}; use crate::misc::lnmv_gamma; -use crate::traits::{ConjugatePrior, SuffStat}; +use crate::traits::ConjugatePrior; +use crate::traits::SuffStat; use nalgebra::{DMatrix, DVector}; use std::f64::consts::{LN_2, PI}; @@ -21,8 +22,8 @@ fn ln_z(k: f64, df: usize, scale: &DMatrix) -> f64 { impl ConjugatePrior, MvGaussian> for NormalInvWishart { type Posterior = Self; - type LnMCache = f64; - type LnPpCache = (Self, f64); + type MCache = f64; + type PpCache = (Self, f64); fn posterior(&self, x: &MvgData) -> NormalInvWishart { if x.n() == 0 { @@ -66,7 +67,7 @@ impl ConjugatePrior, MvGaussian> for NormalInvWishart { ln_z(self.k(), self.df(), self.scale()) } - fn ln_m_with_cache(&self, cache: &Self::LnMCache, x: &MvgData) -> f64 { + fn ln_m_with_cache(&self, cache: &Self::MCache, x: &MvgData) -> f64 { let z0 = cache; let post = self.posterior(x); let zn = ln_z(post.k(), post.df(), post.scale()); @@ -76,17 +77,13 @@ impl ConjugatePrior, MvGaussian> for NormalInvWishart { } #[inline] - fn ln_pp_cache(&self, x: &MvgData) -> Self::LnPpCache { + fn ln_pp_cache(&self, x: &MvgData) -> Self::PpCache { let post = self.posterior(x); let zn = ln_z(post.k(), post.df(), post.scale()); (post, zn) } - fn ln_pp_with_cache( - &self, - cache: &Self::LnPpCache, - y: &DVector, - ) -> f64 { + fn ln_pp_with_cache(&self, cache: &Self::PpCache, y: &DVector) -> f64 { let post = &cache.0; let zn = cache.1; @@ -106,12 +103,26 @@ impl ConjugatePrior, MvGaussian> for NormalInvWishart { #[cfg(test)] mod tests { - use nalgebra::{dmatrix, dvector}; - use super::*; + use crate::test_conjugate_prior; + use crate::traits::*; + use nalgebra::{dmatrix, dvector}; const TOL: f64 = 1E-12; + test_conjugate_prior!( + DVector, + MvGaussian, + NormalInvWishart, + NormalInvWishart::new( + DVector::zeros(2), + 1.0, + 2, + DMatrix::identity(2, 2), + ) + .unwrap() + ); + fn obs_fxtr() -> MvGaussianSuffStat { let x0v = vec![3.578_396_939_725_76, 0.725_404_224_946_106]; let x1v = vec![2.769_437_029_884_88, -0.063_054_873_189_656_2]; @@ -159,7 +170,7 @@ mod tests { #[test] fn posterior() { // This checks this implementation against the one from - // Kevin Murphey + // Kevin Murphy // Found here: https://github.com/probml/probml-utils/blob/983e107875d550957d6c046b5c1af0fbae4badff/probml_utils/dp_mixgauss_utils.py#L206-L225 let niw = NormalInvWishart::new( diff --git a/src/dist/normal_gamma.rs b/src/dist/normal_gamma.rs index 92183e7..af272c8 100644 --- a/src/dist/normal_gamma.rs +++ b/src/dist/normal_gamma.rs @@ -24,6 +24,30 @@ pub struct NormalGamma { v: f64, } +pub struct NormalGammaParameters { + pub m: f64, + pub r: f64, + pub s: f64, + pub v: f64, +} + +impl Parameterized for NormalGamma { + type Parameters = NormalGammaParameters; + + fn emit_params(&self) -> Self::Parameters { + Self::Parameters { + m: self.m(), + r: self.r(), + s: self.s(), + v: self.v(), + } + } + + fn from_params(params: Self::Parameters) -> Self { + Self::new_unchecked(params.m, params.r, params.s, params.v) + } +} + #[derive(Debug, Clone, PartialEq)] #[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] #[cfg_attr(feature = "serde1", serde(rename_all = "snake_case"))] @@ -110,9 +134,9 @@ impl NormalGamma { /// # use rv::dist::NormalGamma; /// # let mut ng = NormalGamma::new(0.0, 1.2, 2.3, 3.4).unwrap(); /// assert!(ng.set_m(-1.1).is_ok()); - /// assert!(ng.set_m(std::f64::INFINITY).is_err()); - /// assert!(ng.set_m(std::f64::NEG_INFINITY).is_err()); - /// assert!(ng.set_m(std::f64::NAN).is_err()); + /// assert!(ng.set_m(f64::INFINITY).is_err()); + /// assert!(ng.set_m(f64::NEG_INFINITY).is_err()); + /// assert!(ng.set_m(f64::NAN).is_err()); /// ``` #[inline] pub fn set_m(&mut self, m: f64) -> Result<(), NormalGammaError> { @@ -162,9 +186,9 @@ impl NormalGamma { /// assert!(ng.set_r(-1.0).is_err()); /// /// - /// assert!(ng.set_r(std::f64::INFINITY).is_err()); - /// assert!(ng.set_r(std::f64::NEG_INFINITY).is_err()); - /// assert!(ng.set_r(std::f64::NAN).is_err()); + /// assert!(ng.set_r(f64::INFINITY).is_err()); + /// assert!(ng.set_r(f64::NEG_INFINITY).is_err()); + /// assert!(ng.set_r(f64::NAN).is_err()); /// ``` #[inline] pub fn set_r(&mut self, r: f64) -> Result<(), NormalGammaError> { @@ -216,9 +240,9 @@ impl NormalGamma { /// assert!(ng.set_s(-1.0).is_err()); /// /// - /// assert!(ng.set_s(std::f64::INFINITY).is_err()); - /// assert!(ng.set_s(std::f64::NEG_INFINITY).is_err()); - /// assert!(ng.set_s(std::f64::NAN).is_err()); + /// assert!(ng.set_s(f64::INFINITY).is_err()); + /// assert!(ng.set_s(f64::NEG_INFINITY).is_err()); + /// assert!(ng.set_s(f64::NAN).is_err()); /// ``` #[inline] pub fn set_s(&mut self, s: f64) -> Result<(), NormalGammaError> { @@ -270,9 +294,9 @@ impl NormalGamma { /// assert!(ng.set_v(-1.0).is_err()); /// /// - /// assert!(ng.set_v(std::f64::INFINITY).is_err()); - /// assert!(ng.set_v(std::f64::NEG_INFINITY).is_err()); - /// assert!(ng.set_v(std::f64::NAN).is_err()); + /// assert!(ng.set_v(f64::INFINITY).is_err()); + /// assert!(ng.set_v(f64::NEG_INFINITY).is_err()); + /// assert!(ng.set_v(f64::NAN).is_err()); /// ``` #[inline] pub fn set_v(&mut self, v: f64) -> Result<(), NormalGammaError> { @@ -291,12 +315,6 @@ impl NormalGamma { pub fn set_v_unchecked(&mut self, v: f64) { self.v = v; } - - /// Return (m, r, s, v) - #[inline] - pub fn params(&self) -> (f64, f64, f64, f64) { - (self.m, self.r, self.s, self.v) - } } impl From<&NormalGamma> for String { @@ -310,7 +328,7 @@ impl From<&NormalGamma> for String { impl_display!(NormalGamma); -impl Rv for NormalGamma { +impl HasDensity for NormalGamma { fn ln_f(&self, x: &Gaussian) -> f64 { // TODO: could cache the gamma and Gaussian distributions let rho = (x.sigma() * x.sigma()).recip(); @@ -320,7 +338,9 @@ impl Rv for NormalGamma { let lnf_mu = Gaussian::new_unchecked(self.m, prior_sigma).ln_f(&x.mu()); lnf_rho + lnf_mu } +} +impl Sampleable for NormalGamma { fn draw(&self, mut rng: &mut R) -> Gaussian { // NOTE: The parameter errors in this fn shouldn't happen if the prior // parameters are valid. @@ -339,7 +359,7 @@ impl Rv for NormalGamma { .draw(&mut rng); let sigma = if rho.is_infinite() { - std::f64::EPSILON + f64::EPSILON } else { rho.recip().sqrt() }; diff --git a/src/dist/normal_gamma/gaussian_prior.rs b/src/dist/normal_gamma/gaussian_prior.rs index dfbc408..c769bc0 100644 --- a/src/dist/normal_gamma/gaussian_prior.rs +++ b/src/dist/normal_gamma/gaussian_prior.rs @@ -2,9 +2,7 @@ use std::collections::BTreeMap; use std::f64::consts::LN_2; use crate::consts::*; -use crate::data::{ - extract_stat, extract_stat_then, DataOrSuffStat, GaussianSuffStat, -}; +use crate::data::{extract_stat, extract_stat_then, GaussianSuffStat}; use crate::dist::{Gaussian, NormalGamma}; use crate::gaussian_prior_geweke_testable; use crate::misc::ln_gammafn; @@ -37,8 +35,8 @@ fn posterior_from_stat( impl ConjugatePrior for NormalGamma { type Posterior = Self; - type LnMCache = f64; - type LnPpCache = (GaussianSuffStat, f64); + type MCache = f64; + type PpCache = (GaussianSuffStat, f64); fn posterior(&self, x: &DataOrSuffStat) -> Self { extract_stat_then(x, GaussianSuffStat::new, |stat: GaussianSuffStat| { @@ -47,13 +45,13 @@ impl ConjugatePrior for NormalGamma { } #[inline] - fn ln_m_cache(&self) -> Self::LnMCache { + fn ln_m_cache(&self) -> Self::MCache { ln_z(self.r(), self.s, self.v) } fn ln_m_with_cache( &self, - cache: &Self::LnMCache, + cache: &Self::MCache, x: &DataOrSuffStat, ) -> f64 { extract_stat_then(x, GaussianSuffStat::new, |stat: GaussianSuffStat| { @@ -64,17 +62,14 @@ impl ConjugatePrior for NormalGamma { } #[inline] - fn ln_pp_cache( - &self, - x: &DataOrSuffStat, - ) -> Self::LnPpCache { + fn ln_pp_cache(&self, x: &DataOrSuffStat) -> Self::PpCache { let stat = extract_stat(x, GaussianSuffStat::new); let post_n = posterior_from_stat(self, &stat); let lnz_n = ln_z(post_n.r, post_n.s, post_n.v); (stat, lnz_n) } - fn ln_pp_with_cache(&self, cache: &Self::LnPpCache, y: &f64) -> f64 { + fn ln_pp_with_cache(&self, cache: &Self::PpCache, y: &f64) -> f64 { let mut stat = cache.0.clone(); let lnz_n = cache.1; @@ -93,9 +88,17 @@ gaussian_prior_geweke_testable!(NormalGamma, Gaussian); mod tests { use super::*; use crate::data::GaussianData; + use crate::test_conjugate_prior; const TOL: f64 = 1E-12; + test_conjugate_prior!( + f64, + Gaussian, + NormalGamma, + NormalGamma::new(0.1, 1.2, 0.5, 1.8).unwrap() + ); + #[test] fn geweke() { use crate::test::GewekeTester; diff --git a/src/dist/normal_inv_chi_squared.rs b/src/dist/normal_inv_chi_squared.rs index dd560ae..c348592 100644 --- a/src/dist/normal_inv_chi_squared.rs +++ b/src/dist/normal_inv_chi_squared.rs @@ -9,7 +9,7 @@ mod gaussian_prior; use crate::dist::{Gaussian, ScaledInvChiSquared}; use crate::impl_display; -use crate::traits::Rv; +use crate::traits::*; use rand::Rng; use std::sync::OnceLock; @@ -30,6 +30,30 @@ pub struct NormalInvChiSquared { scaled_inv_x2: OnceLock, } +pub struct NormalInvChiSquaredParameters { + pub m: f64, + pub k: f64, + pub v: f64, + pub s2: f64, +} + +impl Parameterized for NormalInvChiSquared { + type Parameters = NormalInvChiSquaredParameters; + + fn emit_params(&self) -> Self::Parameters { + Self::Parameters { + m: self.m(), + k: self.k(), + v: self.v(), + s2: self.s2(), + } + } + + fn from_params(params: Self::Parameters) -> Self { + Self::new_unchecked(params.m, params.k, params.v, params.s2) + } +} + impl PartialEq for NormalInvChiSquared { fn eq(&self, other: &Self) -> bool { self.m == other.m @@ -145,9 +169,9 @@ impl NormalInvChiSquared { /// # use rv::dist::NormalInvChiSquared; /// # let mut nix = NormalInvChiSquared::new(0.0, 1.2, 2.3, 3.4).unwrap(); /// assert!(nix.set_m(-1.1).is_ok()); - /// assert!(nix.set_m(std::f64::INFINITY).is_err()); - /// assert!(nix.set_m(std::f64::NEG_INFINITY).is_err()); - /// assert!(nix.set_m(std::f64::NAN).is_err()); + /// assert!(nix.set_m(f64::INFINITY).is_err()); + /// assert!(nix.set_m(f64::NEG_INFINITY).is_err()); + /// assert!(nix.set_m(f64::NAN).is_err()); /// ``` #[inline] pub fn set_m(&mut self, m: f64) -> Result<(), NormalInvChiSquaredError> { @@ -197,9 +221,9 @@ impl NormalInvChiSquared { /// assert!(nix.set_k(-1.0).is_err()); /// /// - /// assert!(nix.set_k(std::f64::INFINITY).is_err()); - /// assert!(nix.set_k(std::f64::NEG_INFINITY).is_err()); - /// assert!(nix.set_k(std::f64::NAN).is_err()); + /// assert!(nix.set_k(f64::INFINITY).is_err()); + /// assert!(nix.set_k(f64::NEG_INFINITY).is_err()); + /// assert!(nix.set_k(f64::NAN).is_err()); /// ``` #[inline] pub fn set_k(&mut self, k: f64) -> Result<(), NormalInvChiSquaredError> { @@ -251,9 +275,9 @@ impl NormalInvChiSquared { /// assert!(nix.set_v(-1.0).is_err()); /// /// - /// assert!(nix.set_v(std::f64::INFINITY).is_err()); - /// assert!(nix.set_v(std::f64::NEG_INFINITY).is_err()); - /// assert!(nix.set_v(std::f64::NAN).is_err()); + /// assert!(nix.set_v(f64::INFINITY).is_err()); + /// assert!(nix.set_v(f64::NEG_INFINITY).is_err()); + /// assert!(nix.set_v(f64::NAN).is_err()); /// ``` #[inline] pub fn set_v(&mut self, v: f64) -> Result<(), NormalInvChiSquaredError> { @@ -307,9 +331,9 @@ impl NormalInvChiSquared { /// assert!(nix.set_s2(-1.0).is_err()); /// /// - /// assert!(nix.set_s2(std::f64::INFINITY).is_err()); - /// assert!(nix.set_s2(std::f64::NEG_INFINITY).is_err()); - /// assert!(nix.set_s2(std::f64::NAN).is_err()); + /// assert!(nix.set_s2(f64::INFINITY).is_err()); + /// assert!(nix.set_s2(f64::NEG_INFINITY).is_err()); + /// assert!(nix.set_s2(f64::NAN).is_err()); /// ``` #[inline] pub fn set_s2(&mut self, s2: f64) -> Result<(), NormalInvChiSquaredError> { @@ -349,22 +373,20 @@ impl From<&NormalInvChiSquared> for String { impl_display!(NormalInvChiSquared); -impl Rv for NormalInvChiSquared { +impl HasDensity for NormalInvChiSquared { fn ln_f(&self, x: &Gaussian) -> f64 { let lnf_sigma = self.scaled_inv_x2().ln_f(&(x.sigma() * x.sigma())); let prior_sigma = x.sigma() / self.k.sqrt(); let lnf_mu = Gaussian::new_unchecked(self.m, prior_sigma).ln_f(&x.mu()); lnf_sigma + lnf_mu } +} +impl Sampleable for NormalInvChiSquared { fn draw(&self, mut rng: &mut R) -> Gaussian { let var: f64 = self.scaled_inv_x2().draw(&mut rng); - let sigma = if var <= 0.0 { - std::f64::EPSILON - } else { - var.sqrt() - }; + let sigma = if var <= 0.0 { f64::EPSILON } else { var.sqrt() }; let post_sigma: f64 = sigma / self.k.sqrt(); let mu: f64 = Gaussian::new(self.m, post_sigma) @@ -406,8 +428,9 @@ mod test { use crate::{test_basic_impls, verify_cache_resets}; test_basic_impls!( - NormalInvChiSquared::new(0.1, 1.2, 2.3, 3.4).unwrap(), - Gaussian::new(-1.2, 0.4).unwrap() + Gaussian, + NormalInvChiSquared, + NormalInvChiSquared::new(0.1, 1.2, 2.3, 3.4).unwrap() ); verify_cache_resets!( diff --git a/src/dist/normal_inv_chi_squared/gaussian_prior.rs b/src/dist/normal_inv_chi_squared/gaussian_prior.rs index a55bea2..69ffbe9 100644 --- a/src/dist/normal_inv_chi_squared/gaussian_prior.rs +++ b/src/dist/normal_inv_chi_squared/gaussian_prior.rs @@ -1,9 +1,7 @@ use std::collections::BTreeMap; use crate::consts::HALF_LN_PI; -use crate::data::{ - extract_stat, extract_stat_then, DataOrSuffStat, GaussianSuffStat, -}; +use crate::data::{extract_stat, extract_stat_then, GaussianSuffStat}; use crate::dist::{Gaussian, NormalInvChiSquared}; use crate::gaussian_prior_geweke_testable; use crate::misc::ln_gammafn; @@ -54,8 +52,8 @@ fn posterior_from_stat( impl ConjugatePrior for NormalInvChiSquared { type Posterior = Self; - type LnMCache = f64; - type LnPpCache = (GaussianSuffStat, f64); + type MCache = f64; + type PpCache = (GaussianSuffStat, f64); fn posterior(&self, x: &DataOrSuffStat) -> Self { extract_stat_then(x, GaussianSuffStat::new, |stat: GaussianSuffStat| { @@ -64,13 +62,13 @@ impl ConjugatePrior for NormalInvChiSquared { } #[inline] - fn ln_m_cache(&self) -> Self::LnMCache { + fn ln_m_cache(&self) -> Self::MCache { ln_z(self.k, self.v, self.s2) } fn ln_m_with_cache( &self, - cache: &Self::LnMCache, + cache: &Self::MCache, x: &DataOrSuffStat, ) -> f64 { extract_stat_then(x, GaussianSuffStat::new, |stat: GaussianSuffStat| { @@ -82,10 +80,7 @@ impl ConjugatePrior for NormalInvChiSquared { } #[inline] - fn ln_pp_cache( - &self, - x: &DataOrSuffStat, - ) -> Self::LnPpCache { + fn ln_pp_cache(&self, x: &DataOrSuffStat) -> Self::PpCache { let stat = extract_stat(x, GaussianSuffStat::new); let post_n = posterior_from_stat(self, &stat); let lnz_n = ln_z(post_n.k, post_n.v, post_n.s2); @@ -93,7 +88,7 @@ impl ConjugatePrior for NormalInvChiSquared { // post_n } - fn ln_pp_with_cache(&self, cache: &Self::LnPpCache, y: &f64) -> f64 { + fn ln_pp_with_cache(&self, cache: &Self::PpCache, y: &f64) -> f64 { let mut stat = cache.0.clone(); let lnz_n = cache.1; @@ -111,9 +106,17 @@ gaussian_prior_geweke_testable!(NormalInvChiSquared, Gaussian); #[cfg(test)] mod test { use super::*; + use crate::test_conjugate_prior; const TOL: f64 = 1E-12; + test_conjugate_prior!( + f64, + Gaussian, + NormalInvChiSquared, + NormalInvChiSquared::new(0.1, 1.2, 0.5, 1.8).unwrap() + ); + #[test] fn geweke() { use crate::test::GewekeTester; @@ -135,7 +138,7 @@ mod test { } fn post_params( - xs: &Vec, + xs: &[f64], m: f64, k: f64, v: f64, @@ -163,7 +166,7 @@ mod test { // examples/dpgmm.rs) words with the NormalInvGamma prior, then we should be // good to go. fn alternate_ln_marginal( - xs: &Vec, + xs: &[f64], m: f64, k: f64, v: f64, @@ -215,7 +218,7 @@ mod test { #[test] fn posterior_of_nothing_is_prior() { let prior = NormalInvChiSquared::new_unchecked(1.2, 2.3, 3.4, 4.5); - let post = prior.posterior(&DataOrSuffStat::None); + let post = prior.posterior(&DataOrSuffStat::from(&vec![])); assert_eq!(prior.m(), post.m()); assert_eq!(prior.k(), post.k()); assert_eq!(prior.v(), post.v()); @@ -305,7 +308,8 @@ mod test { let (m, k, v, s2) = (1.0, 2.2, 3.3, 4.4); let nix = NormalInvChiSquared::new(m, k, v, s2).unwrap(); - let ln_pp = nix.ln_pp(&x, &DataOrSuffStat::::None); + let ln_pp = + nix.ln_pp(&x, &DataOrSuffStat::::from(&vec![])); let mc_est = { let ln_fs: Vec = nix @@ -331,7 +335,8 @@ mod test { let (ln_pp, ln_m) = { let ys = vec![y]; - let data = DataOrSuffStat::::None; + let new_vec = Vec::new(); + let data = DataOrSuffStat::::from(&new_vec); let y_data = DataOrSuffStat::::from(&ys); (nix.ln_pp(&y, &data), nix.ln_m(&y_data)) }; diff --git a/src/dist/normal_inv_gamma.rs b/src/dist/normal_inv_gamma.rs index 3c87d12..8cdeba5 100644 --- a/src/dist/normal_inv_gamma.rs +++ b/src/dist/normal_inv_gamma.rs @@ -9,7 +9,7 @@ mod gaussian_prior; use crate::dist::{Gaussian, InvGamma}; use crate::impl_display; -use crate::traits::Rv; +use crate::traits::*; use rand::Rng; use std::fmt; @@ -27,6 +27,30 @@ pub struct NormalInvGamma { b: f64, } +pub struct NormalInvGammaParameters { + pub m: f64, + pub v: f64, + pub a: f64, + pub b: f64, +} + +impl Parameterized for NormalInvGamma { + type Parameters = NormalInvGammaParameters; + + fn emit_params(&self) -> Self::Parameters { + Self::Parameters { + m: self.m(), + v: self.v(), + a: self.a(), + b: self.b(), + } + } + + fn from_params(params: Self::Parameters) -> Self { + Self::new_unchecked(params.m, params.v, params.a, params.b) + } +} + #[derive(Debug, Clone, PartialEq)] #[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] #[cfg_attr(feature = "serde1", serde(rename_all = "snake_case"))] @@ -87,12 +111,6 @@ impl NormalInvGamma { NormalInvGamma { m, v, a, b } } - /// Returns (m, v, a, b) - #[inline(always)] - pub fn params(&self) -> (f64, f64, f64, f64) { - (self.m, self.v, self.a, self.b) - } - /// Get the m parameter #[inline(always)] pub fn m(&self) -> f64 { @@ -119,9 +137,9 @@ impl NormalInvGamma { /// # use rv::dist::NormalInvGamma; /// # let mut nig = NormalInvGamma::new(0.0, 1.2, 2.3, 3.4).unwrap(); /// assert!(nig.set_m(-1.1).is_ok()); - /// assert!(nig.set_m(std::f64::INFINITY).is_err()); - /// assert!(nig.set_m(std::f64::NEG_INFINITY).is_err()); - /// assert!(nig.set_m(std::f64::NAN).is_err()); + /// assert!(nig.set_m(f64::INFINITY).is_err()); + /// assert!(nig.set_m(f64::NEG_INFINITY).is_err()); + /// assert!(nig.set_m(f64::NAN).is_err()); /// ``` #[inline] pub fn set_m(&mut self, m: f64) -> Result<(), NormalInvGammaError> { @@ -171,9 +189,9 @@ impl NormalInvGamma { /// assert!(nig.set_v(-1.0).is_err()); /// /// - /// assert!(nig.set_v(std::f64::INFINITY).is_err()); - /// assert!(nig.set_v(std::f64::NEG_INFINITY).is_err()); - /// assert!(nig.set_v(std::f64::NAN).is_err()); + /// assert!(nig.set_v(f64::INFINITY).is_err()); + /// assert!(nig.set_v(f64::NEG_INFINITY).is_err()); + /// assert!(nig.set_v(f64::NAN).is_err()); /// ``` #[inline] pub fn set_v(&mut self, v: f64) -> Result<(), NormalInvGammaError> { @@ -225,9 +243,9 @@ impl NormalInvGamma { /// assert!(nig.set_a(-1.0).is_err()); /// /// - /// assert!(nig.set_a(std::f64::INFINITY).is_err()); - /// assert!(nig.set_a(std::f64::NEG_INFINITY).is_err()); - /// assert!(nig.set_a(std::f64::NAN).is_err()); + /// assert!(nig.set_a(f64::INFINITY).is_err()); + /// assert!(nig.set_a(f64::NEG_INFINITY).is_err()); + /// assert!(nig.set_a(f64::NAN).is_err()); /// ``` #[inline] pub fn set_a(&mut self, a: f64) -> Result<(), NormalInvGammaError> { @@ -279,9 +297,9 @@ impl NormalInvGamma { /// assert!(nig.set_b(-1.0).is_err()); /// /// - /// assert!(nig.set_b(std::f64::INFINITY).is_err()); - /// assert!(nig.set_b(std::f64::NEG_INFINITY).is_err()); - /// assert!(nig.set_b(std::f64::NAN).is_err()); + /// assert!(nig.set_b(f64::INFINITY).is_err()); + /// assert!(nig.set_b(f64::NEG_INFINITY).is_err()); + /// assert!(nig.set_b(f64::NAN).is_err()); /// ``` #[inline] pub fn set_b(&mut self, b: f64) -> Result<(), NormalInvGammaError> { @@ -313,7 +331,7 @@ impl From<&NormalInvGamma> for String { impl_display!(NormalInvGamma); -impl Rv for NormalInvGamma { +impl HasDensity for NormalInvGamma { fn ln_f(&self, x: &Gaussian) -> f64 { // TODO: could cache the gamma and Gaussian distributions let mu = x.mu(); @@ -324,7 +342,9 @@ impl Rv for NormalInvGamma { let lnf_mu = Gaussian::new_unchecked(self.m, prior_sigma).ln_f(&mu); lnf_sigma + lnf_mu } +} +impl Sampleable for NormalInvGamma { fn draw(&self, mut rng: &mut R) -> Gaussian { // NOTE: The parameter errors in this fn shouldn't happen if the prior // parameters are valid. @@ -335,11 +355,7 @@ impl Rv for NormalInvGamma { .unwrap() .draw(&mut rng); - let sigma = if var <= 0.0 { - std::f64::EPSILON - } else { - var.sqrt() - }; + let sigma = if var <= 0.0 { f64::EPSILON } else { var.sqrt() }; let post_sigma: f64 = self.v.sqrt() * sigma; let mu: f64 = Gaussian::new(self.m, post_sigma) @@ -374,3 +390,15 @@ impl fmt::Display for NormalInvGammaError { } } } + +#[cfg(test)] +mod test { + use super::*; + use crate::test_basic_impls; + + test_basic_impls!( + Gaussian, + NormalInvGamma, + NormalInvGamma::new(0.1, 1.2, 2.3, 3.4).unwrap() + ); +} diff --git a/src/dist/normal_inv_gamma/gaussian_prior.rs b/src/dist/normal_inv_gamma/gaussian_prior.rs index e6de677..cedbe3b 100644 --- a/src/dist/normal_inv_gamma/gaussian_prior.rs +++ b/src/dist/normal_inv_gamma/gaussian_prior.rs @@ -1,9 +1,7 @@ use std::collections::BTreeMap; use crate::consts::HALF_LN_2PI; -use crate::data::{ - extract_stat, extract_stat_then, DataOrSuffStat, GaussianSuffStat, -}; +use crate::data::{extract_stat, extract_stat_then, GaussianSuffStat}; use crate::dist::{Gaussian, NormalInvGamma}; use crate::gaussian_prior_geweke_testable; use crate::misc::ln_gammafn; @@ -26,7 +24,7 @@ fn posterior_from_stat( ) -> NormalInvGamma { let n = stat.n() as f64; - let (m, v, a, b) = nig.params(); + let super::NormalInvGammaParameters { m, v, a, b } = nig.emit_params(); let v_inv = v.recip(); @@ -45,9 +43,9 @@ fn posterior_from_stat( impl ConjugatePrior for NormalInvGamma { type Posterior = Self; - type LnMCache = f64; - type LnPpCache = (GaussianSuffStat, f64); - // type LnPpCache = NormalInvGamma; + type MCache = f64; + type PpCache = (GaussianSuffStat, f64); + // type PpCache = NormalInvGamma; fn posterior(&self, x: &DataOrSuffStat) -> Self { extract_stat_then(x, GaussianSuffStat::new, |stat: GaussianSuffStat| { @@ -56,13 +54,13 @@ impl ConjugatePrior for NormalInvGamma { } #[inline] - fn ln_m_cache(&self) -> Self::LnMCache { + fn ln_m_cache(&self) -> Self::MCache { ln_z(self.v, self.a, self.b) } fn ln_m_with_cache( &self, - cache: &Self::LnMCache, + cache: &Self::MCache, x: &DataOrSuffStat, ) -> f64 { extract_stat_then(x, GaussianSuffStat::new, |stat: GaussianSuffStat| { @@ -75,17 +73,14 @@ impl ConjugatePrior for NormalInvGamma { } #[inline] - fn ln_pp_cache( - &self, - x: &DataOrSuffStat, - ) -> Self::LnPpCache { + fn ln_pp_cache(&self, x: &DataOrSuffStat) -> Self::PpCache { let stat = extract_stat(x, GaussianSuffStat::new); let post_n = posterior_from_stat(self, &stat); let lnz_n = ln_z(post_n.v, post_n.a, post_n.b); (stat, lnz_n) } - fn ln_pp_with_cache(&self, cache: &Self::LnPpCache, y: &f64) -> f64 { + fn ln_pp_with_cache(&self, cache: &Self::PpCache, y: &f64) -> f64 { let mut stat = cache.0.clone(); let lnz_n = cache.1; @@ -104,9 +99,18 @@ gaussian_prior_geweke_testable!(NormalInvGamma, Gaussian); mod test { use super::*; use crate::consts::LN_2PI; + use crate::dist::normal_inv_gamma::NormalInvGammaParameters; + use crate::test_conjugate_prior; const TOL: f64 = 1E-12; + test_conjugate_prior!( + f64, + Gaussian, + NormalInvGamma, + NormalInvGamma::new(0.1, 1.2, 0.5, 1.8).unwrap() + ); + #[test] fn geweke() { use crate::test::GewekeTester; @@ -126,7 +130,7 @@ mod test { // Random reference I found using the same source // https://github.com/JuliaStats/ConjugatePriors.jl/blob/master/src/normalinversegamma.jl fn ln_f_ref(gauss: &Gaussian, nig: &NormalInvGamma) -> f64 { - let (m, v, a, b) = nig.params(); + let NormalInvGammaParameters { m, v, a, b } = nig.emit_params(); let mu = gauss.mu(); let sigma = gauss.sigma(); let sig2 = sigma * sigma; @@ -142,7 +146,7 @@ mod test { } fn post_params( - xs: &Vec, + xs: &[f64], m: f64, v: f64, a: f64, @@ -172,7 +176,7 @@ mod test { // examples/dpgmm.rs) words with the NormalInvGamma prior, then we should be // good to go. fn alternate_ln_marginal( - xs: &Vec, + xs: &[f64], m: f64, v: f64, a: f64, @@ -302,7 +306,7 @@ mod test { let y: f64 = -0.3; let (m, v, a, b) = (0.0, 1.2, 2.3, 3.4); let nig = NormalInvGamma::new(m, v, a, b).unwrap(); - let ln_pp = nig.ln_pp(&y, &DataOrSuffStat::None); + let ln_pp = nig.ln_pp(&y, &DataOrSuffStat::from(&vec![])); let ln_m = nig.ln_m(&DataOrSuffStat::from(&vec![y])); assert::close(ln_pp, ln_m, TOL); } diff --git a/src/dist/pareto.rs b/src/dist/pareto.rs index 34521de..e30fa47 100644 --- a/src/dist/pareto.rs +++ b/src/dist/pareto.rs @@ -27,6 +27,26 @@ pub struct Pareto { scale: f64, } +pub struct ParetoParameters { + pub shape: f64, + pub scale: f64, +} + +impl Parameterized for Pareto { + type Parameters = ParetoParameters; + + fn emit_params(&self) -> Self::Parameters { + Self::Parameters { + shape: self.shape(), + scale: self.scale(), + } + } + + fn from_params(params: Self::Parameters) -> Self { + Self::new_unchecked(params.shape, params.scale) + } +} + #[derive(Debug, Clone, PartialEq)] #[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] #[cfg_attr(feature = "serde1", serde(rename_all = "snake_case"))] @@ -98,9 +118,9 @@ impl Pareto { /// assert!(pareto.set_shape(1.1).is_ok()); /// assert!(pareto.set_shape(0.0).is_err()); /// assert!(pareto.set_shape(-1.0).is_err()); - /// assert!(pareto.set_shape(std::f64::INFINITY).is_err()); - /// assert!(pareto.set_shape(std::f64::NEG_INFINITY).is_err()); - /// assert!(pareto.set_shape(std::f64::NAN).is_err()); + /// assert!(pareto.set_shape(f64::INFINITY).is_err()); + /// assert!(pareto.set_shape(f64::NEG_INFINITY).is_err()); + /// assert!(pareto.set_shape(f64::NAN).is_err()); /// ``` #[inline] pub fn set_shape(&mut self, shape: f64) -> Result<(), ParetoError> { @@ -155,9 +175,9 @@ impl Pareto { /// assert!(pareto.set_scale(1.1).is_ok()); /// assert!(pareto.set_scale(0.0).is_err()); /// assert!(pareto.set_scale(-1.0).is_err()); - /// assert!(pareto.set_scale(std::f64::INFINITY).is_err()); - /// assert!(pareto.set_scale(std::f64::NEG_INFINITY).is_err()); - /// assert!(pareto.set_scale(std::f64::NAN).is_err()); + /// assert!(pareto.set_scale(f64::INFINITY).is_err()); + /// assert!(pareto.set_scale(f64::NEG_INFINITY).is_err()); + /// assert!(pareto.set_scale(f64::NAN).is_err()); /// ``` #[inline] pub fn set_scale(&mut self, scale: f64) -> Result<(), ParetoError> { @@ -188,7 +208,7 @@ impl_display!(Pareto); macro_rules! impl_traits { ($kind:ty) => { - impl Rv<$kind> for Pareto { + impl HasDensity<$kind> for Pareto { fn ln_f(&self, x: &$kind) -> f64 { // TODO: cache ln(shape) and ln(scale) (self.shape + 1.0).mul_add( @@ -196,7 +216,9 @@ macro_rules! impl_traits { self.shape.mul_add(self.scale.ln(), self.shape.ln()), ) } + } + impl Sampleable<$kind> for Pareto { fn draw(&self, rng: &mut R) -> $kind { let p = rand_distr::Pareto::new(self.scale, self.shape).unwrap(); @@ -318,13 +340,12 @@ mod tests { use super::*; use crate::misc::{ks_test, linspace}; use crate::test_basic_impls; - use std::f64; const TOL: f64 = 1E-12; const KS_PVAL: f64 = 0.2; const N_TRIES: usize = 5; - test_basic_impls!([continuous] Pareto::new(1.0, 0.2).unwrap()); + test_basic_impls!(f64, Pareto, Pareto::new(1.0, 0.2).unwrap()); #[test] fn new() { diff --git a/src/dist/poisson.rs b/src/dist/poisson.rs index d59f338..6da9733 100644 --- a/src/dist/poisson.rs +++ b/src/dist/poisson.rs @@ -73,6 +73,18 @@ pub struct Poisson { ln_rate: OnceLock, } +impl Parameterized for Poisson { + type Parameters = f64; + + fn emit_params(&self) -> Self::Parameters { + self.rate() + } + + fn from_params(rate: Self::Parameters) -> Self { + Self::new_unchecked(rate) + } +} + impl PartialEq for Poisson { fn eq(&self, other: &Poisson) -> bool { self.rate == other.rate @@ -151,9 +163,9 @@ impl Poisson { /// assert!(pois.set_rate(1.1).is_ok()); /// assert!(pois.set_rate(0.0).is_err()); /// assert!(pois.set_rate(-1.0).is_err()); - /// assert!(pois.set_rate(std::f64::INFINITY).is_err()); - /// assert!(pois.set_rate(std::f64::NEG_INFINITY).is_err()); - /// assert!(pois.set_rate(std::f64::NAN).is_err()); + /// assert!(pois.set_rate(f64::INFINITY).is_err()); + /// assert!(pois.set_rate(f64::NEG_INFINITY).is_err()); + /// assert!(pois.set_rate(f64::NAN).is_err()); /// ``` #[inline] pub fn set_rate(&mut self, rate: f64) -> Result<(), PoissonError> { @@ -185,12 +197,14 @@ impl_display!(Poisson); macro_rules! impl_traits { ($kind:ty) => { - impl Rv<$kind> for Poisson { + impl HasDensity<$kind> for Poisson { fn ln_f(&self, x: &$kind) -> f64 { let kf = *x as f64; kf.mul_add(self.ln_rate(), -self.rate) - ln_fact(*x as usize) } + } + impl Sampleable<$kind> for Poisson { fn draw(&self, rng: &mut R) -> $kind { let pois = RPossion::new(self.rate).unwrap(); let x: u64 = rng.sample(pois) as u64; @@ -349,7 +363,7 @@ mod tests { .sum() } - test_basic_impls!([count] Poisson::new(0.5).unwrap()); + test_basic_impls!(u32, Poisson, Poisson::new(0.5).unwrap()); #[test] fn new() { diff --git a/src/dist/product.rs b/src/dist/product.rs deleted file mode 100644 index 8fb6a0f..0000000 --- a/src/dist/product.rs +++ /dev/null @@ -1,137 +0,0 @@ -//! Distribution over multiple data types -#[cfg(feature = "serde1")] -use serde::{Deserialize, Serialize}; - -use crate::data::Datum; -use crate::dist::Distribution; -use crate::traits::Rv; - -/// A product distribution is the distribution of independent distributions. -/// -/// # Notes -/// -/// The `ProductDistribution` is an abstraction around `Vec`, which allows -/// implementation of `Rv>`. -/// -/// # Example -/// -/// Create a mixture of product distributions of Categorical * Gaussian -/// -/// ``` -/// use rv::data::Datum; -/// use rv::dist::{ -/// Categorical, Gaussian, Mixture, ProductDistribution, Distribution -/// }; -/// use rv::traits::Rv; -/// -/// // NOTE: Because the ProductDistribution is an abstraction around Vec, -/// // the user must take care to get the order of distributions in each -/// // ProductDistribution correct. -/// let prod_1 = ProductDistribution::new(vec![ -/// Distribution::Categorical(Categorical::new(&[0.1, 0.9]).unwrap()), -/// Distribution::Gaussian(Gaussian::new(3.0, 1.0).unwrap()), -/// ]); -/// -/// let prod_2 = ProductDistribution::new(vec![ -/// Distribution::Categorical(Categorical::new(&[0.9, 0.1]).unwrap()), -/// Distribution::Gaussian(Gaussian::new(-3.0, 1.0).unwrap()), -/// ]); -/// -/// let prodmix = Mixture::new(vec![0.5, 0.5], vec![prod_1, prod_2]).unwrap(); -/// -/// let mut rng = rand::thread_rng(); -/// -/// let x: Datum = prodmix.draw(&mut rng); -/// let fx = prodmix.f(&x); -/// -/// println!("draw: {:?}", x); -/// println!("f(x): {}", fx); -/// ``` -#[derive(Clone, Debug, PartialEq)] -#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] -pub struct ProductDistribution { - dists: Vec, -} - -impl ProductDistribution { - /// Create a new product distribution - /// - /// # Example - /// - /// ``` - /// use rv::data::Datum; - /// use rv::dist::{ - /// Categorical, Gaussian, Mixture, ProductDistribution, Distribution - /// }; - /// use rv::traits::Rv; - /// - /// let prod = ProductDistribution::new(vec![ - /// Distribution::Categorical(Categorical::new(&[0.1, 0.9]).unwrap()), - /// Distribution::Gaussian(Gaussian::new(3.0, 1.0).unwrap()), - /// ]); - /// - /// let mut rng = rand::thread_rng(); - /// let x: Datum = prod.draw(&mut rng); - /// ``` - pub fn new(dists: Vec) -> Self { - Self { dists } - } -} - -impl Rv> for ProductDistribution { - fn ln_f(&self, x: &Vec) -> f64 { - self.dists - .iter() - .zip(x.iter()) - .map(|(dist, x_i)| dist.ln_f(x_i)) - .sum() - } - - fn draw(&self, rng: &mut R) -> Vec { - self.dists.iter().map(|dist| dist.draw(rng)).collect() - } -} - -impl Rv for ProductDistribution { - fn ln_f(&self, x: &Datum) -> f64 { - match x { - Datum::Compound(ref xs) => self.ln_f(xs), - _ => panic!("unsupported data type for product distribution"), - } - } - - fn draw(&self, rng: &mut R) -> Datum { - Datum::Compound(self.draw(rng)) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::data::Datum; - use crate::dist::{Categorical, Distribution, Gaussian}; - - fn catgauss_mix() -> ProductDistribution { - ProductDistribution::new(vec![ - Distribution::Categorical(Categorical::new(&[0.1, 0.9]).unwrap()), - Distribution::Gaussian(Gaussian::standard()), - ]) - } - - #[test] - fn ln_f() { - let gauss = Gaussian::standard(); - let cat = Categorical::new(&[0.1, 0.9]).unwrap(); - - let x_cat = 0_u8; - let x_gauss = 1.2_f64; - - let x_prod = - Datum::Compound(vec![Datum::U8(x_cat), Datum::F64(x_gauss)]); - - let ln_f = cat.ln_f(&x_cat) + gauss.ln_f(&x_gauss); - let ln_f_prod = catgauss_mix().ln_f(&x_prod); - - assert::close(ln_f, ln_f_prod, 1e-12); - } -} diff --git a/src/dist/scaled_inv_chi_squared.rs b/src/dist/scaled_inv_chi_squared.rs index d382485..7a3de2e 100644 --- a/src/dist/scaled_inv_chi_squared.rs +++ b/src/dist/scaled_inv_chi_squared.rs @@ -1,4 +1,4 @@ -//! Χ-2 over x in (0, ∞) +//! Χ-2 over x in (0, ∞) #[cfg(feature = "serde1")] use serde::{Deserialize, Serialize}; @@ -36,6 +36,26 @@ pub struct ScaledInvChiSquared { ln_f_const: OnceLock, } +pub struct ScaledInvChiSquaredParameters { + pub v: f64, + pub t2: f64, +} + +impl Parameterized for ScaledInvChiSquared { + type Parameters = ScaledInvChiSquaredParameters; + + fn emit_params(&self) -> Self::Parameters { + Self::Parameters { + v: self.v(), + t2: self.t2(), + } + } + + fn from_params(params: Self::Parameters) -> Self { + Self::new_unchecked(params.v, params.t2) + } +} + impl PartialEq for ScaledInvChiSquared { fn eq(&self, other: &ScaledInvChiSquared) -> bool { self.v == other.v @@ -128,8 +148,8 @@ impl ScaledInvChiSquared { /// assert!(ix2.set_v(2.2).is_ok()); /// assert!(ix2.set_v(0.0).is_err()); /// assert!(ix2.set_v(-1.0).is_err()); - /// assert!(ix2.set_v(std::f64::NAN).is_err()); - /// assert!(ix2.set_v(std::f64::INFINITY).is_err()); + /// assert!(ix2.set_v(f64::NAN).is_err()); + /// assert!(ix2.set_v(f64::INFINITY).is_err()); /// ``` #[inline] pub fn set_v(&mut self, v: f64) -> Result<(), ScaledInvChiSquaredError> { @@ -184,8 +204,8 @@ impl ScaledInvChiSquared { /// assert!(ix2.set_t2(2.2).is_ok()); /// assert!(ix2.set_t2(0.0).is_err()); /// assert!(ix2.set_t2(-1.0).is_err()); - /// assert!(ix2.set_t2(std::f64::NAN).is_err()); - /// assert!(ix2.set_t2(std::f64::INFINITY).is_err()); + /// assert!(ix2.set_t2(f64::NAN).is_err()); + /// assert!(ix2.set_t2(f64::INFINITY).is_err()); /// ``` #[inline] pub fn set_t2(&mut self, t2: f64) -> Result<(), ScaledInvChiSquaredError> { @@ -234,14 +254,16 @@ impl_display!(ScaledInvChiSquared); macro_rules! impl_traits { ($kind:ty) => { - impl Rv<$kind> for ScaledInvChiSquared { + impl HasDensity<$kind> for ScaledInvChiSquared { fn ln_f(&self, x: &$kind) -> f64 { let x64 = f64::from(*x); let term_1 = -self.v * self.t2 / (2.0 * x64); let term_2 = self.v.mul_add(0.5, 1.0) * x64.ln(); self.ln_f_const() - self.ln_gamma_v_2() + term_1 - term_2 } + } + impl Sampleable<$kind> for ScaledInvChiSquared { fn draw(&self, rng: &mut R) -> $kind { let a = 0.5 * self.v; let b = 0.5 * self.v * self.t2; @@ -353,7 +375,11 @@ mod test { const KS_PVAL: f64 = 0.2; const N_TRIES: usize = 5; - test_basic_impls!([continuous] ScaledInvChiSquared::new(3.2, 1.4).unwrap()); + test_basic_impls!( + f64, + ScaledInvChiSquared, + ScaledInvChiSquared::new(3.2, 1.4).unwrap() + ); #[test] fn new() { diff --git a/src/dist/skellam.rs b/src/dist/skellam.rs index fdb8c45..5292d40 100644 --- a/src/dist/skellam.rs +++ b/src/dist/skellam.rs @@ -45,6 +45,26 @@ fn cache_default() -> RefCell> { RefCell::new(LruCache::new(unsafe { NonZeroUsize::new_unchecked(100) })) } +pub struct SkellamParameters { + pub mu_1: f64, + pub mu_2: f64, +} + +impl Parameterized for Skellam { + type Parameters = SkellamParameters; + + fn emit_params(&self) -> Self::Parameters { + Self::Parameters { + mu_1: self.mu_1(), + mu_2: self.mu_2(), + } + } + + fn from_params(params: Self::Parameters) -> Self { + Self::new_unchecked(params.mu_1, params.mu_2) + } +} + #[derive(Debug, Clone, PartialEq)] #[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] #[cfg_attr(feature = "serde1", serde(rename_all = "snake_case"))] @@ -120,9 +140,9 @@ impl Skellam { /// assert!(skel.set_mu_1(1.1).is_ok()); /// assert!(skel.set_mu_1(0.0).is_err()); /// assert!(skel.set_mu_1(-1.0).is_err()); - /// assert!(skel.set_mu_1(std::f64::INFINITY).is_err()); - /// assert!(skel.set_mu_1(std::f64::NEG_INFINITY).is_err()); - /// assert!(skel.set_mu_1(std::f64::NAN).is_err()); + /// assert!(skel.set_mu_1(f64::INFINITY).is_err()); + /// assert!(skel.set_mu_1(f64::NEG_INFINITY).is_err()); + /// assert!(skel.set_mu_1(f64::NAN).is_err()); /// ``` #[inline] pub fn set_mu_1(&mut self, mu_1: f64) -> Result<(), SkellamError> { @@ -177,9 +197,9 @@ impl Skellam { /// assert!(skel.set_mu_2(1.1).is_ok()); /// assert!(skel.set_mu_2(0.0).is_err()); /// assert!(skel.set_mu_2(-1.0).is_err()); - /// assert!(skel.set_mu_2(std::f64::INFINITY).is_err()); - /// assert!(skel.set_mu_2(std::f64::NEG_INFINITY).is_err()); - /// assert!(skel.set_mu_2(std::f64::NAN).is_err()); + /// assert!(skel.set_mu_2(f64::INFINITY).is_err()); + /// assert!(skel.set_mu_2(f64::NEG_INFINITY).is_err()); + /// assert!(skel.set_mu_2(f64::NAN).is_err()); /// ``` #[inline] pub fn set_mu_2(&mut self, mu_2: f64) -> Result<(), SkellamError> { @@ -222,7 +242,7 @@ impl_display!(Skellam); macro_rules! impl_traits { ($kind:ty) => { - impl Rv<$kind> for Skellam { + impl HasDensity<$kind> for Skellam { fn ln_f(&self, x: &$kind) -> f64 { let kf = f64::from(*x); let mut cache = self.bessel_iv_cache.borrow_mut(); @@ -239,13 +259,15 @@ macro_rules! impl_traits { -(self.mu_1 + self.mu_2) + (kf / 2.0).mul_add((self.mu_1 / self.mu_2).ln(), bf) } + } + impl Sampleable<$kind> for Skellam { fn draw(&self, rng: &mut R) -> $kind { let pois_1 = Poisson::new_unchecked(self.mu_1); let pois_2 = Poisson::new_unchecked(self.mu_2); let x_1: u32 = pois_1.draw(rng); let x_2: u32 = pois_2.draw(rng); - (x_1 - x_2) as $kind + (x_1 as i32 - x_2 as i32) as $kind } fn sample(&self, n: usize, rng: &mut R) -> Vec<$kind> { @@ -333,7 +355,7 @@ mod tests { const N_TRIES: usize = 5; const X2_PVAL: f64 = 0.2; - test_basic_impls!(Skellam::new(0.5, 2.0).unwrap(), 3_i32); + test_basic_impls!(i32, Skellam, Skellam::new(1.0, 2.0).unwrap()); #[test] fn new() { diff --git a/src/dist/students_t.rs b/src/dist/students_t.rs index be72c2a..125e00e 100644 --- a/src/dist/students_t.rs +++ b/src/dist/students_t.rs @@ -6,7 +6,6 @@ use crate::misc::ln_gammafn; use crate::traits::*; use rand::Rng; use std::f64::consts::PI; -use std::f64::INFINITY; use std::fmt; /// [Student's T distribution](https://en.wikipedia.org/wiki/Student%27s_t-distribution) @@ -19,6 +18,18 @@ pub struct StudentsT { v: f64, } +impl Parameterized for StudentsT { + type Parameters = f64; + + fn emit_params(&self) -> Self::Parameters { + self.v() + } + + fn from_params(v: Self::Parameters) -> Self { + Self::new_unchecked(v) + } +} + #[derive(Debug, Clone, PartialEq)] #[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] #[cfg_attr(feature = "serde1", serde(rename_all = "snake_case"))] @@ -30,7 +41,7 @@ pub enum StudentsTError { } impl StudentsT { - /// Create a new Student's T distribtuion with degrees of freedom, v. + /// Create a new Student's T distribution with degrees of freedom, v. #[inline] pub fn new(v: f64) -> Result { if v <= 0.0 { @@ -81,9 +92,9 @@ impl StudentsT { /// assert!(t.set_v(-1.0).is_err()); /// /// - /// assert!(t.set_v(std::f64::INFINITY).is_err()); - /// assert!(t.set_v(std::f64::NEG_INFINITY).is_err()); - /// assert!(t.set_v(std::f64::NAN).is_err()); + /// assert!(t.set_v(f64::INFINITY).is_err()); + /// assert!(t.set_v(f64::NEG_INFINITY).is_err()); + /// assert!(t.set_v(f64::NAN).is_err()); /// ``` #[inline] pub fn set_v(&mut self, v: f64) -> Result<(), StudentsTError> { @@ -120,7 +131,7 @@ impl_display!(StudentsT); macro_rules! impl_traits { ($kind:ty) => { - impl Rv<$kind> for StudentsT { + impl HasDensity<$kind> for StudentsT { fn ln_f(&self, x: &$kind) -> f64 { // TODO: could cache ln(pi*v) and ln_gamma(v/2) let vp1 = (self.v + 1.0) / 2.0; @@ -132,7 +143,9 @@ macro_rules! impl_traits { ); zterm + xterm } + } + impl Sampleable<$kind> for StudentsT { fn draw(&self, rng: &mut R) -> $kind { let t = rand_distr::StudentT::new(self.v).unwrap(); rng.sample(t) as $kind @@ -201,7 +214,7 @@ impl Kurtosis for StudentsT { if self.v > 4.0 { Some(6.0 / (self.v - 4.0)) } else if self.v > 2.0 { - Some(INFINITY) + Some(f64::INFINITY) } else { None } @@ -232,7 +245,7 @@ mod tests { const TOL: f64 = 1E-12; - test_basic_impls!([continuous] StudentsT::default()); + test_basic_impls!(f64, StudentsT); #[test] fn new() { diff --git a/src/dist/uniform.rs b/src/dist/uniform.rs index f3abd65..66c0c4d 100644 --- a/src/dist/uniform.rs +++ b/src/dist/uniform.rs @@ -38,6 +38,18 @@ pub struct Uniform { lnf: OnceLock, } +impl Parameterized for Uniform { + type Parameters = (f64, f64); + + fn emit_params(&self) -> Self::Parameters { + (self.a(), self.b()) + } + + fn from_params((a, b): Self::Parameters) -> Self { + Self::new_unchecked(a, b) + } +} + impl PartialEq for Uniform { fn eq(&self, other: &Uniform) -> bool { self.a == other.a && self.b == other.b @@ -168,7 +180,7 @@ impl_display!(Uniform); macro_rules! impl_traits { ($kind:ty) => { - impl Rv<$kind> for Uniform { + impl HasDensity<$kind> for Uniform { fn ln_f(&self, x: &$kind) -> f64 { let xf = f64::from(*x); if self.a <= xf && xf <= self.b { @@ -178,7 +190,9 @@ macro_rules! impl_traits { f64::NEG_INFINITY } } + } + impl Sampleable<$kind> for Uniform { fn draw(&self, rng: &mut R) -> $kind { let u = rand_distr::Uniform::new(self.a, self.b); rng.sample(u) as $kind @@ -290,7 +304,7 @@ mod tests { const KS_PVAL: f64 = 0.2; const N_TRIES: usize = 5; - test_basic_impls!([continuous] Uniform::default()); + test_basic_impls!(f64, Uniform); #[test] fn new() { diff --git a/src/dist/unit_powerlaw.rs b/src/dist/unit_powerlaw.rs index e39d967..0573727 100644 --- a/src/dist/unit_powerlaw.rs +++ b/src/dist/unit_powerlaw.rs @@ -12,6 +12,8 @@ use std::f64; use std::fmt; use std::sync::OnceLock; +pub mod bernoulli_prior; + /// UnitPowerLaw(α) over x in (0, 1). /// /// # Examples @@ -39,6 +41,18 @@ pub struct UnitPowerLaw { alpha_ln: OnceLock, } +impl Parameterized for UnitPowerLaw { + type Parameters = f64; + + fn emit_params(&self) -> Self::Parameters { + self.alpha() + } + + fn from_params(alpha: Self::Parameters) -> Self { + Self::new_unchecked(alpha) + } +} + impl PartialEq for UnitPowerLaw { fn eq(&self, other: &UnitPowerLaw) -> bool { self.alpha == other.alpha @@ -142,8 +156,8 @@ impl UnitPowerLaw { /// assert!(powlaw.set_alpha(0.1).is_ok()); /// assert!(powlaw.set_alpha(0.0).is_err()); /// assert!(powlaw.set_alpha(-1.0).is_err()); - /// assert!(powlaw.set_alpha(std::f64::INFINITY).is_err()); - /// assert!(powlaw.set_alpha(std::f64::NAN).is_err()); + /// assert!(powlaw.set_alpha(f64::INFINITY).is_err()); + /// assert!(powlaw.set_alpha(f64::NAN).is_err()); /// ``` #[inline] pub fn set_alpha(&mut self, alpha: f64) -> Result<(), UnitPowerLawError> { @@ -167,13 +181,13 @@ impl UnitPowerLaw { /// Evaluate or fetch cached ln(a*b) #[inline] - fn alpha_inv(&self) -> f64 { + pub fn alpha_inv(&self) -> f64 { *self.alpha_inv.get_or_init(|| self.alpha.recip()) } /// Evaluate or fetch cached ln(a*b) #[inline] - fn alpha_ln(&self) -> f64 { + pub fn alpha_ln(&self) -> f64 { *self.alpha_ln.get_or_init(|| self.alpha.ln()) } } @@ -200,11 +214,13 @@ impl_display!(UnitPowerLaw); macro_rules! impl_traits { ($kind:ty) => { - impl Rv<$kind> for UnitPowerLaw { + impl HasDensity<$kind> for UnitPowerLaw { fn ln_f(&self, x: &$kind) -> f64 { (*x as f64).ln().mul_add(self.alpha - 1.0, self.alpha_ln()) } + } + impl Sampleable<$kind> for UnitPowerLaw { fn draw(&self, rng: &mut R) -> $kind { self.invcdf(rng.gen::()) } @@ -330,13 +346,12 @@ mod tests { use super::*; use crate::misc::ks_test; use crate::test_basic_impls; - use std::f64; const TOL: f64 = 1E-12; const KS_PVAL: f64 = 0.2; const N_TRIES: usize = 5; - test_basic_impls!([continuous] UnitPowerLaw::new(1.5).unwrap()); + test_basic_impls!(f64, UnitPowerLaw, UnitPowerLaw::new(1.5).unwrap()); #[test] fn new() { @@ -413,7 +428,7 @@ mod tests { } #[test] - fn draw_should_resturn_values_within_0_to_1() { + fn draw_should_return_values_within_0_to_1() { let mut rng = rand::thread_rng(); let powlaw = UnitPowerLaw::new(2.0).unwrap(); for _ in 0..100 { diff --git a/src/dist/unit_powerlaw/bernoulli_prior.rs b/src/dist/unit_powerlaw/bernoulli_prior.rs new file mode 100644 index 0000000..a857bf7 --- /dev/null +++ b/src/dist/unit_powerlaw/bernoulli_prior.rs @@ -0,0 +1,110 @@ +use rand::Rng; +use special::Beta as SBeta; + +use crate::data::{BernoulliSuffStat, Booleable}; +use crate::dist::{Bernoulli, Beta, UnitPowerLaw}; +use crate::traits::*; + +impl HasDensity for UnitPowerLaw { + fn ln_f(&self, x: &Bernoulli) -> f64 { + self.ln_f(&x.p()) + } +} + +impl Sampleable for UnitPowerLaw { + fn draw(&self, mut rng: &mut R) -> Bernoulli { + let p: f64 = self.draw(&mut rng); + Bernoulli::new(p).expect("Failed to draw valid weight") + } +} + +impl Support for UnitPowerLaw { + fn supports(&self, x: &Bernoulli) -> bool { + 0.0 < x.p() && x.p() < 1.0 + } +} + +impl ContinuousDistr for UnitPowerLaw {} + +impl ConjugatePrior for UnitPowerLaw { + type Posterior = Beta; + type MCache = f64; + type PpCache = (f64, f64); + + #[allow(clippy::many_single_char_names)] + fn posterior(&self, x: &DataOrSuffStat) -> Beta { + let (n, k) = match x { + DataOrSuffStat::Data(xs) => { + let mut stat = BernoulliSuffStat::new(); + xs.iter().for_each(|x| stat.observe(x)); + (stat.n(), stat.k()) + } + DataOrSuffStat::SuffStat(stat) => (stat.n(), stat.k()), + }; + + let a = self.alpha() + k as f64; + let b = (1 + (n - k)) as f64; + + Beta::new(a, b).expect("Invalid posterior parameters") + } + + #[inline] + fn ln_m_cache(&self) -> Self::MCache { + -self.alpha_ln() + } + + fn ln_m_with_cache( + &self, + cache: &Self::MCache, + x: &DataOrSuffStat, + ) -> f64 { + let post = self.posterior(x); + post.alpha().ln_beta(post.beta()) - cache + } + + #[inline] + fn ln_pp_cache(&self, x: &DataOrSuffStat) -> Self::PpCache { + // P(y=1 | xs) happens to be the posterior mean + let post = self.posterior(x); + let p: f64 = post.mean().expect("Mean undefined"); + (p.ln(), (1.0 - p).ln()) + } + + fn ln_pp_with_cache(&self, cache: &Self::PpCache, y: &X) -> f64 { + // P(y=1 | xs) happens to be the posterior mean + if y.into_bool() { + cache.0 + } else { + cache.1 + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + const TOL: f64 = 1E-12; + + #[test] + fn posterior_from_data_bool() { + let data = vec![false, true, false, true, true]; + let xs = DataOrSuffStat::Data::(&data); + + let posterior = UnitPowerLaw::new(1.0).unwrap().posterior(&xs); + + assert::close(posterior.alpha(), 4.0, TOL); + assert::close(posterior.beta(), 3.0, TOL); + } + + #[test] + fn posterior_from_data_u16() { + let data: Vec = vec![0, 1, 0, 1, 1]; + let xs = DataOrSuffStat::Data::(&data); + + let posterior = UnitPowerLaw::new(1.0).unwrap().posterior(&xs); + + assert::close(posterior.alpha(), 4.0, TOL); + assert::close(posterior.beta(), 3.0, TOL); + } +} diff --git a/src/dist/vonmises.rs b/src/dist/vonmises.rs index 084fb59..15f96cd 100644 --- a/src/dist/vonmises.rs +++ b/src/dist/vonmises.rs @@ -9,7 +9,7 @@ use rand::Rng; use std::f64::consts::PI; use std::fmt; -/// [VonMises distirbution](https://en.wikipedia.org/wiki/Von_Mises_distribution) +/// [VonMises distribution](https://en.wikipedia.org/wiki/Von_Mises_distribution) /// on the circular interval (0, 2π] /// /// # Example @@ -40,6 +40,26 @@ pub struct VonMises { i0_k: f64, } +pub struct VonMisesParameters { + pub mu: f64, + pub k: f64, +} + +impl Parameterized for VonMises { + type Parameters = VonMisesParameters; + + fn emit_params(&self) -> Self::Parameters { + Self::Parameters { + mu: self.mu(), + k: self.k(), + } + } + + fn from_params(params: Self::Parameters) -> Self { + Self::new_unchecked(params.mu, params.k) + } +} + #[derive(Debug, Clone, PartialEq)] #[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] #[cfg_attr(feature = "serde1", serde(rename_all = "snake_case"))] @@ -118,9 +138,9 @@ impl VonMises { /// assert!(vm.set_mu(0.0 - 0.001).is_err()); /// assert!(vm.set_mu(2.0 * std::f64::consts::PI + 0.001).is_err()); /// - /// assert!(vm.set_mu(std::f64::NEG_INFINITY).is_err()); - /// assert!(vm.set_mu(std::f64::INFINITY).is_err()); - /// assert!(vm.set_mu(std::f64::NAN).is_err()); + /// assert!(vm.set_mu(f64::NEG_INFINITY).is_err()); + /// assert!(vm.set_mu(f64::INFINITY).is_err()); + /// assert!(vm.set_mu(f64::NAN).is_err()); /// ``` #[inline] pub fn set_mu(&mut self, mu: f64) -> Result<(), VonMisesError> { @@ -182,9 +202,9 @@ impl VonMises { /// assert!(vm.set_k(0.0).is_err()); /// assert!(vm.set_k(-1.0).is_err()); /// - /// assert!(vm.set_k(std::f64::INFINITY).is_err()); - /// assert!(vm.set_k(std::f64::NEG_INFINITY).is_err()); - /// assert!(vm.set_k(std::f64::NAN).is_err()); + /// assert!(vm.set_k(f64::INFINITY).is_err()); + /// assert!(vm.set_k(f64::NEG_INFINITY).is_err()); + /// assert!(vm.set_k(f64::NAN).is_err()); /// ``` #[inline] pub fn set_k(&mut self, k: f64) -> Result<(), VonMisesError> { @@ -222,13 +242,15 @@ impl_display!(VonMises); macro_rules! impl_traits { ($kind:ty) => { - impl Rv<$kind> for VonMises { + impl HasDensity<$kind> for VonMises { fn ln_f(&self, x: &$kind) -> f64 { // TODO: could also cache ln(i0_k) let xf = f64::from(*x); self.k.mul_add((xf - self.mu).cos(), -LN_2PI) - self.i0_k.ln() } + } + impl Sampleable<$kind> for VonMises { // Best, D. J., & Fisher, N. I. (1979). Efficient simulation of the // von Mises distribution. Applied Statistics, 152-157. // https://www.researchgate.net/publication/246035131_Efficient_Simulation_of_the_von_Mises_Distribution @@ -349,13 +371,12 @@ mod tests { use super::*; use crate::misc::ks_test; use crate::test_basic_impls; - use std::f64::EPSILON; const TOL: f64 = 1E-12; const KS_PVAL: f64 = 0.2; const N_TRIES: usize = 5; - test_basic_impls!([continuous] VonMises::default()); + test_basic_impls!(f64, VonMises); #[test] fn new_should_allow_mu_in_0_2pi() { @@ -367,7 +388,7 @@ mod tests { #[test] fn new_should_not_allow_mu_outside_0_2pi() { assert!(VonMises::new(-PI, 1.0).is_err()); - assert!(VonMises::new(-EPSILON, 1.0).is_err()); + assert!(VonMises::new(-f64::EPSILON, 1.0).is_err()); assert!(VonMises::new(2.0_f64.mul_add(PI, 0.001), 1.0).is_err()); assert!(VonMises::new(100.0, 1.0).is_err()); } diff --git a/src/dist/wishart.rs b/src/dist/wishart.rs index a90941f..26866e6 100644 --- a/src/dist/wishart.rs +++ b/src/dist/wishart.rs @@ -21,6 +21,26 @@ pub struct InvWishart { df: usize, } +pub struct InvWishartParameters { + pub inv_scale: DMatrix, + pub df: usize, +} + +impl Parameterized for InvWishart { + type Parameters = InvWishartParameters; + + fn emit_params(&self) -> Self::Parameters { + Self::Parameters { + inv_scale: self.inv_scale().clone_owned(), + df: self.df(), + } + } + + fn from_params(params: Self::Parameters) -> Self { + Self::new_unchecked(params.inv_scale, params.df) + } +} + #[derive(Debug, Clone, PartialEq, Eq)] #[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] #[cfg_attr(feature = "serde1", serde(rename_all = "snake_case"))] @@ -137,7 +157,7 @@ impl InvWishart { } } -impl Rv> for InvWishart { +impl HasDensity> for InvWishart { fn ln_f(&self, x: &DMatrix) -> f64 { let p = self.inv_scale.nrows(); let pf = p as f64; @@ -154,7 +174,9 @@ impl Rv> for InvWishart { det_s - denom + det_x + numer } +} +impl Sampleable> for InvWishart { // XXX: The complexity of this is O(df * dims^2). There is a O(dims^2) // algorithm, but it's more complicated to implement, so standby. // See https://www.math.wustl.edu/~sawyer/hmhandouts/Wishart.pdf for more @@ -177,8 +199,8 @@ impl Rv> for InvWishart { let p = self.inv_scale.nrows(); let scale = self.inv_scale.clone().try_inverse().unwrap(); let mvg = MvGaussian::new_unchecked(DVector::zeros(p), scale); - (0..n) - .map(|_| { + (0..) + .filter_map(|_| { let xs = mvg.sample(self.df, &mut rng); let y = xs.iter().fold( DMatrix::::zeros(p, p), @@ -187,8 +209,9 @@ impl Rv> for InvWishart { acc + x * x.transpose() }, ); - y.try_inverse().unwrap() + y.try_inverse() }) + .take(n) .collect() } } @@ -246,7 +269,7 @@ mod tests { const TOL: f64 = 1E-12; - test_basic_impls!(InvWishart::identity(3), DMatrix::identity(3, 3)); + test_basic_impls!(DMatrix, InvWishart, InvWishart::identity(3)); #[test] fn new_should_reject_df_too_low() { @@ -331,7 +354,7 @@ mod tests { ]; let inv_scale: DMatrix = DMatrix::from_row_slice(4, 4, &slice); let iw = InvWishart::new(inv_scale, 5).unwrap(); - for x in >>::sample::< + for x in >>::sample::< rand::rngs::ThreadRng, >(&iw, 100, &mut rng) { diff --git a/src/experimental/mod.rs b/src/experimental/mod.rs new file mode 100644 index 0000000..56bfeb4 --- /dev/null +++ b/src/experimental/mod.rs @@ -0,0 +1 @@ +pub mod stick_breaking_process; diff --git a/src/experimental/stick_breaking_process/mod.rs b/src/experimental/stick_breaking_process/mod.rs new file mode 100644 index 0000000..8127d5a --- /dev/null +++ b/src/experimental/stick_breaking_process/mod.rs @@ -0,0 +1,11 @@ +pub mod sbd; +pub mod sbd_stat; +pub mod stick_breaking; +pub mod stick_breaking_stat; +pub mod stick_sequence; + +pub use sbd::StickBreakingDiscrete; +pub use sbd_stat::StickBreakingDiscreteSuffStat; +pub use stick_breaking::{BreakSequence, PartialWeights, StickBreaking}; +// pub use stick_breaking_stat::*; +pub use stick_sequence::StickSequence; diff --git a/src/experimental/stick_breaking_process/sbd.rs b/src/experimental/stick_breaking_process/sbd.rs new file mode 100644 index 0000000..619c12a --- /dev/null +++ b/src/experimental/stick_breaking_process/sbd.rs @@ -0,0 +1,319 @@ +use super::StickSequence; +use crate::dist::Mixture; +use crate::misc::sorted_uniforms; +use crate::misc::ConvergentSequence; +use crate::traits::*; +use rand::seq::SliceRandom; +use rand::Rng; +#[cfg(feature = "serde1")] +use serde::{Deserialize, Serialize}; + +#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] +#[derive(Clone, Debug, PartialEq)] +/// A "Stick-breaking discrete" distribution parameterized by a StickSequence. +pub struct StickBreakingDiscrete { + sticks: StickSequence, +} + +impl StickBreakingDiscrete { + /// Creates a new instance of StickBreakingDiscrete with the specified StickSequence. + /// + /// # Arguments + /// + /// * `sticks` - The StickSequence used for generating random numbers. + /// + /// # Returns + /// + /// A new instance of StickBreakingDiscrete. + pub fn new(sticks: StickSequence) -> StickBreakingDiscrete { + Self { sticks } + } + + /// Calculates the inverse complementary cumulative distribution function + /// (invccdf) for the StickBreakingDiscrete distribution. This method is preferred over the + /// traditional cumulative distribution function (cdf) as it provides higher precision in the + /// tail regions of the distribution. + /// + /// # Arguments + /// + /// * `p` - The probability value for which to calculate the invccdf. + /// + /// # Returns + /// + /// The index of the first element in the StickSequence whose cumulative probability is less + /// than `p`. + pub fn invccdf(&self, p: f64) -> usize { + debug_assert!(p > 0.0 && p < 1.0); + self.sticks.extendmap_ccdf( + |ccdf| ccdf.last().unwrap() < &p, + |ccdf| ccdf.iter().position(|q| *q < p).unwrap() - 1, + ) + } + + /// Provides a reference to the StickSequence used by the StickBreakingDiscrete distribution. + /// + /// # Returns + /// + /// A reference to the StickSequence. + pub fn stick_sequence(&self) -> &StickSequence { + &self.sticks + } + + /// Calculates the inverse complementary cumulative distribution function (invccdf) for + /// multiple sorted values. This method is useful for efficiently computing the invccdf for a + /// sequence of values that are already sorted in ascending order. The returned vector contains + /// the indices of the StickSequence elements whose cumulative probabilities are less than the + /// corresponding values in `ps`. + /// + /// # Arguments + /// + /// * `ps` - A slice of probability values for which to calculate the invccdf. The values must + /// be sorted in ascending order. + /// + /// # Returns + /// + /// A vector containing the indices of the StickSequence elements whose cumulative probabilities + /// are less than the corresponding values in `ps`. + pub fn multi_invccdf_sorted(&self, ps: &[f64]) -> Vec { + let n = ps.len(); + self.sticks.extendmap_ccdf( + // Note that ccdf is decreasing, but ps is increasing + |ccdf| ccdf.last().unwrap() < ps.first().unwrap(), + |ccdf| { + let mut result: Vec = Vec::with_capacity(n); + + // Start at the end of the sorted probability values (the largest value) + let mut i: usize = n - 1; + for q in ccdf.iter().skip(1).enumerate() { + while ps[i] > *q.1 { + result.push(q.0); + if i == 0 { + break; + } else { + i -= 1; + } + } + } + result + }, + ) + } +} + +/// Implementation of the `Support` trait for `StickBreakingDiscrete`. +impl Support for StickBreakingDiscrete { + /// Checks if the given value is supported by `StickBreakingDiscrete`. + /// + /// # Arguments + /// + /// * `x` - The value to be checked. + /// + /// # Returns + /// + /// Returns `true` for all values as `StickBreakingDiscrete` supports all `usize` values, `false` otherwise. + fn supports(&self, _: &usize) -> bool { + true + } +} + +/// Implementation of the `Cdf` trait for `StickBreakingDiscrete`. +impl Cdf for StickBreakingDiscrete { + /// Calculates the survival function (SF) for a given value `x`. + /// + /// The survival function is defined as 1 minus the cumulative distribution function (CDF). + /// It represents the probability that a random variable is greater than `x`. + /// + /// # Arguments + /// + /// * `x` - The value for which to calculate the survival function. + /// + /// # Returns + /// + /// The calculated survival function value as a `f64`. + fn sf(&self, x: &usize) -> f64 { + self.sticks.ccdf(*x + 1) + } + + /// Calculates the cumulative distribution function (CDF) for a given value `x`. + /// + /// The cumulative distribution function (CDF) represents the probability that a random variable + /// is less than or equal to `x`. + /// + /// # Arguments + /// + /// * `x` - The value for which to calculate the cumulative distribution function. + /// + /// # Returns + /// + /// The calculated cumulative distribution function value as a `f64`. + fn cdf(&self, x: &usize) -> f64 { + 1.0 - self.sf(x) + } +} + +impl InverseCdf for StickBreakingDiscrete { + /// Calculates the inverse cumulative distribution function (invcdf) for a given probability `p`. + /// + /// The inverse cumulative distribution function (invcdf) represents the value below which a random variable + /// falls with probability `p`. + /// + /// # Arguments + /// + /// * `p` - The probability value for which to calculate the invcdf. + /// + /// # Returns + /// + /// The calculated invcdf value as a `usize`. + fn invcdf(&self, p: f64) -> usize { + self.invccdf(1.0 - p) + } +} + +impl DiscreteDistr for StickBreakingDiscrete {} + +impl Mode for StickBreakingDiscrete { + /// Calculates the mode of the `StickBreakingDiscrete` distribution. + /// + /// The mode is the value that appears most frequently in a data set or probability distribution. + /// + /// # Returns + /// + /// The mode of the distribution as an `Option`. Returns `None` if the mode cannot be determined. + fn mode(&self) -> Option { + let w0 = self.sticks.weight(0); + // Once the unallocated mass is less than that of the first stick, the + // allocated mass is guaranteed to contain the mode. + let n = self.sticks.extendmap_ccdf( + |ccdf| ccdf.last().unwrap() < &w0, + |ccdf| { + let weights: Vec = + ccdf.windows(2).map(|qs| qs[0] - qs[1]).collect(); + weights + .iter() + .enumerate() + .max_by(|a, b| a.1.partial_cmp(b.1).unwrap()) + .map(|(i, _)| i) + }, + ); + n + } +} + +/// Provides density and log-density functions for StickBreakingDiscrete. +impl HasDensity for StickBreakingDiscrete { + /// Computes the density of a given stick index. + /// + /// # Arguments + /// + /// * `n` - The index of the stick. + /// + /// # Returns + /// + /// The density of the stick at index `n`. + fn f(&self, n: &usize) -> f64 { + let sticks = &self.sticks; + sticks.weight(*n) + } + + /// Computes the natural logarithm of the density of a given stick index. + /// + /// # Arguments + /// + /// * `n` - The index of the stick. + /// + /// # Returns + /// + /// The natural logarithm of the density of the stick at index `n`. + fn ln_f(&self, n: &usize) -> f64 { + self.f(n).ln() + } +} + +/// Enables sampling from StickBreakingDiscrete. +impl Sampleable for StickBreakingDiscrete { + /// Draws a single sample from the distribution. + /// + /// # Type Parameters + /// + /// * `R` - The random number generator type. + /// + /// # Arguments + /// + /// * `rng` - A mutable reference to the random number generator. + /// + /// # Returns + /// + /// A single sample as a usize. + fn draw(&self, rng: &mut R) -> usize { + let u: f64 = rng.gen(); + self.invccdf(u) + } + + /// Draws multiple samples from the distribution and shuffles them. + /// + /// # Type Parameters + /// + /// * `R` - The random number generator type. + /// + /// # Arguments + /// + /// * `n` - The number of samples to draw. + /// * `rng` - A mutable reference to the random number generator. + /// + /// # Returns + /// + /// A vector of usize samples, shuffled. + fn sample(&self, n: usize, mut rng: &mut R) -> Vec { + let ps = sorted_uniforms(n, &mut rng); + let mut result = self.multi_invccdf_sorted(&ps); + + // At this point `result` is sorted, so we need to shuffle it. + // Note that shuffling is O(n) but sorting is O(n log n) + result.shuffle(&mut rng); + result + } +} + +impl Entropy for StickBreakingDiscrete { + fn entropy(&self) -> f64 { + let probs = (0..).map(|n| self.f(&n)); + probs + .map(|p| p * p.ln()) + .scan(0.0, |state, x| { + *state -= x; + Some(*state) + }) + .limit(1e-10) + } +} + +impl Entropy for &Mixture { + fn entropy(&self) -> f64 { + let probs = (0..).map(|n| self.f(&n)); + probs + .map(|p| p * p.ln()) + .scan(0.0, |state, x| { + *state -= x; + Some(*state) + }) + .limit(1e-10) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::prelude::*; + use rand::thread_rng; + + #[test] + fn test_multi_invccdf_sorted() { + let sticks = StickSequence::new(UnitPowerLaw::new(10.0).unwrap(), None); + let sbd = StickBreakingDiscrete::new(sticks); + let ps = sorted_uniforms(5, &mut thread_rng()); + assert_eq!( + sbd.multi_invccdf_sorted(&ps), + ps.iter().rev().map(|p| sbd.invccdf(*p)).collect::>() + ) + } +} diff --git a/src/experimental/stick_breaking_process/sbd_stat.rs b/src/experimental/stick_breaking_process/sbd_stat.rs new file mode 100644 index 0000000..3942712 --- /dev/null +++ b/src/experimental/stick_breaking_process/sbd_stat.rs @@ -0,0 +1,205 @@ +use crate::experimental::stick_breaking_process::sbd::StickBreakingDiscrete; +use crate::traits::{HasSuffStat, SuffStat}; +#[cfg(feature = "serde1")] +use serde::{Deserialize, Serialize}; + +/// Represents the sufficient statistics for a Stick-Breaking Discrete distribution. +/// +/// This struct encapsulates the sufficient statistics for a Stick-Breaking Discrete distribution, +/// primarily involving a vector of counts representing the observed data. +#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "serde1", serde(rename_all = "snake_case"))] +#[derive(Clone, Debug, PartialEq)] +pub struct StickBreakingDiscreteSuffStat { + /// A vector of counts for observed data. + /// + /// Each element represents the count of observations for a given category. + counts: Vec, +} + +impl StickBreakingDiscreteSuffStat { + /// Constructs a new instance. + /// + /// Initializes a new `StickBreakingDiscreteSuffStat` with an empty vector of counts. + /// + /// # Returns + /// + /// A new `StickBreakingDiscreteSuffStat` instance. + pub fn new() -> Self { + Self { counts: Vec::new() } + } + + /// Calculates break pairs for probabilities. + /// + /// Returns a vector of pairs where each pair consists of the sum of all counts after the current index and the count at the current index. + /// + /// # Returns + /// + /// A vector of `(usize, usize)` pairs for calculating probabilities. + pub fn break_pairs(&self) -> Vec<(usize, usize)> { + let mut s = self.counts.iter().sum(); + self.counts + .iter() + .map(|&x| { + s -= x; + (s, x) + }) + .collect() + } + + /// Provides read-only access to counts. + /// + /// # Returns + /// + /// A reference to the vector of counts. + pub fn counts(&self) -> &Vec { + &self.counts + } +} + +impl From<&[usize]> for StickBreakingDiscreteSuffStat { + /// Constructs from a slice of counts. + /// + /// Allows creation from a slice of counts, converting raw observation data into a sufficient statistic. + /// + /// # Arguments + /// + /// * `data` - A slice of counts. + /// + /// # Returns + /// + /// A new `StickBreakingDiscreteSuffStat` instance. + fn from(data: &[usize]) -> Self { + let mut stat = StickBreakingDiscreteSuffStat::new(); + stat.observe_many(data); + stat + } +} + +impl Default for StickBreakingDiscreteSuffStat { + /// Returns a default instance. + /// + /// Equivalent to `new()`, for APIs requiring a default constructor. + /// + /// # Returns + /// + /// A default `StickBreakingDiscreteSuffStat` instance. + fn default() -> Self { + Self::new() + } +} + +impl HasSuffStat for StickBreakingDiscrete { + type Stat = StickBreakingDiscreteSuffStat; + + /// Initializes an empty sufficient statistic. + /// + /// # Returns + /// + /// An empty `StickBreakingDiscreteSuffStat`. + fn empty_suffstat(&self) -> Self::Stat { + Self::Stat::new() + } + + /// Calculates the log probability density of observed data. + /// + /// # Arguments + /// + /// * `stat` - A reference to the sufficient statistic. + /// + /// # Returns + /// + /// The natural logarithm of the probability of the observed data. + fn ln_f_stat(&self, stat: &Self::Stat) -> f64 { + self.stick_sequence() + .weights(stat.counts.len()) + .0 + .iter() + .zip(stat.counts.iter()) + .map(|(w, c)| (*c as f64) * w.ln()) + .sum() + } +} + +impl SuffStat for StickBreakingDiscreteSuffStat { + /// Returns the total count of observations. + /// + /// # Returns + /// + /// The total count of all observed data. + fn n(&self) -> usize { + self.counts.iter().sum() + } + + /// Updates the statistic with a new observation. + /// + /// # Arguments + /// + /// * `i` - The index at which to increment the count. + fn observe(&mut self, i: &usize) { + if self.counts.len() < *i + 1 { + self.counts.resize(*i + 1, 0) + } + self.counts[*i] += 1; + } + + /// Removes a previously observed data point. + /// + /// # Arguments + /// + /// * `i` - The index at which to decrement the count. + /// + /// # Panics + /// + /// Panics if there are no observations of the specified category to forget. + fn forget(&mut self, i: &usize) { + assert!(self.counts[*i] > 0, "No observations of {i} to forget."); + self.counts[*i] -= 1; + } +} +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_break_pairs() { + let suff_stat = StickBreakingDiscreteSuffStat { + counts: vec![1, 2, 3], + }; + + let pairs = suff_stat.break_pairs(); + assert_eq!(pairs, vec![(5, 1), (3, 2), (0, 3)]); + } + + // #[test] + // fn test_ln_f_stat() { + // let sbd = StickBreakingDiscrete::new(); + // let suff_stat = StickBreakingDiscreteSuffStat { + // counts: vec![1, 2, 3], + // }; + + // let ln_f_stat = sbd.ln_f_stat(&suff_stat); + // assert_eq!(ln_f_stat, 2.1972245773362196); // Replace with the expected value + // } + + #[test] + fn test_observe_and_forget() { + let mut suff_stat = StickBreakingDiscreteSuffStat::new(); + + suff_stat.observe(&1); + suff_stat.observe(&2); + suff_stat.observe(&2); + suff_stat.forget(&2); + + assert_eq!(suff_stat.counts, vec![0, 1, 1]); + assert_eq!(suff_stat.n(), 2); + } + + #[test] + fn test_new_is_default() { + assert!( + StickBreakingDiscreteSuffStat::new() + == StickBreakingDiscreteSuffStat::default() + ); + } +} diff --git a/src/experimental/stick_breaking_process/stick_breaking.rs b/src/experimental/stick_breaking_process/stick_breaking.rs new file mode 100644 index 0000000..b5456d1 --- /dev/null +++ b/src/experimental/stick_breaking_process/stick_breaking.rs @@ -0,0 +1,580 @@ +use crate::experimental::stick_breaking_process::StickBreakingDiscrete; +use crate::experimental::stick_breaking_process::StickBreakingDiscreteSuffStat; +// use crate::experimental::stick_breaking_process::StickBreakingSuffStat; +use crate::experimental::stick_breaking_process::StickSequence; +use crate::prelude::*; +use crate::traits::*; +use itertools::Either; +use itertools::EitherOrBoth::{Both, Left, Right}; +use itertools::Itertools; +use rand::Rng; +use special::Beta as BetaFn; + +#[cfg(feature = "serde1")] +use serde::{Deserialize, Serialize}; + +#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "serde1", serde(rename_all = "snake_case"))] +#[derive(Clone, Debug, PartialEq)] +/// Represents a stick-breaking process. +pub struct StickBreaking { + break_prefix: Vec, + break_tail: UnitPowerLaw, +} + +/// Implementation of the `StickBreaking` struct. +impl StickBreaking { + /// Creates a new instance of `StickBreaking` with the given `breaker`. + /// + /// # Arguments + /// + /// * `breaker` - The `UnitPowerLaw` used for stick breaking. + /// + /// # Returns + /// + /// A new instance of `StickBreaking`. + /// + /// # Example + /// ``` + /// use rv::prelude::*; + /// use rv::experimental::stick_breaking_process::StickBreaking; + /// + /// let alpha = 5.0; + /// let stick_breaking = StickBreaking::new(UnitPowerLaw::new(alpha).unwrap()); + /// ``` + pub fn new(breaker: UnitPowerLaw) -> Self { + let break_prefix = Vec::new(); + Self { + break_prefix, + break_tail: breaker, + } + } + + pub fn from_alpha(alpha: f64) -> Result { + let breaker = UnitPowerLaw::new(alpha)?; + Ok(Self::new(breaker)) + } + + /// Sets the alpha parameter for both the break_tail and all Beta distributions in break_prefix. + /// + /// # Arguments + /// + /// * `alpha` - The new alpha value to set. + /// + /// # Returns + /// + /// A result indicating success or containing a `UnitPowerLawError` if setting alpha on `break_tail` fails, + /// or a `BetaError` if setting alpha on any `Beta` distribution in `break_prefix` fails. + pub fn set_alpha(&mut self, alpha: f64) -> Result<(), BetaError> { + let old_alpha = self.alpha(); + self.break_tail.set_alpha(alpha).map_err(|e| match e { + UnitPowerLawError::AlphaNotFinite { alpha } => { + BetaError::AlphaNotFinite { alpha } + } + UnitPowerLawError::AlphaTooLow { alpha } => { + BetaError::AlphaTooLow { alpha } + } + })?; + let d_alpha = alpha - old_alpha; + for b in &mut self.break_prefix { + b.set_alpha(b.alpha() + d_alpha)?; + } + Ok(()) + } + + pub fn break_prefix(&self) -> &Vec { + &self.break_prefix + } + + pub fn break_tail(&self) -> &UnitPowerLaw { + &self.break_tail + } + + pub fn break_dists( + &self, + ) -> impl Iterator> { + self.break_prefix + .iter() + .map(Either::Left) + .chain(std::iter::repeat(Either::Right(&self.break_tail))) + } + + pub fn alpha(&self) -> f64 { + self.break_tail.alpha() + } +} + +pub struct PartialWeights(pub Vec); +pub struct BreakSequence(pub Vec); + +impl From<&BreakSequence> for PartialWeights { + fn from(bs: &BreakSequence) -> Self { + let mut remaining = 1.0; + let ws = + bs.0.iter() + .map(|b| { + debug_assert!((0.0..=1.0).contains(b)); + let w = (1.0 - b) * remaining; + debug_assert!((0.0..=1.0).contains(&w)); + remaining -= w; + debug_assert!((0.0..=1.0).contains(&remaining)); + w + }) + .collect(); + PartialWeights(ws) + } +} + +impl From<&PartialWeights> for BreakSequence { + fn from(ws: &PartialWeights) -> Self { + let mut remaining = 1.0; + let bs = + ws.0.iter() + .map(|w| { + debug_assert!((0.0..=1.0).contains(w)); + let b = 1.0 - (w / remaining); + debug_assert!((0.0..=1.0).contains(&b)); + remaining -= w; + debug_assert!((0.0..=1.0).contains(&remaining)); + b + }) + .collect(); + BreakSequence(bs) + } +} + +/// Implements the `HasDensity` trait for `StickBreaking`. +impl HasDensity for StickBreaking { + /// Calculates the natural logarithm of the density function for the given input `x`. + /// + /// # Arguments + /// + /// * `x` - A reference to a slice of `f64` values. + /// + /// # Returns + /// + /// The natural logarithm of the density function. + fn ln_f(&self, w: &PartialWeights) -> f64 { + self.break_dists() + .zip(BreakSequence::from(w).0.iter()) + .map(|(b, p)| match b { + Either::Left(beta) => beta.ln_f(p), + Either::Right(unit_powlaw) => unit_powlaw.ln_f(p), + }) + .sum() + } +} + +impl Sampleable for StickBreaking { + /// Draws a sample from the StickBreaking distribution. + /// + /// # Arguments + /// + /// * `rng` - A mutable reference to the random number generator. + /// + /// # Returns + /// + /// A `StickSequence` representing the drawn sample. + fn draw(&self, rng: &mut R) -> StickSequence { + let seed: u64 = rng.gen(); + + let seq = StickSequence::new(self.break_tail.clone(), Some(seed)); + for beta in &self.break_prefix { + let p = beta.draw(rng); + seq.push_break(p); + } + seq + } +} + +/// Implements the `Sampleable` trait for `StickBreaking`. +impl Sampleable for StickBreaking { + /// Draws a sample from the `StickBreaking` distribution. + /// + /// # Arguments + /// + /// * `rng` - A mutable reference to the random number generator. + /// + /// # Returns + /// + /// A sample from the `StickBreaking` distribution. + fn draw(&self, rng: &mut R) -> StickBreakingDiscrete { + StickBreakingDiscrete::new(self.draw(rng)) + } +} + +/// Implementation of the `ConjugatePrior` trait for the `StickBreaking` struct. +impl ConjugatePrior for StickBreaking { + type Posterior = StickBreaking; + type MCache = (); + type PpCache = Self::Posterior; + + /// Computes the logarithm of the marginal likelihood cache. + fn ln_m_cache(&self) -> Self::MCache {} + + /// Computes the logarithm of the predictive probability cache. + fn ln_pp_cache( + &self, + x: &DataOrSuffStat, + ) -> Self::PpCache { + self.posterior(x) + } + + /// Computes the posterior distribution from the sufficient statistic. + fn posterior_from_suffstat( + &self, + stat: &StickBreakingDiscreteSuffStat, + ) -> Self::Posterior { + let pairs = stat.break_pairs(); + let new_prefix = self + .break_prefix + .iter() + .zip_longest(pairs) + .map(|pair| match pair { + Left(beta) => beta.clone(), + Right((a, b)) => Beta::new( + self.break_tail.alpha() + a as f64, + 1.0 + b as f64, + ) + .unwrap(), + Both(beta, (a, b)) => { + Beta::new(beta.alpha() + a as f64, beta.beta() + b as f64) + .unwrap() + } + }) + .collect(); + StickBreaking { + break_prefix: new_prefix, + break_tail: self.break_tail.clone(), + } + } + + fn posterior( + &self, + x: &DataOrSuffStat, + ) -> Self::Posterior { + match x { + DataOrSuffStat::Data(xs) => { + let mut stat = StickBreakingDiscreteSuffStat::new(); + stat.observe_many(xs); + self.posterior_from_suffstat(&stat) + } + DataOrSuffStat::SuffStat(stat) => { + self.posterior_from_suffstat(stat) + } + } + } + + /// Computes the logarithm of the marginal likelihood. + fn ln_m(&self, x: &DataOrSuffStat) -> f64 { + let count_pairs = match x { + DataOrSuffStat::Data(xs) => { + let mut stat = StickBreakingDiscreteSuffStat::new(); + stat.observe_many(xs); + stat.break_pairs() + } + DataOrSuffStat::SuffStat(stat) => stat.break_pairs(), + }; + let alpha = self.break_tail.alpha(); + let params = self.break_prefix.iter().map(|b| (b.alpha(), b.beta())); + count_pairs + .iter() + .zip_longest(params) + .map(|pair| match pair { + Left((yes, no)) => { + let (yes, no) = (*yes as f64, *no as f64); + + // TODO: Simplify this after everything is working + (yes + alpha).ln_beta(no + 1.0) - alpha.ln_beta(1.0) + } + Right((_a, _b)) => 0.0, + Both((yes, no), (a, b)) => { + let (yes, no) = (*yes as f64, *no as f64); + (yes + a).ln_beta(no + b) - a.ln_beta(b) + } + }) + .sum() + } + + /// Computes the logarithm of the marginal likelihood with cache. + fn ln_m_with_cache( + &self, + _cache: &Self::MCache, + x: &DataOrSuffStat, + ) -> f64 { + self.ln_m(x) + } + + /// Computes the logarithm of the predictive probability with cache. + fn ln_pp_with_cache(&self, cache: &Self::PpCache, y: &usize) -> f64 { + cache.ln_m(&DataOrSuffStat::Data(&[*y])) + } + + /// Computes the predictive probability. + fn pp( + &self, + y: &usize, + x: &DataOrSuffStat, + ) -> f64 { + let post = self.posterior(x); + post.m(&DataOrSuffStat::Data(&[*y])) + } +} + +#[cfg(test)] +mod tests { + use proptest::prelude::*; + + use super::*; + + proptest! { + #[test] + fn partial_weights_to_break_sequence(v in prop::collection::vec(0.0..=1.0, 1..100), m in 0.0..=1.0) { + // we want the sum of ws to be in the range [0, 1] + let multiplier: f64 = m / v.iter().sum::(); + let ws = PartialWeights(v.iter().map(|w| w * multiplier).collect()); + let bs = BreakSequence::from(&ws); + assert::close(ws.0, PartialWeights::from(&bs).0, 1e-10); + } + } + + proptest! { + #[test] + fn break_sequence_to_partial_weights(v in prop::collection::vec(0.0..=1.0, 1..100)) { + let bs = BreakSequence(v); + let ws = PartialWeights::from(&bs); + let bs2 = BreakSequence::from(&ws); + assert::close(bs.0, bs2.0, 1e-10); + } + } + + #[test] + fn sb_ln_m_vs_monte_carlo() { + use crate::misc::logsumexp; + + let n_samples = 1_000_000; + let xs: Vec = vec![1, 2, 3]; + + let sb = StickBreaking::new(UnitPowerLaw::new(5.0).unwrap()); + let obs = DataOrSuffStat::Data(&xs); + let ln_m = sb.ln_m(&obs); + + let mc_est = { + let ln_fs: Vec = sb + .sample_stream(&mut rand::thread_rng()) + .take(n_samples) + .map(|sbd: StickBreakingDiscrete| { + xs.iter().map(|x| sbd.ln_f(x)).sum::() + }) + .collect(); + logsumexp(&ln_fs) - (n_samples as f64).ln() + }; + // high error tolerance. MC estimation is not the most accurate... + assert::close(ln_m, mc_est, 1e-2); + } + + #[test] + fn sb_pp_posterior() { + let sb = StickBreaking::new(UnitPowerLaw::new(5.0).unwrap()); + let sb_pp = sb.pp(&3, &DataOrSuffStat::Data(&[1, 2])); + let post = sb.posterior(&DataOrSuffStat::Data(&[1, 2])); + let post_f = post.pp( + &3, + &DataOrSuffStat::SuffStat(&StickBreakingDiscreteSuffStat::new()), + ); + assert::close(sb_pp, post_f, 1e-10); + } + + #[test] + fn sb_repeated_obs_more_likely() { + let sb = StickBreaking::new(UnitPowerLaw::new(5.0).unwrap()); + let sb_m = sb.ln_m(&DataOrSuffStat::Data(&[10])); + let post = sb.posterior(&DataOrSuffStat::Data(&[10])); + let post_m = post.ln_m(&DataOrSuffStat::Data(&[10])); + assert!(post_m > sb_m); + } + + #[test] + fn sb_bayes_law() { + let mut rng = rand::thread_rng(); + + // Prior + let prior = StickBreaking::new(UnitPowerLaw::new(5.0).unwrap()); + let par: StickSequence = prior.draw(&mut rng); + let par_data = par.weights(7); + let prior_lnf = prior.ln_f(&par_data); + + // Likelihood + let lik = StickBreakingDiscrete::new(par); + let lik_data: &usize = &5; + let lik_lnf = lik.ln_f(lik_data); + + // Evidence + let ln_ev = prior.ln_m(&DataOrSuffStat::Data(&[*lik_data])); + + // Posterior + let post = prior.posterior(&DataOrSuffStat::Data(&[*lik_data])); + let post_lnf = post.ln_f(&par_data); + + // Bayes' law + assert::close(post_lnf, prior_lnf + lik_lnf - ln_ev, 1e-12); + } + + #[test] + fn sb_pp_is_quotient_of_marginals() { + // pp(x|y) = m({x, y})/m(x) + let sb = StickBreaking::new(UnitPowerLaw::new(5.0).unwrap()); + let sb_pp = sb.pp(&1, &DataOrSuffStat::Data(&[0])); + + let m_1 = sb.m(&DataOrSuffStat::Data(&[0])); + let m_1_2 = sb.m(&DataOrSuffStat::Data(&[0, 1])); + + assert::close(sb_pp, m_1_2 / m_1, 1e-12); + } + + #[test] + fn sb_big_alpha_heavy_tails() { + let sb_5 = StickBreaking::new(UnitPowerLaw::new(5.0).unwrap()); + let sb_2 = StickBreaking::new(UnitPowerLaw::new(2.0).unwrap()); + let sb_pt5 = StickBreaking::new(UnitPowerLaw::new(0.5).unwrap()); + + let m_pt5_10 = sb_pt5.m(&DataOrSuffStat::Data(&[10])); + let m_2_10 = sb_2.m(&DataOrSuffStat::Data(&[10])); + let m_5_10 = sb_5.m(&DataOrSuffStat::Data(&[10])); + + assert!(m_pt5_10 < m_2_10); + assert!(m_2_10 < m_5_10); + } + + #[test] + fn sb_marginal_zero() { + let sb = StickBreaking::new(UnitPowerLaw::new(3.0).unwrap()); + let m_0 = sb.m(&DataOrSuffStat::Data(&[0])); + let bern = Bernoulli::new(3.0 / 4.0).unwrap(); + assert::close(m_0, bern.f(&0), 1e-12); + } + + #[test] + fn sb_postpred_zero() { + let sb = StickBreaking::new(UnitPowerLaw::new(3.0).unwrap()); + let pp_0 = sb.pp(&0, &DataOrSuffStat::Data(&[0])); + let bern = Bernoulli::new(3.0 / 5.0).unwrap(); + assert::close(pp_0, bern.f(&0), 1e-12); + } + + #[test] + fn sb_pp_zero_marginals() { + // pp(x|y) = m({x, y})/m(x) + let sb = StickBreaking::new(UnitPowerLaw::new(5.0).unwrap()); + let sb_pp = sb.pp(&0, &DataOrSuffStat::Data(&[0])); + + let m_1 = sb.m(&DataOrSuffStat::Data(&[0])); + let m_1_2 = sb.m(&DataOrSuffStat::Data(&[0, 0])); + + assert::close(sb_pp, m_1_2 / m_1, 1e-12); + } + + #[test] + fn sb_posterior_obs_one() { + let sb = StickBreaking::new(UnitPowerLaw::new(3.0).unwrap()); + let post = sb.posterior(&DataOrSuffStat::Data(&[2])); + + assert_eq!(post.break_prefix[0], Beta::new(4.0, 1.0).unwrap()); + assert_eq!(post.break_prefix[1], Beta::new(4.0, 1.0).unwrap()); + assert_eq!(post.break_prefix[2], Beta::new(3.0, 2.0).unwrap()); + } + + #[test] + fn sb_logposterior_diff() { + // Like Bayes Law, but takes a quotient to cancel evidence + + let mut rng = rand::thread_rng(); + let sb = StickBreaking::new(UnitPowerLaw::new(3.0).unwrap()); + let seq1: StickSequence = sb.draw(&mut rng); + let seq2: StickSequence = sb.draw(&mut rng); + + let w1 = seq1.weights(3); + let w2 = seq2.weights(3); + + let logprior_diff = sb.ln_f(&w1) - sb.ln_f(&w2); + + let data = [1, 2]; + let stat = StickBreakingDiscreteSuffStat::from(&data[..]); + let post = sb.posterior(&DataOrSuffStat::SuffStat(&stat)); + let logpost_diff = post.ln_f(&w1) - post.ln_f(&w2); + + let sbd1 = StickBreakingDiscrete::new(seq1); + let sbd2 = StickBreakingDiscrete::new(seq2); + let loglik_diff = sbd1.ln_f_stat(&stat) - sbd2.ln_f_stat(&stat); + + assert::close(logpost_diff, loglik_diff + logprior_diff, 1e-12); + } + + #[test] + fn sb_posterior_rejection_sampling() { + let mut rng = rand::thread_rng(); + let sb = StickBreaking::new(UnitPowerLaw::new(3.0).unwrap()); + + let num_samples = 1000; + + // Our computed posterior + let data = [10]; + let post = sb.posterior(&DataOrSuffStat::Data(&data[..])); + + // An approximation using rejection sampling + let mut stat = StickBreakingDiscreteSuffStat::new(); + let mut n = 0; + while n < num_samples { + let seq: StickSequence = sb.draw(&mut rng); + let sbd = StickBreakingDiscrete::new(seq.clone()); + if sbd.draw(&mut rng) == 10 { + stat.observe(&sbd.draw(&mut rng)); + n += 1; + } + } + + let counts = stat.counts(); + + // This would be counts.len() - 1, but the current implementation has a + // trailing zero we need to ignore + let dof = (counts.len() - 2) as f64; + + // Chi-square test is not exact, so we'll trim to only consider cases + // where expected count is at least 5. + let expected_counts = (0..) + .map(|j| post.m(&DataOrSuffStat::Data(&[j])) * num_samples as f64) + .take_while(|x| *x > 5.0); + + let ts = counts + .iter() + .zip(expected_counts) + .map(|(o, e)| ((*o as f64) - e).powi(2) / e); + + let t: &f64 = &ts.clone().sum(); + let p = ChiSquared::new(dof).unwrap().sf(t); + + assert!(p > 0.001, "p-value = {}", p); + } + + #[test] + fn test_set_alpha() { + // Step 1: Generate a new StickBreaking instance with alpha=3 + let mut sb = StickBreaking::new(UnitPowerLaw::new(3.0).unwrap()); + + // Step 2: Set the prefix to [Beta(4, 3), Beta(3, 2), Beta(2, 1)] + sb.break_prefix = vec![ + Beta::new(4.0, 2.0).unwrap(), + Beta::new(3.0, 2.0).unwrap(), + Beta::new(2.0, 2.0).unwrap(), + ]; + + // Step 3: Call set_alpha(2.0) + sb.set_alpha(2.0).unwrap(); + + // Step 4: Check that the prefix is now [Beta(3, 3), Beta(2, 2), Beta(1, 1)] + assert_eq!(sb.break_prefix[0], Beta::new(3.0, 2.0).unwrap()); + assert_eq!(sb.break_prefix[1], Beta::new(2.0, 2.0).unwrap()); + assert_eq!(sb.break_prefix[2], Beta::new(1.0, 2.0).unwrap()); + assert_eq!(sb.break_tail, UnitPowerLaw::new(2.0).unwrap()); + } +} // mod tests diff --git a/src/experimental/stick_breaking_process/stick_breaking_stat.rs b/src/experimental/stick_breaking_process/stick_breaking_stat.rs new file mode 100644 index 0000000..1f625cc --- /dev/null +++ b/src/experimental/stick_breaking_process/stick_breaking_stat.rs @@ -0,0 +1,208 @@ +use crate::experimental::stick_breaking_process::stick_breaking::StickBreaking; +use crate::traits::{HasSuffStat, SuffStat}; + +#[cfg(feature = "serde1")] +use serde::{Deserialize, Serialize}; + +/// Represents the sufficient statistics for a stick-breaking process. +/// +/// This struct is used to accumulate statistics from a stick-breaking process, +/// such as the number of breaks and the sum of the logarithms of the remaining stick lengths. +/// +/// # Fields +/// +/// * `n` - The total number of observations. +/// * `num_breaks` - The number of breaks observed. +/// * `sum_log_q` - The sum of the logarithms of the remaining stick lengths after each break. +#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "serde1", serde(rename_all = "snake_case"))] +#[derive(Clone, Debug, PartialEq)] +pub struct StickBreakingSuffStat { + n: usize, + num_breaks: usize, + sum_log_q: f64, +} + +impl Default for StickBreakingSuffStat { + /// Provides a default instance of `StickBreakingSuffStat` with zeroed statistics. + /// + /// # Returns + /// + /// A new instance of `StickBreakingSuffStat` with all fields set to zero. + fn default() -> Self { + Self::new() + } +} + +impl StickBreakingSuffStat { + /// Constructs a new `StickBreakingSuffStat`. + /// + /// Initializes a new instance of `StickBreakingSuffStat` with all fields set to zero, + /// representing the start of a new stick-breaking process. + /// + /// # Returns + /// + /// A new instance of `StickBreakingSuffStat`. + pub fn new() -> Self { + Self { + n: 0, + num_breaks: 0, + sum_log_q: 0.0, + } + } + + /// Returns the number of breaks observed in the stick-breaking process. + pub fn num_breaks(&self) -> usize { + self.num_breaks + } + + /// Returns the sum of the logarithms of the remaining stick lengths after + /// each break. + pub fn sum_log_q(&self) -> f64 { + self.sum_log_q + } +} + +impl From<&&[f64]> for StickBreakingSuffStat { + /// Constructs a `StickBreakingSuffStat` from a slice of floating-point numbers. + /// + /// This conversion allows for directly observing a slice of stick lengths + /// and accumulating their statistics into a `StickBreakingSuffStat`. + /// + /// # Arguments + /// + /// * `x` - A reference to a slice of floating-point numbers representing stick lengths. + /// + /// # Returns + /// + /// A new instance of `StickBreakingSuffStat` with observed statistics. + fn from(x: &&[f64]) -> Self { + let mut stat = StickBreakingSuffStat::new(); + stat.observe(x); + stat + } +} + +// TODO: Generalize the above, something like +// impl From<&X> for Stat +// where Stat: SuffStat +// { +// fn from(x: &X) -> Self { +// let mut stat = Stat::new(); +// stat.observe(x); +// stat +// } +// } + +/// Computes the sufficient statistic for a UnitPowerLaw distribution from a sequence of stick lengths. +/// +/// This function processes a sequence of stick lengths resulting from a stick-breaking process +/// parameterized by a UnitPowerLaw(α), which is equivalent to a Beta(α,1) distribution. It calculates +/// the sufficient statistic for this distribution, which is necessary for further statistical analysis +/// or parameter estimation. +/// +/// # Arguments +/// +/// * `sticks` - A slice of floating-point numbers representing the lengths of the sticks. +/// +/// # Returns +/// +/// A tuple containing: +/// - The number of breaks (`usize`). +/// - The natural logarithm of the product of (1 - pᵢ) for each stick length pᵢ (`f64`). +fn stick_stat_unit_powerlaw(sticks: &[f64]) -> (usize, f64) { + // First we need to find the sequence of remaining stick lengths. Because we + // broke the sticks left-to-right, we need to reverse the sequence. + let remaining = sticks.iter().rev().scan(0.0, |acc, &x| { + *acc += x; + Some(*acc) + }); + + let qs = sticks + .iter() + // Reversing `remaining` would force us to collect the intermediate + // result e.g. into a `Vec`. Instead, we can reverse the sequence of + // stick lengths to match. + .rev() + // Now zip the sequences together and do the main computation we're interested in. + .zip(remaining) + // In theory the broken stick lengths should all be less than what was + // remaining before the break. In practice, numerical instabilities can + // cause problems. So we filter to be sure we only consider valid + // values. + .filter(|(&len, remaining)| len < *remaining) + .map(|(&len, remaining)| 1.0 - len / remaining); + + // The sufficient statistic is (n, ∑ᵢ log(1 - pᵢ)) == (n, log ∏ᵢ(1 - pᵢ)). + // First we compute `n` and `∏ᵢ(1 - pᵢ)` + let (num_breaks, prod_q) = + qs.fold((0, 1.0), |(n, prod_q), q| (n + 1, prod_q * q)); + + (num_breaks, prod_q.ln()) +} + +/// Implementation of `HasSuffStat` for `StickBreaking` with stick lengths as input. +impl HasSuffStat<&[f64]> for StickBreaking { + type Stat = StickBreakingSuffStat; + + /// Creates an empty sufficient statistic for stick breaking. + /// + /// # Returns + /// + /// A new instance of `StickBreakingSuffStat` with zeroed statistics. + fn empty_suffstat(&self) -> Self::Stat { + Self::Stat::new() + } + + /// Computes the natural logarithm of the likelihood function given the sufficient statistic. + /// + /// # Arguments + /// + /// * `stat` - A reference to the sufficient statistic of stick lengths. + /// + /// # Returns + /// + /// The natural logarithm of the likelihood function. + fn ln_f_stat(&self, stat: &Self::Stat) -> f64 { + let alpha = self.alpha(); + let alpha_ln = self.break_tail().alpha_ln(); + (stat.num_breaks as f64) + .mul_add(alpha_ln, (alpha - 1.0) * stat.sum_log_q) + } +} + +/// Implementation of `SuffStat` for `StickBreakingSuffStat` with stick lengths as input. +impl SuffStat<&[f64]> for StickBreakingSuffStat { + /// Returns the total number of observations. + /// + /// # Returns + /// + /// The total number of observations. + fn n(&self) -> usize { + self.n + } + + /// Observes a sequence of stick lengths and updates the sufficient statistic. + /// + /// # Arguments + /// + /// * `sticks` - A reference to a slice of floating-point numbers representing stick lengths. + fn observe(&mut self, sticks: &&[f64]) { + let (num_breaks, sum_log_q) = stick_stat_unit_powerlaw(sticks); + self.n += 1; + self.num_breaks += num_breaks; + self.sum_log_q += sum_log_q; + } + + /// Reverses the observation of a sequence of stick lengths and updates the sufficient statistic. + /// + /// # Arguments + /// + /// * `sticks` - A reference to a slice of floating-point numbers representing stick lengths. + fn forget(&mut self, sticks: &&[f64]) { + let (num_breaks, sum_log_q) = stick_stat_unit_powerlaw(sticks); + self.n -= 1; + self.num_breaks -= num_breaks; + self.sum_log_q -= sum_log_q; + } +} diff --git a/src/experimental/stick_breaking_process/stick_sequence.rs b/src/experimental/stick_breaking_process/stick_sequence.rs new file mode 100644 index 0000000..9f2ce15 --- /dev/null +++ b/src/experimental/stick_breaking_process/stick_sequence.rs @@ -0,0 +1,368 @@ +use rand::SeedableRng; +use rand_xoshiro::Xoshiro256Plus; +#[cfg(feature = "serde1")] +use serde::{Deserialize, Serialize}; +use std::sync::{Arc, RwLock}; + +// use super::sticks_stat::StickBreakingSuffStat; +use crate::experimental::stick_breaking_process::stick_breaking::PartialWeights; +use crate::prelude::UnitPowerLaw; +use crate::traits::*; + +// We'd like to be able to serialize and deserialize StickSequence, but serde can't handle +// `Arc` or `RwLock`. So we use `StickSequenceFmt` as an intermediate type. +#[cfg(feature = "serde1")] +#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "serde1", serde(rename_all = "snake_case"))] +struct StickSequenceFmt { + breaker: UnitPowerLaw, + inner: _Inner, +} + +#[cfg(feature = "serde1")] +impl From for StickSequence { + fn from(fmt: StickSequenceFmt) -> Self { + Self { + breaker: fmt.breaker, + inner: Arc::new(RwLock::new(fmt.inner)), + } + } +} + +#[cfg(feature = "serde1")] +impl From for StickSequenceFmt { + fn from(sticks: StickSequence) -> Self { + Self { + breaker: sticks.breaker, + inner: sticks.inner.read().map(|inner| inner.clone()).unwrap(), + } + } +} + +// NOTE: We currently derive PartialEq, but this (we think) compares the +// internal state of the RNGs, which is probably not what we want. +#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "serde1", serde(rename_all = "snake_case"))] +#[derive(Clone, Debug, PartialEq)] +pub struct _Inner { + rng: Xoshiro256Plus, + ccdf: Vec, +} + +impl _Inner { + fn new(seed: Option) -> _Inner { + _Inner { + rng: seed.map_or_else( + Xoshiro256Plus::from_entropy, + Xoshiro256Plus::seed_from_u64, + ), + ccdf: vec![1.0], + } + } + + fn extend + Clone>(&mut self, breaker: &B) -> f64 { + let p: f64 = breaker.draw(&mut self.rng); + let remaining_mass = self.ccdf.last().unwrap(); + let new_remaining_mass = remaining_mass * p; + self.ccdf.push(new_remaining_mass); + new_remaining_mass + } + + fn extend_until(&mut self, breaker: &B, p: F) + where + B: Rv + Clone, + F: Fn(&_Inner) -> bool, + { + while !p(self) { + self.extend(breaker); + } + } +} + +#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] +#[cfg_attr( + feature = "serde1", + serde( + rename_all = "snake_case", + from = "StickSequenceFmt", + into = "StickSequenceFmt" + ) +)] +#[derive(Clone, Debug)] +pub struct StickSequence { + breaker: UnitPowerLaw, + inner: Arc>, +} + +impl PartialEq for StickSequence { + fn eq(&self, other: &StickSequence) -> bool { + self.ensure_breaks(other.num_weights_unstable()); + other.ensure_breaks(self.num_weights_unstable()); + self.breaker == other.breaker + && self.with_inner(|self_inner| { + other.with_inner(|other_inner| { + self_inner.ccdf == other_inner.ccdf + && self_inner.rng == other_inner.rng + }) + }) + } +} + +impl StickSequence { + /// Creates a new StickSequence with the given breaker and optional seed. + /// + /// # Arguments + /// + /// * `breaker` - A `UnitPowerLaw` instance used as the breaker. + /// * `seed` - An optional seed for the random number generator. + /// + /// # Returns + /// + /// A new instance of `StickSequence`. + pub fn new(breaker: UnitPowerLaw, seed: Option) -> Self { + Self { + breaker, + inner: Arc::new(RwLock::new(_Inner::new(seed))), + } + } + + /// Pushes a new break to the stick sequence using a given probability `p`. + /// + /// # Arguments + /// + /// * `p` - The probability used to calculate the new remaining mass. + pub fn push_break(&self, p: f64) { + self.with_inner_mut(|inner| { + let remaining_mass = *inner.ccdf.last().unwrap(); + let new_remaining_mass = remaining_mass * p; + inner.ccdf.push(new_remaining_mass); + }); + } + + /// Pushes a new value `p` directly to the ccdf vector if `p` is less than the last element. + /// + /// # Arguments + /// + /// * `p` - The value to be pushed to the ccdf vector. + /// + /// # Panics + /// + /// Panics if `p` is not less than the last element of the ccdf vector. + pub fn push_to_ccdf(&self, p: f64) { + self.with_inner_mut(|inner| { + assert!(p < *inner.ccdf.last().unwrap()); + inner.ccdf.push(p); + }); + } + + /// Extends the ccdf vector until a condition defined by `pred` is met, then applies function `f`. + /// + /// # Type Parameters + /// + /// * `P` - A predicate function type that takes a reference to a vector of f64 and returns a bool. + /// * `F` - A function type that takes a reference to a vector of f64 and returns a value of type `Ans`. + /// * `Ans` - The return type of the function `f`. + /// + /// # Arguments + /// + /// * `pred` - A predicate function that determines when to stop extending the ccdf vector. + /// * `f` - A function to apply to the ccdf vector once the condition is met. + /// + /// # Returns + /// + /// The result of applying function `f` to the ccdf vector. + pub fn extendmap_ccdf(&self, pred: P, f: F) -> Ans + where + P: Fn(&Vec) -> bool, + F: Fn(&Vec) -> Ans, + { + self.extend_until(|inner| pred(&inner.ccdf)); + self.with_inner(|inner| f(&inner.ccdf)) + } + + /// Provides read access to the inner `_Inner` structure. + /// + /// # Type Parameters + /// + /// * `F` - A function type that takes a reference to `_Inner` and returns a value of type `Ans`. + /// * `Ans` - The return type of the function `f`. + /// + /// # Arguments + /// + /// * `f` - A function that is applied to the inner `_Inner` structure. + /// + /// # Returns + /// + /// The result of applying function `f` to the inner `_Inner` structure. + pub fn with_inner(&self, f: F) -> Ans + where + F: FnOnce(&_Inner) -> Ans, + { + self.inner.read().map(|inner| f(&inner)).unwrap() + } + + /// Provides write access to the inner `_Inner` structure. + /// + /// # Type Parameters + /// + /// * `F` - A function type that takes a mutable reference to `_Inner` and returns a value of type `Ans`. + /// * `Ans` - The return type of the function `f`. + /// + /// # Arguments + /// + /// * `f` - A function that is applied to the inner `_Inner` structure. + /// + /// # Returns + /// + /// The result of applying function `f` to the inner `_Inner` structure. + pub fn with_inner_mut(&self, f: F) -> Ans + where + F: FnOnce(&mut _Inner) -> Ans, + { + self.inner.write().map(|mut inner| f(&mut inner)).unwrap() + } + + /// Ensures that the ccdf vector is extended to at least `n + 1` elements. + /// + /// # Arguments + /// + /// * `n` - The minimum number of elements the ccdf vector should have. + pub fn ensure_breaks(&self, n: usize) { + self.extend_until(|inner| inner.ccdf.len() > n); + } + + /// Returns the `n`th element of the ccdf vector, ensuring the vector is long enough. + /// + /// # Arguments + /// + /// * `n` - The index of the element to retrieve from the ccdf vector. + /// + /// # Returns + /// + /// The `n`th element of the ccdf vector. + pub fn ccdf(&self, n: usize) -> f64 { + self.ensure_breaks(n); + self.with_inner(|inner| { + let ccdf = &inner.ccdf; + ccdf[n] + }) + } + + /// Returns the number of weights instantiated so far. + /// + /// # Returns + /// + /// The number of weights. This is "unstable" because it's a detail of the + /// implementation that should not be depended on. + pub fn num_weights_unstable(&self) -> usize { + self.with_inner(|inner| inner.ccdf.len() - 1) + } + + /// Returns the weight of the `n`th stick. + /// + /// # Arguments + /// + /// * `n` - The index of the stick whose weight is to be returned. + /// + /// # Returns + /// + /// The weight of the `n`th stick. + pub fn weight(&self, n: usize) -> f64 { + self.ensure_breaks(n + 1); + self.with_inner(|inner| { + let ccdf = &inner.ccdf; + ccdf[n] - ccdf[n + 1] + }) + } + + /// Returns the weights of the first `n` sticks. + /// + /// Note that this includes sticks `0..n-1`, but not `n`. + /// + /// # Arguments + /// + /// * `n` - The number of sticks for which to return the weights. + /// + /// # Returns + /// + /// A `PartialWeights` instance containing the weights of the first `n` sticks. + pub fn weights(&self, n: usize) -> PartialWeights { + self.ensure_breaks(n); + let w = self.with_inner(|inner| { + let mut last_p = 1.0; + inner + .ccdf + .iter() + .skip(1) + .map(|&p| { + let w = last_p - p; + last_p = p; + w + }) + .collect() + }); + PartialWeights(w) + } + + /// Returns a clone of the breaker used in this StickSequence. + /// + /// # Returns + /// + /// A clone of the `UnitPowerLaw` instance used as the breaker. + pub fn breaker(&self) -> UnitPowerLaw { + self.breaker.clone() + } + + /// Extends the ccdf vector until a condition defined by `p` is met. + /// + /// # Type Parameters + /// + /// * `F` - A function type that takes a reference to `_Inner` and returns a bool. + /// + /// # Arguments + /// + /// * `p` - A predicate function that determines when to stop extending the ccdf vector. + pub fn extend_until(&self, p: F) + where + F: Fn(&_Inner) -> bool, + { + self.with_inner_mut(|inner| inner.extend_until(&self.breaker, p)); + } +} + +#[cfg(test)] +mod tests { + use crate::experimental::stick_breaking_process::StickSequence; + use crate::prelude::UnitPowerLaw; + + #[test] + fn test_stickseq_weights() { + // test that `weights` gives the same as `weight` for all n + let breaker = UnitPowerLaw::new(10.0).unwrap(); + let sticks = StickSequence::new(breaker, None); + let weights = sticks.weights(100); + assert_eq!(weights.0.len(), 100); + for (n, w) in weights.0.iter().enumerate() { + assert_eq!(sticks.weight(n), *w); + } + } + + #[test] + fn test_push_to_ccdf() { + let breaker = UnitPowerLaw::new(10.0).unwrap(); + let sticks = StickSequence::new(breaker, None); + sticks.push_to_ccdf(0.9); + sticks.push_to_ccdf(0.8); + assert_eq!(sticks.ccdf(1), 0.9); + assert_eq!(sticks.ccdf(2), 0.8); + } + + #[test] + fn test_push_break() { + let breaker = UnitPowerLaw::new(10.0).unwrap(); + let sticks = StickSequence::new(breaker, None); + sticks.push_break(0.9); + sticks.push_break(0.8); + assert::close(sticks.weights(2).0, vec![0.1, 0.18], 1e-10); + } +} diff --git a/src/lib.rs b/src/lib.rs index 103babc..ee40494 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -15,6 +15,7 @@ //! - `process`: Gives you access to Gaussian processes. //! - `arraydist`: Enables distributions and statistical tests that require the //! [nalgebra](https://crates.io/crates/nalgebra) crate. +//! - `experimental`: Enables experimental features. //! //! # Design //! @@ -93,6 +94,8 @@ doctest!("../README.md"); pub mod consts; pub mod data; pub mod dist; +#[cfg(feature = "experimental")] +pub mod experimental; pub mod misc; mod model; pub mod prelude; @@ -129,7 +132,6 @@ macro_rules! extract_stat { xs.iter().for_each(|y| stat.observe(y)); stat } - DataOrSuffStat::None => $stat_type::new(), } } }; diff --git a/src/misc/bessel.rs b/src/misc/bessel.rs index 2bc9600..debad9b 100644 --- a/src/misc/bessel.rs +++ b/src/misc/bessel.rs @@ -1,5 +1,3 @@ -use std::f64::EPSILON; - const MAX_ITER: usize = 500; const BESSI0_COEFFS_A: [f64; 30] = [ @@ -208,11 +206,11 @@ impl std::fmt::Display for BesselIvError { /// Modified Bessel function of the first kind of real order pub fn bessel_iv(v: f64, z: f64) -> Result { if v.is_nan() || z.is_nan() { - return Ok(std::f64::NAN); + return Ok(f64::NAN); } let (v, t) = { let t = v.floor(); - if v < 0.0 && (t - v).abs() < EPSILON { + if v < 0.0 && (t - v).abs() < f64::EPSILON { (-v, -t) } else { (v, t) @@ -221,11 +219,11 @@ pub fn bessel_iv(v: f64, z: f64) -> Result { let sign: f64 = if z < 0.0 { // Return error if v is not an integer if x < 0 - if (t - v).abs() > EPSILON { + if (t - v).abs() > f64::EPSILON { return Err(BesselIvError::OrderNotIntegerForNegativeZ); } - if 2.0_f64.mul_add(-(v / 2.0).floor(), v) > EPSILON { + if 2.0_f64.mul_add(-(v / 2.0).floor(), v) > f64::EPSILON { -1.0 } else { 1.0 @@ -645,7 +643,7 @@ fn bessel_ikv_asymptotic_uniform( i_sum += term; k_sum += if n % 2 == 0 { term } else { -term }; - if term.abs() < EPSILON { + if term.abs() < f64::EPSILON { break; } divisor *= v; @@ -654,7 +652,7 @@ fn bessel_ikv_asymptotic_uniform( // check convergence if term.abs() > 1E-3 * i_sum.abs() { Err(BesselIvError::FailedToConverge) - } else if term.abs() > EPSILON * i_sum.abs() { + } else if term.abs() > f64::EPSILON * i_sum.abs() { Err(BesselIvError::PrecisionLoss) } else { let k_value = k_prefactor * k_sum; @@ -713,7 +711,7 @@ pub(crate) fn bessel_ikv_temme( let lim = (4.0_f64.mul_add(v * v, 10.0) / (8.0 * x)).powi(3) / 24.0; - let iv = if lim < 10.0 * EPSILON && x > 100.0 { + let iv = if lim < 10.0 * f64::EPSILON && x > 100.0 { bessel_iv_asymptotic(v, x)? } else { let fv = cf1_ik(v, x)?; @@ -750,17 +748,17 @@ fn temme_ik_series(v: f64, x: f64) -> Result<(f64, f64), BesselIvError> { let a = (x / 2.0).ln(); let b = (v * a).exp(); let sigma = -a * v; - let c = if v.abs() < 2.0 * EPSILON { + let c = if v.abs() < 2.0 * f64::EPSILON { 1.0 } else { (PI * v).sin() / (PI * v) }; - let d = if sigma.abs() < EPSILON { + let d = if sigma.abs() < f64::EPSILON { 1.0 } else { sigma.sinh() / sigma }; - let gamma1 = if v.abs() < EPSILON { + let gamma1 = if v.abs() < f64::EPSILON { -EULER_MASCERONI } else { (0.5 / v) * (gp - gm) * c @@ -785,7 +783,7 @@ fn temme_ik_series(v: f64, x: f64) -> Result<(f64, f64), BesselIvError> { sum += coef * f; sum1 += coef * h; - if (coef * f).abs() < sum.abs() * EPSILON { + if (coef * f).abs() < sum.abs() * f64::EPSILON { return Ok((sum, 2.0 * sum1 / x)); } } @@ -832,7 +830,7 @@ fn cf2_ik(v: f64, x: f64) -> Result<(f64, f64), BesselIvError> { q += c * t; s += q * delta; - if (q * delta).abs() < s.abs() * EPSILON / 2.0 { + if (q * delta).abs() < s.abs() * f64::EPSILON / 2.0 { let kv = (PI / (2.0 * x)).sqrt() * (-x).exp() / s; let kv1 = kv * v.mul_add(v, -0.25).mul_add(f, 0.5 + v + x) / x; return Ok((kv, kv1)); @@ -855,8 +853,8 @@ fn cf1_ik(v: f64, x: f64) -> Result { * Lentz, Applied Optics, vol 15, 668 (1976) */ - const TOL: f64 = EPSILON; - let tiny: f64 = std::f64::MAX.sqrt().recip(); + const TOL: f64 = f64::EPSILON; + let tiny: f64 = f64::MAX.sqrt().recip(); let mut c = tiny; let mut f = tiny; let mut d = 0.0; @@ -900,7 +898,7 @@ fn bessel_iv_asymptotic(v: f64, x: f64) -> Result { let mut term: f64 = 1.0; let mut k: usize = 1; - while term.abs() > std::f64::EPSILON * sum.abs() { + while term.abs() > f64::EPSILON * sum.abs() { let kf = k as f64; let factor = 2.0_f64 .mul_add(kf, -1.0) diff --git a/src/misc/convergent_seq.rs b/src/misc/convergent_seq.rs new file mode 100644 index 0000000..9da9270 --- /dev/null +++ b/src/misc/convergent_seq.rs @@ -0,0 +1,95 @@ +// use core::iter::Map; +use itertools::Itertools; +use num::Zero; +// use itertools::TupleWindows; + +/// A trait for sequences that can be checked for convergence. +pub trait ConvergentSequence: Iterator + Sized { + /// Applies Aitken's Δ² process to accelerate the convergence of a sequence. + /// See https://en.wikipedia.org/wiki/Aitken%27s_delta-squared_process and + /// https://en.wikipedia.org/wiki/Shanks_transformation + /// + /// # Returns + /// + /// An iterator over the accelerated sequence. + fn aitken(self) -> impl Iterator { + self.tuple_windows::<(_, _, _)>().filter_map(|(x, y, z)| { + let dx = z - y; + let dx2 = y - x; + let ddx = dx - dx2; + + // We can't handle a segment like [2,4,6] + // But e.g. [2, 2, 2] may have already converged + if ddx.is_zero() { + if dx.is_zero() { + Some(z) + } else { + None + } + } else { + Some(z - dx.powi(2) / ddx) + } + }) + } + + /// Finds the limit of the sequence within a given tolerance using Aitken's + /// Δ² process. This should *only* be applied to sequences that are known to + /// converge. + /// + /// # Arguments + /// + /// * `tol` - The tolerance within which to find the limit. + /// + /// # Returns + /// + /// The limit of the sequence as a floating-point number. + /// + /// # Panics + /// + /// Runs forever if the sequence does not converge within the given + /// tolerance. + fn limit(self, tol: f64) -> f64 { + self.aitken() + .aitken() + .aitken() + .aitken() + .tuple_windows::<(_, _)>() + .filter_map( + |(a, b)| { + if (a - b).abs() < tol { + Some(b) + } else { + None + } + }, + ) + .next() + .unwrap() + } +} + +impl ConvergentSequence for T where T: Iterator + Sized {} + +#[cfg(test)] +mod tests { + use super::*; + use num::Integer; + + #[test] + fn test_aitken_limit() { + let seq = (0..) + .map(|n| { + let sign = if n.is_even() { 1.0 } else { -1.0 }; + let val = sign / (2 * n + 1) as f64; + dbg!(val); + val + }) + .scan(0.0, |acc, x| { + *acc += x; + Some(*acc) + }); + let limit = seq.limit(1e-10); + let pi_over_4 = std::f64::consts::PI / 4.0; + assert!((limit - pi_over_4).abs() < 1e-10, "The limit calculated using Aitken's Δ² process did not converge to π/4 within the tolerance."); + } +} diff --git a/src/misc/func.rs b/src/misc/func.rs index 3566c41..b17f5d4 100644 --- a/src/misc/func.rs +++ b/src/misc/func.rs @@ -3,7 +3,6 @@ use rand::distributions::Open01; use rand::Rng; use special::Gamma; use std::cmp::Ordering; -use std::cmp::PartialOrd; use std::fmt::Debug; use std::ops::AddAssign; @@ -24,7 +23,9 @@ pub fn vec_to_string(xs: &[T], max_entries: usize) -> String { out += "["; let n = xs.len(); xs.iter().enumerate().for_each(|(i, x)| { - let to_push = if i < max_entries - 1 { + let to_push = if i == n - 1 { + format!("{:?}]", x) + } else if i < max_entries - 1 { format!("{:?}, ", x) } else if i == (max_entries - 1) && n > max_entries { String::from("... , ") @@ -38,7 +39,7 @@ pub fn vec_to_string(xs: &[T], max_entries: usize) -> String { out } -/// Natural logarithm of binomial coefficent, ln nCk +/// Natural logarithm of binomial coefficient, ln nCk /// /// # Example /// @@ -99,16 +100,15 @@ pub fn logsumexp(xs: &[f64]) -> f64 { xs[0] } else { let (alpha, r) = - xs.iter() - .fold((std::f64::NEG_INFINITY, 0.0), |(alpha, r), &x| { - if x == std::f64::NEG_INFINITY { - (alpha, r) - } else if x <= alpha { - (alpha, r + (x - alpha).exp()) - } else { - (x, r.mul_add((alpha - x).exp(), 1.0)) - } - }); + xs.iter().fold((f64::NEG_INFINITY, 0.0), |(alpha, r), &x| { + if x == f64::NEG_INFINITY { + (alpha, r) + } else if x <= alpha { + (alpha, r + (x - alpha).exp()) + } else { + (x, r.mul_add((alpha - x).exp(), 1.0)) + } + }); r.ln() + alpha } @@ -344,6 +344,55 @@ pub fn ln_fact(n: usize) -> f64 { } } +/// Generate a vector of sorted uniform random variables. +/// +/// # Arguments +/// +/// * `n` - The number of random variables to generate. +/// +/// * `rng` - A mutable reference to the random number generator. +/// +/// # Returns +/// +/// A vector of sorted uniform random variables. +/// +/// # Example +/// +/// ``` +/// use rand::thread_rng; +/// use rv::misc::sorted_uniforms; +/// +/// let mut rng = thread_rng(); +/// let n = 10000; +/// let xs = sorted_uniforms(n, &mut rng); +/// assert_eq!(xs.len(), n); +/// +/// // Result is sorted and in the unit interval +/// assert!(xs.first().map_or(false, |&first| first > 0.0)); +/// assert!(xs.last().map_or(false, |&last| last < 1.0)); +/// assert!(xs.windows(2).all(|w| w[0] <= w[1])); +/// +/// // Mean is approximately 1/2 +/// let mean = xs.iter().sum::() / n as f64; +/// assert!(mean > 0.49 && mean < 0.51); +/// +/// // Variance is approximately 1/12 +/// let var = xs.iter().map(|x| (x - 0.5).powi(2)).sum::() / n as f64; +/// assert!(var > 0.08 && var < 0.09); +/// ``` +pub fn sorted_uniforms(n: usize, rng: &mut R) -> Vec { + let mut xs: Vec<_> = (0..n) + .map(|_| -rng.gen::().ln()) + .scan(0.0, |state, x| { + *state += x; + Some(*state) + }) + .collect(); + let max = *xs.last().unwrap() - rng.gen::().ln(); + (0..n).for_each(|i| xs[i] /= max); + xs +} + const LN_FACT: [f64; 255] = [ 0.000_000_000_000_000, 0.000_000_000_000_000, @@ -605,6 +654,9 @@ const LN_FACT: [f64; 255] = [ #[cfg(test)] mod tests { use super::*; + use crate::prelude::ChiSquared; + use crate::traits::Cdf; + use rand::thread_rng; const TOL: f64 = 1E-12; @@ -666,18 +718,18 @@ mod tests { #[test] fn logsumexp_leading_neginf() { - let inf = std::f64::INFINITY; + let inf = f64::INFINITY; let weights = vec![ -inf, - -210.14873879197316, - -818.1043044601643, - -1269.0480185226445, - -2916.862476271387, + -210.148_738_791_973_16, + -818.104_304_460_164_3, + -1_269.048_018_522_644_5, + -2_916.862_476_271_387, -inf, ]; let lse = logsumexp(&weights); - assert::close(lse, -210.14873879197316, TOL); + assert::close(lse, -210.148_738_791_973_16, TOL); } #[test] @@ -724,9 +776,8 @@ mod tests { #[test] fn ln_pflip_works_with_zero_weights() { use std::f64::consts::LN_2; - use std::f64::NEG_INFINITY; - let ln_weights: Vec = vec![-LN_2, NEG_INFINITY, -LN_2]; + let ln_weights: Vec = vec![-LN_2, f64::NEG_INFINITY, -LN_2]; let xs = ln_pflip(&ln_weights, 100, true, &mut rand::thread_rng()); @@ -738,4 +789,50 @@ mod tests { assert_eq!(one_count, 0); assert!(two_count > 30); } + + #[test] + fn test_sorted_uniforms() { + let mut rng = thread_rng(); + let n = 1000; + let xs = sorted_uniforms(n, &mut rng); + assert_eq!(xs.len(), n); + + // Result is sorted and in the unit interval + assert!(&0.0 < xs.first().unwrap()); + assert!(xs.last().unwrap() < &1.0); + assert!(xs.windows(2).all(|w| w[0] <= w[1])); + + // t will aggregate our chi-squared test statistic + let mut t = 0.0; + + { + // We'll build a histogram and count the bin populations, aggregating + // the chi-squared statistic as we go + let mut next_bin = 0.01; + let mut bin_pop = 0; + + for x in xs.iter() { + bin_pop += 1; + if *x > next_bin { + let obs = bin_pop as f64; + let exp = n as f64 / 100.0; + t += (obs - exp).powi(2) / exp; + bin_pop = 0; + next_bin += 0.01; + } + } + + // The last bin + let obs = bin_pop as f64; + let exp = n as f64 / 100.0; + t += (obs - exp).powi(2) / exp; + } + + let alpha = 0.001; + + // dof = number of bins minus one + let chi2 = ChiSquared::new(99.0).unwrap(); + let p = chi2.sf(&t); + assert!(p > alpha); + } } diff --git a/src/misc/ks.rs b/src/misc/ks.rs index b523ed1..7af8869 100644 --- a/src/misc/ks.rs +++ b/src/misc/ks.rs @@ -95,7 +95,7 @@ pub enum KsError { /// Two sample Kolmogorov-Smirnov statistic on two samples. /// -/// Heavily inspired by https://github.com/scipy/scipy/blob/v1.4.1/scipy/stats/stats.py#L6087 +/// Heavily inspired by /// Exact computations are derived from: /// Hodges, J.L. Jr., "The Significance Probability of the Smirnov /// Two-Sample Test," Arkiv fiur Matematik, 3, No. 43 (1958), 469-86. @@ -162,7 +162,7 @@ where .iter() .zip(cdf_y.iter()) .map(|(cx, cy)| (cx - cy)) - .fold((std::f64::MAX, std::f64::MIN), |(min, max), z| { + .fold((f64::MAX, f64::MIN), |(min, max), z| { let new_min = min.min(z); let new_max = max.max(z); (new_min, new_max) @@ -191,7 +191,7 @@ where } } KsMode::Exact => { - if n_x_g > std::f64::MAX / n_y_g { + if n_x_g > f64::MAX / n_y_g { return Err(KsError::TooLongForExact); } KsMode::Exact @@ -473,7 +473,6 @@ fn ks_cdf(n: usize, d: f64) -> f64 { mod tests { use super::*; use crate::dist::Gaussian; - use crate::traits::Cdf; const TOL: f64 = 1E-12; diff --git a/src/misc/mardia.rs b/src/misc/mardia.rs index 3bf0cca..d2e5984 100644 --- a/src/misc/mardia.rs +++ b/src/misc/mardia.rs @@ -53,7 +53,7 @@ pub fn mardia(xs: &[DVector]) -> (f64, f64) { mod test { use super::*; use crate::dist::MvGaussian; - use crate::traits::Rv; + use crate::traits::*; const MARDIA_PVAL: f64 = 0.05; const NTRIES: usize = 5; diff --git a/src/misc/mod.rs b/src/misc/mod.rs index 10cd9c2..14d29b4 100644 --- a/src/misc/mod.rs +++ b/src/misc/mod.rs @@ -1,5 +1,6 @@ //! Random utilities pub mod bessel; +mod convergent_seq; pub(crate) mod entropy; mod func; mod ks; @@ -9,6 +10,7 @@ mod mardia; mod seq; mod x2; +pub use convergent_seq::*; pub use func::*; pub use ks::*; pub use legendre::*; diff --git a/src/model.rs b/src/model.rs index 44bb692..04be40c 100644 --- a/src/model.rs +++ b/src/model.rs @@ -1,4 +1,3 @@ -use crate::data::DataOrSuffStat; use crate::traits::*; use rand::Rng; use std::marker::PhantomData; @@ -112,7 +111,7 @@ where } } -impl Rv for ConjugateModel +impl HasDensity for ConjugateModel where Fx: Rv + HasSuffStat, Pr: ConjugatePrior, @@ -120,7 +119,13 @@ where fn ln_f(&self, x: &X) -> f64 { self.prior.ln_pp(x, &self.obs()) } +} +impl Sampleable for ConjugateModel +where + Fx: Rv + HasSuffStat, + Pr: ConjugatePrior, +{ fn draw(&self, mut rng: &mut R) -> X { let post = self.posterior(); let fx: Fx = post.draw(&mut rng); diff --git a/src/process/gaussian/kernel/constant_kernel.rs b/src/process/gaussian/kernel/constant_kernel.rs index 3f0c85d..0ee7170 100644 --- a/src/process/gaussian/kernel/constant_kernel.rs +++ b/src/process/gaussian/kernel/constant_kernel.rs @@ -21,7 +21,7 @@ impl ConstantKernel { Err(KernelError::ParameterOutOfBounds { name: "value".to_string(), given: value, - bounds: (0.0, std::f64::INFINITY), + bounds: (0.0, f64::INFINITY), }) } else { Ok(Self { scale: value }) diff --git a/src/process/gaussian/kernel/matern.rs b/src/process/gaussian/kernel/matern.rs index 551643d..c69a976 100644 --- a/src/process/gaussian/kernel/matern.rs +++ b/src/process/gaussian/kernel/matern.rs @@ -1,10 +1,10 @@ use crate::misc::bessel::bessel_ikv_temme; use super::{e2_norm, CovGrad, CovGradError, Kernel, KernelError}; +use crate::misc::gammafn; use nalgebra::base::constraint::{SameNumberOfColumns, ShapeConstraint}; use nalgebra::base::storage::Storage; use nalgebra::{dvector, DMatrix, DVector, Dim, Matrix}; -use peroxide::prelude::gamma; use std::f64; #[cfg(feature = "serde1")] @@ -35,13 +35,13 @@ impl MaternKernel { Err(KernelError::ParameterOutOfBounds { name: "nu".to_string(), given: nu, - bounds: (0.0, std::f64::INFINITY), + bounds: (0.0, f64::INFINITY), }) } else if length_scale <= 0.0 { Err(KernelError::ParameterOutOfBounds { name: "length_scale".to_string(), given: length_scale, - bounds: (0.0, std::f64::INFINITY), + bounds: (0.0, f64::INFINITY), }) } else { Ok(Self { nu, length_scale }) @@ -63,7 +63,7 @@ impl MaternKernel { let n = x.nrows(); let mut dm: DMatrix = DMatrix::zeros(n, n); - let c = (1.0 - self.nu).exp2() / gamma(self.nu); + let c = (1.0 - self.nu).exp2() / gammafn(self.nu); let sqrt_two_nu = (2.0 * self.nu).sqrt(); for i in 0..n { @@ -119,7 +119,7 @@ impl Kernel for MaternKernel { let n = x2.nrows(); let mut dm: DMatrix = DMatrix::zeros(m, n); - let c = (1.0 - self.nu).exp2() / gamma(self.nu); + let c = (1.0 - self.nu).exp2() / gammafn(self.nu); let sqrt_two_nu = (2.0 * self.nu).sqrt(); for i in 0..m { diff --git a/src/process/gaussian/kernel/rbf.rs b/src/process/gaussian/kernel/rbf.rs index 3fd6863..b6696e9 100644 --- a/src/process/gaussian/kernel/rbf.rs +++ b/src/process/gaussian/kernel/rbf.rs @@ -31,7 +31,7 @@ impl RBFKernel { Err(KernelError::ParameterOutOfBounds { name: "length_scale".to_string(), given: length_scale, - bounds: (0.0, std::f64::INFINITY), + bounds: (0.0, f64::INFINITY), }) } else { Ok(Self { length_scale }) diff --git a/src/process/gaussian/kernel/seard.rs b/src/process/gaussian/kernel/seard.rs index ed777c0..58f0ef7 100644 --- a/src/process/gaussian/kernel/seard.rs +++ b/src/process/gaussian/kernel/seard.rs @@ -36,7 +36,7 @@ impl SEardKernel { .iter() .min_by(|a, b| a.partial_cmp(b).unwrap()) .unwrap(), - bounds: (0.0, std::f64::INFINITY), + bounds: (0.0, f64::INFINITY), }) } } diff --git a/src/process/gaussian/kernel/white_kernel.rs b/src/process/gaussian/kernel/white_kernel.rs index 771929a..a233229 100644 --- a/src/process/gaussian/kernel/white_kernel.rs +++ b/src/process/gaussian/kernel/white_kernel.rs @@ -23,7 +23,7 @@ impl WhiteKernel { return Err(KernelError::ParameterOutOfBounds { name: "noise_level".to_string(), given: noise_level, - bounds: (0.0, std::f64::INFINITY), + bounds: (0.0, f64::INFINITY), }); } Ok(Self { noise_level }) diff --git a/src/process/gaussian/mod.rs b/src/process/gaussian/mod.rs index bd1f467..2817aa0 100644 --- a/src/process/gaussian/mod.rs +++ b/src/process/gaussian/mod.rs @@ -8,8 +8,9 @@ use rand::Rng; use serde::{Deserialize, Serialize}; use std::cell::OnceCell; +use crate::consts::HALF_LN_2PI; use crate::dist::MvGaussian; -use crate::{consts::HALF_LN_2PI, traits::Mean, traits::Rv, traits::Variance}; +use crate::traits::*; pub mod kernel; use kernel::{Kernel, KernelError}; @@ -152,7 +153,7 @@ where indicies: &[Self::Index], ) -> Self::SampleFunction { let n = indicies.len(); - let m = indicies.get(0).map(|i| i.len()).unwrap_or(0); + let m = indicies.first().map(|i| i.len()).unwrap_or(0); let indicies: DMatrix = DMatrix::from_iterator( n, @@ -352,14 +353,19 @@ where } } -impl Rv> for GaussianProcessPrediction +impl HasDensity> for GaussianProcessPrediction where K: Kernel, { fn ln_f(&self, x: &DVector) -> f64 { self.dist().ln_f(x) } +} +impl Sampleable> for GaussianProcessPrediction +where + K: Kernel, +{ fn draw(&self, rng: &mut R) -> DVector { self.dist().draw(rng) } diff --git a/src/process/mod.rs b/src/process/mod.rs index c867e7d..662eb46 100644 --- a/src/process/mod.rs +++ b/src/process/mod.rs @@ -84,7 +84,7 @@ where let random_params = (0..random_reinits).map(|_| self.random_params(rng)); - let mut best_cost = std::f64::INFINITY; + let mut best_cost = f64::INFINITY; let mut successes = 0; let mut last_err = None; diff --git a/src/test.rs b/src/test.rs index 78816ec..3c4a990 100644 --- a/src/test.rs +++ b/src/test.rs @@ -9,30 +9,20 @@ use std::collections::BTreeMap; // often happens in ln_f is called #[macro_export] macro_rules! test_basic_impls { - ([continuous] $fx: expr) => { - test_basic_impls!($fx, 0.5_f64, impls); + ($X:ty, $Fx:ty) => { + test_basic_impls!($X, $Fx, <$Fx>::default()); }; - ([categorical] $fx: expr) => { - test_basic_impls!($fx, 0_usize, impls); - }; - ([count] $fx: expr) => { - test_basic_impls!($fx, 3_u32, impls); - }; - ([binary] $fx: expr) => { - test_basic_impls!($fx, true, impls); - }; - ($fx: expr, $x: expr) => { - test_basic_impls!($fx, $x, impls); - }; - ($fx: expr, $x: expr, $mod: ident) => { - mod $mod { + ($X:ty, $Fx:ty, $fx:expr) => { + mod rv_impl { use super::*; #[test] fn should_impl_debug_clone_and_partialeq() { + let mut rng = rand::thread_rng(); // make the expression a thing. If we don't do this, calling $fx // reconstructs the distribution which means we don't do caching let fx = $fx; + let x: $X = fx.draw(&mut rng); // clone a copy of fn before any computation of cached values is // done @@ -40,16 +30,35 @@ macro_rules! test_basic_impls { assert_eq!($fx, fx2); // Computing ln_f normally initializes all cached values - let y1 = fx.ln_f(&$x); - let y2 = fx.ln_f(&$x); - assert!((y1 - y2).abs() < std::f64::EPSILON); + let y1 = fx.ln_f(&x); + let y2 = fx.ln_f(&x); + assert!((y1 - y2).abs() < f64::EPSILON); - // check the fx == fx2 despite fx having its cached values initalized + // check the fx == fx2 despite fx having its cached values + // initialized assert_eq!(fx2, fx); // Make sure Debug is implemented for fx let _s1 = format!("{:?}", fx); } + + #[test] + fn should_impl_parameterized() { + let mut rng = rand::thread_rng(); + + let fx_1 = $fx; + let params = fx_1.emit_params(); + let fx_2 = <$Fx>::from_params(params); + + for _ in 0..100 { + let x: $X = fx_1.draw(&mut rng); + + let ln_f_1 = fx_1.ln_f(&x); + let ln_f_2 = fx_2.ln_f(&x); + + assert::close(ln_f_1, ln_f_2, 1e-14); + } + } } }; } @@ -320,3 +329,179 @@ macro_rules! gaussian_prior_geweke_testable { } }; } + +#[macro_export] +macro_rules! test_conjugate_prior { + ($X: ty, $Fx: ty, $Pr: ident, $prior: expr) => { + test_conjugate_prior!( + $X, + $Fx, + $Pr, + $prior, + mctol = 1e-3, + n = 1_000_000 + ); + }; + ($X: ty, $Fx: ty, $Pr: ident, $prior: expr, n=$n: expr) => { + test_conjugate_prior!($X, $Fx, $Pr, $prior, mctol = 1e-3, n = $n); + }; + ($X: ty, $Fx: ty, $Pr: ident, $prior: expr, mctol=$tol: expr) => { + test_conjugate_prior!( + $X, + $Fx, + $Pr, + $prior, + mctol = $tol, + n = 1_000_000 + ); + }; + ($X: ty, $Fx: ty, $Pr: ident, $prior: expr, mctol=$tol: expr, n=$n: expr) => { + mod conjugate_prior { + use super::*; + + fn random_xs( + fx: &$Fx, + n: usize, + mut rng: &mut impl rand::Rng, + ) -> <$Fx as $crate::traits::HasSuffStat<$X>>::Stat { + let mut stat = + <$Fx as $crate::traits::HasSuffStat<$X>>::empty_suffstat( + &fx, + ); + let xs: Vec<$X> = fx.sample(n, &mut rng); + stat.observe_many(&xs); + stat + } + + #[test] + fn ln_p_is_ratio_of_ln_m() { + // test that p(y|x) = p(y, x) / p(x) + // If this doesn't work, one of two things could be wrong: + // 1. prior.ln_m is wrong + // 2. prior.ln_pp is wrong + let mut rng = rand::thread_rng(); + + let pr = $prior; + let fx: $Fx = pr.draw(&mut rng); + + let mut stat = random_xs(&fx, 3, &mut rng); + + let y: $X = fx.draw(&mut rng); + + let ln_pp = <$Pr as ConjugatePrior<$X, $Fx>>::ln_pp( + &pr, + &y, + &DataOrSuffStat::SuffStat(&stat), + ); + let ln_m_lower = <$Pr as ConjugatePrior<$X, $Fx>>::ln_m( + &pr, + &DataOrSuffStat::SuffStat(&stat), + ); + + stat.observe(&y); + + let ln_m_upper = <$Pr as ConjugatePrior<$X, $Fx>>::ln_m( + &pr, + &DataOrSuffStat::SuffStat(&stat), + ); + + assert::close(ln_pp, ln_m_upper - ln_m_lower, 1e-12); + } + + #[test] + fn bayes_law() { + // test that p(θ|x) == p(x|θ)p(θ)/p(x) + // If this doesn't work, one of the following is wrong + // 1. prior.posterior.ln_f(fx) + // 2. fx.ln_f(x) + // 3. prior.ln_f(fx) + // 4. prior.ln_m(x) + let mut rng = rand::thread_rng(); + + let pr = $prior; + let fx: $Fx = pr.draw(&mut rng); + let stat = random_xs(&fx, 3, &mut rng); + + let ln_like = + <$Fx as $crate::traits::HasSuffStat<$X>>::ln_f_stat( + &fx, &stat, + ); + let ln_prior = pr.ln_f(&fx); + let ln_m = <$Pr as ConjugatePrior<$X, $Fx>>::ln_m( + &pr, + &DataOrSuffStat::SuffStat(&stat), + ); + + let posterior = <$Pr as ConjugatePrior<$X, $Fx>>::posterior( + &pr, + &DataOrSuffStat::SuffStat(&stat), + ); + let ln_post = posterior.ln_f(&fx); + + eprintln!("bayes_law stat: {:?}", stat); + eprintln!("bayes_law prior: {pr}"); + eprintln!("bayes_law fx: {fx}"); + eprintln!("bayes_law ln_like: {ln_like}"); + eprintln!("bayes_law ln_prior: {ln_prior}"); + eprintln!("bayes_law ln_m: {ln_m}"); + eprintln!("bayes_law ln_post: {ln_post}"); + + assert::close(ln_post, ln_like + ln_prior - ln_m, 1e-12); + } + + #[test] + fn monte_carlo_ln_m() { + // tests that the Monte Carlo estimate of the evidence converges + // to m(x) + // If this doesn't work one of three things could be wrong: + // 1. prior.draw (from sample_stream) is wrong + // 2. fx.ln_f_stat is wrong + // 3. prior.m is wrong + let n_tries = 5; + let mut rng = rand::thread_rng(); + + let pr = $prior; + + let stat = random_xs(&pr.draw(&mut rng), 3, &mut rng); + + let m = <$Pr as ConjugatePrior<$X, $Fx>>::m( + &pr, + &DataOrSuffStat::SuffStat(&stat), + ); + + let mut min_err = f64::INFINITY; + + for _ in 0..n_tries { + let stream = + <$Pr as $crate::traits::Sampleable<$Fx>>::sample_stream( + &pr, &mut rng, + ); + let est = stream + .take($n) + .map(|fx| { + <$Fx as $crate::traits::HasSuffStat<$X>>::ln_f_stat( + &fx, &stat, + ) + .exp() + }) + .sum::() + / ($n as f64); + + let err = (est - m).abs(); + let close_enough = err < $tol; + + if err < min_err { + min_err = err; + } + + if close_enough { + return; + } + } + panic!( + "MC estimate of M failed under {pr}. Min err: {min_err}" + ); + } + } + }; +} diff --git a/src/traits.rs b/src/traits.rs index 227a32c..ac37fbe 100644 --- a/src/traits.rs +++ b/src/traits.rs @@ -1,43 +1,16 @@ //! Trait definitions -use crate::data::DataOrSuffStat; +pub use crate::data::DataOrSuffStat; use rand::Rng; -/// Random variable -/// -/// Contains the minimal functionality that a random object must have to be -/// useful: a function defining the un-normalized density/mass at a point, -/// and functions to draw samples from the distribution. -pub trait Rv { - /// Probability function - /// - /// # Example - /// - /// ``` - /// use rv::dist::Gaussian; - /// use rv::traits::Rv; - /// - /// let g = Gaussian::standard(); - /// assert!(g.f(&0.0_f64) > g.f(&0.1_f64)); - /// assert!(g.f(&0.0_f64) > g.f(&-0.1_f64)); - /// ``` - fn f(&self, x: &X) -> f64 { - self.ln_f(x).exp() - } +pub trait Parameterized { + type Parameters; - /// Probability function - /// - /// # Example - /// - /// ``` - /// use rv::dist::Gaussian; - /// use rv::traits::Rv; - /// - /// let g = Gaussian::standard(); - /// assert!(g.ln_f(&0.0_f64) > g.ln_f(&0.1_f64)); - /// assert!(g.ln_f(&0.0_f64) > g.ln_f(&-0.1_f64)); - /// ``` - fn ln_f(&self, x: &X) -> f64; + fn emit_params(&self) -> Self::Parameters; + fn from_params(params: Self::Parameters) -> Self; +} + +pub trait Sampleable { /// Single draw from the `Rv` /// /// # Example @@ -46,7 +19,7 @@ pub trait Rv { /// /// ``` /// use rv::dist::Bernoulli; - /// use rv::traits::Rv; + /// use rv::traits::*; /// /// let b = Bernoulli::uniform(); /// let mut rng = rand::thread_rng(); @@ -62,7 +35,7 @@ pub trait Rv { /// /// ``` /// use rv::dist::Bernoulli; - /// use rv::traits::Rv; + /// use rv::traits::*; /// /// let b = Bernoulli::uniform(); /// let mut rng = rand::thread_rng(); @@ -75,7 +48,7 @@ pub trait Rv { /// /// ``` /// use rv::dist::Gaussian; - /// use rv::traits::Rv; + /// use rv::traits::*; /// /// let gauss = Gaussian::standard(); /// let mut rng = rand::thread_rng(); @@ -94,7 +67,7 @@ pub trait Rv { /// Estimate the mean of a Gamma distribution /// /// ``` - /// use rv::traits::Rv; + /// use rv::traits::*; /// use rv::dist::Gamma; /// /// let mut rng = rand::thread_rng(); @@ -102,7 +75,7 @@ pub trait Rv { /// let gamma = Gamma::new(2.0, 1.0).unwrap(); /// /// let n = 1_000_000_usize; - /// let mean = >::sample_stream(&gamma, &mut rng) + /// let mean = >::sample_stream(&gamma, &mut rng) /// .take(n) /// .sum::() / n as f64;; /// @@ -116,6 +89,53 @@ pub trait Rv { } } +pub trait HasDensity { + /// Probability function + /// + /// # Example + /// + /// ``` + /// use rv::dist::Gaussian; + /// use rv::traits::*; + /// + /// let g = Gaussian::standard(); + /// assert!(g.f(&0.0_f64) > g.f(&0.1_f64)); + /// assert!(g.f(&0.0_f64) > g.f(&-0.1_f64)); + /// ``` + fn f(&self, x: &X) -> f64 { + self.ln_f(x).exp() + } + + /// Probability function + /// + /// # Example + /// + /// ``` + /// use rv::dist::Gaussian; + /// use rv::traits::*; + /// + /// let g = Gaussian::standard(); + /// assert!(g.ln_f(&0.0_f64) > g.ln_f(&0.1_f64)); + /// assert!(g.ln_f(&0.0_f64) > g.ln_f(&-0.1_f64)); + /// ``` + fn ln_f(&self, x: &X) -> f64; +} + +/// Random variable +/// +/// Contains the minimal functionality that a random object must have to be +/// useful: a function defining the un-normalized density/mass at a point, +/// and functions to draw samples from the distribution. +pub trait Rv: Sampleable + HasDensity {} + +impl Rv for T where T: Sampleable + HasDensity {} + +/// Stochastic process +/// +pub trait Process: Sampleable + HasDensity {} + +impl Process for T where T: Sampleable + HasDensity {} + /// Identifies the support of the Rv pub trait Support { /// Returns `true` if `x` is in the support of the `Rv` @@ -140,7 +160,7 @@ pub trait Support { /// /// This trait uses the `Rv` and `Support` implementations to implement /// itself. -pub trait ContinuousDistr: Rv + Support { +pub trait ContinuousDistr: HasDensity + Support { /// The value of the Probability Density Function (PDF) at `x` /// /// # Example @@ -205,19 +225,19 @@ pub trait ContinuousDistr: Rv + Support { /// /// let expon = Exponential::new(1.0).unwrap(); /// let f = expon.ln_pdf(&-1.0_f64); - /// assert_eq!(f, std::f64::NEG_INFINITY); + /// assert_eq!(f, f64::NEG_INFINITY); /// ``` fn ln_pdf(&self, x: &X) -> f64 { if self.supports(x) { self.ln_f(x) } else { - std::f64::NEG_INFINITY + f64::NEG_INFINITY } } } /// Has a cumulative distribution function (CDF) -pub trait Cdf: Rv { +pub trait Cdf: HasDensity { /// The value of the Cumulative Density Function at `x` /// /// # Example @@ -241,7 +261,7 @@ pub trait Cdf: Rv { } /// Has an inverse-CDF / quantile function -pub trait InverseCdf: Rv + Support { +pub trait InverseCdf: HasDensity + Support { /// The value of the `x` at the given probability in the CDF /// /// # Example @@ -334,7 +354,7 @@ pub trait DiscreteDistr: Rv + Support { if self.supports(x) { self.ln_f(x) } else { - std::f64::NEG_INFINITY + f64::NEG_INFINITY } } } @@ -424,89 +444,6 @@ pub trait KlDivergence { } } -/// The data for this distribution can be summarized by a statistic -pub trait HasSuffStat: Rv { - type Stat: SuffStat; - fn empty_suffstat(&self) -> Self::Stat; - - /// Return the log likelihood for the data represented by the sufficient - /// statistic. - fn ln_f_stat(&self, stat: &Self::Stat) -> f64; -} - -/// Is a [sufficient statistic](https://en.wikipedia.org/wiki/Sufficient_statistic) for a -/// distribution. -/// -/// # Examples -/// -/// Basic suffstat useage. -/// -/// ``` -/// use rv::data::BernoulliSuffStat; -/// use rv::traits::SuffStat; -/// -/// // Bernoulli sufficient statistics are the number of observations, n, and -/// // the number of successes, k. -/// let mut stat = BernoulliSuffStat::new(); -/// -/// assert!(stat.n() == 0 && stat.k() == 0); -/// -/// stat.observe(&true); // observe `true` -/// assert!(stat.n() == 1 && stat.k() == 1); -/// -/// stat.observe(&false); // observe `false` -/// assert!(stat.n() == 2 && stat.k() == 1); -/// -/// stat.forget_many(&vec![false, true]); // forget `true` and `false` -/// assert!(stat.n() == 0 && stat.k() == 0); -/// ``` -/// -/// Conjugate analysis of coin flips using Bernoulli with a Beta prior on the -/// success probability. -/// -/// ``` -/// use rv::traits::SuffStat; -/// use rv::traits::ConjugatePrior; -/// use rv::data::BernoulliSuffStat; -/// use rv::dist::{Bernoulli, Beta}; -/// -/// let flips = vec![true, false, false]; -/// -/// // Pack the data into a sufficient statistic that holds the number of -/// // trials and the number of successes -/// let mut stat = BernoulliSuffStat::new(); -/// stat.observe_many(&flips); -/// -/// let prior = Beta::jeffreys(); -/// -/// // If we observe more false than true, the posterior predictive -/// // probability of true decreases. -/// let pp_no_obs = prior.pp(&true, &(&BernoulliSuffStat::new()).into()); -/// let pp_obs = prior.pp(&true, &(&flips).into()); -/// -/// assert!(pp_obs < pp_no_obs); -/// ``` -pub trait SuffStat { - /// Returns the number of observations - fn n(&self) -> usize; - - /// Assimilate the datum `x` into the statistic - fn observe(&mut self, x: &X); - - /// Remove the datum `x` from the statistic - fn forget(&mut self, x: &X); - - /// Assimilate several observations - fn observe_many(&mut self, xs: &[X]) { - xs.iter().for_each(|x| self.observe(x)); - } - - /// Forget several observations - fn forget_many(&mut self, xs: &[X]) { - xs.iter().for_each(|x| self.forget(x)); - } -} - /// A prior on `Fx` that induces a posterior that is the same form as the prior /// /// # Example @@ -533,14 +470,15 @@ pub trait SuffStat { /// /// ``` /// # use rv::traits::ConjugatePrior; -/// use rv::traits::{Rv, SuffStat}; +/// use rv::traits::*; +/// use rv::traits::SuffStat; /// use rv::dist::{Categorical, SymmetricDirichlet}; /// use rv::data::{CategoricalSuffStat, DataOrSuffStat}; /// use std::time::Instant; /// /// let ncats = 10; /// let symdir = SymmetricDirichlet::jeffreys(ncats).unwrap(); -/// let mut suffstat = CategoricalSuffStat::new(ncats); +/// let mut suffstat = CategoricalSuffStat::new(10); /// let mut rng = rand::thread_rng(); /// /// Categorical::new(&vec![1.0, 1.0, 5.0, 1.0, 2.0, 1.0, 1.0, 2.0, 1.0, 1.0]) @@ -557,7 +495,7 @@ pub trait SuffStat { /// let t_start = Instant::now(); /// let cache = symdir.ln_pp_cache(&stat); /// // Argmax -/// let k_max = (0..ncats).fold((0, std::f64::NEG_INFINITY), |(ix, f), y| { +/// let k_max = (0..ncats).fold((0, f64::NEG_INFINITY), |(ix, f), y| { /// let f_r = symdir.ln_pp_with_cache(&cache, &y); /// if f_r > f { /// (y, f_r) @@ -575,7 +513,7 @@ pub trait SuffStat { /// let t_no_cache = { /// let t_start = Instant::now(); /// // Argmax -/// let k_max = (0..ncats).fold((0, std::f64::NEG_INFINITY), |(ix, f), y| { +/// let k_max = (0..ncats).fold((0, f64::NEG_INFINITY), |(ix, f), y| { /// let f_r = symdir.ln_pp(&y, &stat); /// if f_r > f { /// (y, f_r) @@ -590,29 +528,35 @@ pub trait SuffStat { /// }; /// /// // Using cache improves runtime -/// assert!(t_no_cache.as_nanos() > t_cache.as_nanos()); +/// assert!(t_no_cache.as_nanos() > t_cache.as_nanos()); /// ``` -pub trait ConjugatePrior: Rv +pub trait ConjugatePrior: Sampleable where - Fx: Rv + HasSuffStat, + Fx: HasDensity + HasSuffStat, { /// Type of the posterior distribution - type Posterior: Rv; - /// Type of the `ln_m` cache - type LnMCache; - /// Type of the `ln_pp` cache - type LnPpCache; + type Posterior: Sampleable; + /// Type of the cache for the marginal likelihood + type MCache; + /// Type of the cache for the posterior predictive + type PpCache; /// Computes the posterior distribution from the data + // fn posterior(&self, x: &DataOrSuffStat) -> Self::Posterior; + + fn posterior_from_suffstat(&self, stat: &Fx::Stat) -> Self::Posterior { + self.posterior(&DataOrSuffStat::SuffStat(stat)) + } + fn posterior(&self, x: &DataOrSuffStat) -> Self::Posterior; /// Compute the cache for the log marginal likelihood. - fn ln_m_cache(&self) -> Self::LnMCache; + fn ln_m_cache(&self) -> Self::MCache; /// Log marginal likelihood with supplied cache. fn ln_m_with_cache( &self, - cache: &Self::LnMCache, + cache: &Self::MCache, x: &DataOrSuffStat, ) -> f64; @@ -625,10 +569,10 @@ where /// Compute the cache for the Log posterior predictive of y given x. /// /// The cache should encompass all information about `x`. - fn ln_pp_cache(&self, x: &DataOrSuffStat) -> Self::LnPpCache; + fn ln_pp_cache(&self, x: &DataOrSuffStat) -> Self::PpCache; /// Log posterior predictive of y given x with supplied ln(norm) - fn ln_pp_with_cache(&self, cache: &Self::LnPpCache, y: &X) -> f64; + fn ln_pp_with_cache(&self, cache: &Self::PpCache, y: &X) -> f64; /// Log posterior predictive of y given x fn ln_pp(&self, y: &X, x: &DataOrSuffStat) -> f64 { @@ -641,6 +585,10 @@ where self.ln_m(x).exp() } + fn pp_with_cache(&self, cache: &Self::PpCache, y: &X) -> f64 { + self.ln_pp_with_cache(cache, y).exp() + } + /// Posterior Predictive distribution fn pp(&self, y: &X, x: &DataOrSuffStat) -> f64 { self.ln_pp(y, x).exp() @@ -651,3 +599,87 @@ where pub trait QuadBounds { fn quad_bounds(&self) -> (f64, f64); } + +/// The data for this distribution can be summarized by a statistic +pub trait HasSuffStat { + type Stat: SuffStat; + + fn empty_suffstat(&self) -> Self::Stat; + + /// Return the log likelihood for the data represented by the sufficient + /// statistic. + fn ln_f_stat(&self, stat: &Self::Stat) -> f64; +} + +/// Is a [sufficient statistic](https://en.wikipedia.org/wiki/Sufficient_statistic) for a +/// distribution. +/// +/// # Examples +/// +/// Basic suffstat useage. +/// +/// ``` +/// use rv::data::BernoulliSuffStat; +/// use rv::traits::SuffStat; +/// +/// // Bernoulli sufficient statistics are the number of observations, n, and +/// // the number of successes, k. +/// let mut stat = BernoulliSuffStat::new(); +/// +/// assert!(stat.n() == 0 && stat.k() == 0); +/// +/// stat.observe(&true); // observe `true` +/// assert!(stat.n() == 1 && stat.k() == 1); +/// +/// stat.observe(&false); // observe `false` +/// assert!(stat.n() == 2 && stat.k() == 1); +/// +/// stat.forget_many(&vec![false, true]); // forget `true` and `false` +/// assert!(stat.n() == 0 && stat.k() == 0); +/// ``` +/// +/// Conjugate analysis of coin flips using Bernoulli with a Beta prior on the +/// success probability. +/// +/// ``` +/// use rv::traits::SuffStat; +/// use rv::traits::ConjugatePrior; +/// use rv::data::BernoulliSuffStat; +/// use rv::dist::{Bernoulli, Beta}; +/// +/// let flips = vec![true, false, false]; +/// +/// // Pack the data into a sufficient statistic that holds the number of +/// // trials and the number of successes +/// let mut stat = BernoulliSuffStat::new(); +/// stat.observe_many(&flips); +/// +/// let prior = Beta::jeffreys(); +/// +/// // If we observe more false than true, the posterior predictive +/// // probability of true decreases. +/// let pp_no_obs = prior.pp(&true, &(&BernoulliSuffStat::new()).into()); +/// let pp_obs = prior.pp(&true, &(&flips).into()); +/// +/// assert!(pp_obs < pp_no_obs); +/// ``` +pub trait SuffStat { + /// Returns the number of observations + fn n(&self) -> usize; + + /// Assimilate the datum `x` into the statistic + fn observe(&mut self, x: &X); + + /// Remove the datum `x` from the statistic + fn forget(&mut self, x: &X); + + /// Assimilate several observations + fn observe_many(&mut self, xs: &[X]) { + xs.iter().for_each(|x| self.observe(x)); + } + + /// Forget several observations + fn forget_many(&mut self, xs: &[X]) { + xs.iter().for_each(|x| self.forget(x)); + } +} diff --git a/tests/mi.rs b/tests/mi.rs index 4bdec15..985e935 100644 --- a/tests/mi.rs +++ b/tests/mi.rs @@ -1,6 +1,6 @@ use rand::Rng; use rv::dist::{Gaussian, Mixture}; -use rv::traits::{Entropy, Rv}; +use rv::traits::*; #[test] fn bivariate_mixture_mi() {