diff --git a/Project.toml b/Project.toml index 5c4b503..23e6fa0 100644 --- a/Project.toml +++ b/Project.toml @@ -11,5 +11,7 @@ ExpectationMaximization = "e1fe09cc-5134-44c2-a941-50f4cd97986a" FreqTables = "da1fdf0e-e0ff-5433-a45f-9bb5ff651cb1" Pluto = "c3e4b0f8-55cb-11ea-2926-15256bba5781" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" StatsPlots = "f3b207a7-027a-5e70-b257-86293d7955fd" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" diff --git a/src/BSSD.jl b/src/BSSD.jl new file mode 100644 index 0000000..76520e1 --- /dev/null +++ b/src/BSSD.jl @@ -0,0 +1,12 @@ +module BSSD + +using Turing +using StatsPlots +using Distributions + +export + meta_analytic + +include("meta_analytic.jl") + +end \ No newline at end of file diff --git a/src/meta_analytic.jl b/src/meta_analytic.jl new file mode 100644 index 0000000..4662efa --- /dev/null +++ b/src/meta_analytic.jl @@ -0,0 +1,28 @@ +""" + Meta Analytic Prior Model + +This Turing model is used to generate posterior samples of the parameters `a` and `b`. +""" +@model function meta_analytic( + y::Vector{Bool}, + time::Vector{Float64}, + trialindex::Vector{Int64}, + prior_a::Distribution, + prior_b::Distribution) + + n = length(y) + n_trials = maximum(trialindex) + + a ~ prior_a + b ~ prior_b + pis ~ filldist(Beta(a * b * n, (1 - a) * b * n), n_trials) + + for i in 1:n + pi = pis[trialindex[i]] + mu = log(-log(1 - pi)) + x = mu + log(time[i]) + prob = 1 - exp(-exp(x)) + y[i] ~ Bernoulli(prob) + end + +end; \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl new file mode 100644 index 0000000..e806b9d --- /dev/null +++ b/test/runtests.jl @@ -0,0 +1,45 @@ +using Test +using StableRNGs +using Random +using Distributions +using DataFrames +using Turing +using BSSD + +# Helper function for numerical tests. +# Taken from https://github.com/TuringLang/Turing.jl/blob/master/test/test_utils/numerical_tests.jl#L41 for now. +function check_numerical(chain, + symbols::Vector, + exact_vals::Vector; + atol=0.2, + rtol=0.0) + for (sym, val) in zip(symbols, exact_vals) + E = val isa Real ? + mean(chain[sym]) : + vec(mean(chain[sym], dims=1)) + @info (symbol=sym, exact=val, evaluated=E) + @test E ≈ val atol=atol rtol=rtol + end +end + +@testset "meta_analytic.jl" begin + rng = StableRNG(123) + + n_trials = 5 + n_patients = 50 + df = DataFrame( + y = rand(rng, Bernoulli(0.2), n_trials * n_patients), + time = rand(rng, Exponential(1), n_trials * n_patients), + trialindex = repeat(1:n_trials, n_patients) + ) + + chain = sample( + rng, + meta_analytic(df.y, df.time, df.trialindex, Beta(2, 8), Beta(9, 10)), + HMC(0.05, 10), + 1000 + ) + + check_numerical(chain, [:a], [0.223], rtol=0.001) + check_numerical(chain, [:b], [0.485], rtol=0.001) +end