From 60c041cc8a01056e0c2c238dd9422a16336f8691 Mon Sep 17 00:00:00 2001 From: Daniel Sabanes Bove Date: Wed, 22 Nov 2023 22:25:04 +0100 Subject: [PATCH 1/2] first stab - need to complete test --- Project.toml | 7 +++++++ src/BSSD.jl | 12 ++++++++++++ src/meta_analytic.jl | 28 ++++++++++++++++++++++++++++ test/runtests.jl | 42 ++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 89 insertions(+) create mode 100644 src/BSSD.jl create mode 100644 src/meta_analytic.jl create mode 100644 test/runtests.jl diff --git a/Project.toml b/Project.toml index 8a76e30..3dfc028 100644 --- a/Project.toml +++ b/Project.toml @@ -1,3 +1,8 @@ +name = "BSSD" +uuid = "ad9a3248-762c-4ade-a843-8f185e2c18ff" +authors = ["Daniel Sabanes Bove ", "Kristian Brock "] +version = "0.0.1" + [deps] CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b" DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" @@ -6,5 +11,7 @@ ExpectationMaximization = "e1fe09cc-5134-44c2-a941-50f4cd97986a" FreqTables = "da1fdf0e-e0ff-5433-a45f-9bb5ff651cb1" Pluto = "c3e4b0f8-55cb-11ea-2926-15256bba5781" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" StatsPlots = "f3b207a7-027a-5e70-b257-86293d7955fd" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" diff --git a/src/BSSD.jl b/src/BSSD.jl new file mode 100644 index 0000000..76520e1 --- /dev/null +++ b/src/BSSD.jl @@ -0,0 +1,12 @@ +module BSSD + +using Turing +using StatsPlots +using Distributions + +export + meta_analytic + +include("meta_analytic.jl") + +end \ No newline at end of file diff --git a/src/meta_analytic.jl b/src/meta_analytic.jl new file mode 100644 index 0000000..f006482 --- /dev/null +++ b/src/meta_analytic.jl @@ -0,0 +1,28 @@ +""" + Meta Analytic Prior Model + +This Turing model is used to generate posterior samples of the parameters `a` and `b`. +""" +@model function meta_analytic( + y::Vector{Bool}, + time::Vector{AbstractFloat}, + trialindex::Vector{Integer}, + prior_a::Distribution, + prior_b::Distribution) + + n = length(y) + n_trials = maximum(trialindex) + + a ~ prior_a + b ~ prior_b + pis ~ filldist(Beta(a * b * n, (1 - a) * b * n), n_trials) + + for i in 1:n + pi = pis[trialindex[i]] + mu = log(-log(1 - pi)) + x = mu + log(time[i]) + prob = 1 - exp(-exp(x)) + y[i] ~ Bernoulli(prob) + end + +end; \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl new file mode 100644 index 0000000..ddecc79 --- /dev/null +++ b/test/runtests.jl @@ -0,0 +1,42 @@ +using Test +using StableRNGs +using Random +using Distributions +using BSSD + +# Helper function for numerical tests. +# Taken from https://github.com/TuringLang/Turing.jl/blob/master/test/test_utils/numerical_tests.jl#L41 for now. +function check_numerical(chain, + symbols::Vector, + exact_vals::Vector; + atol=0.2, + rtol=0.0) + for (sym, val) in zip(symbols, exact_vals) + E = val isa Real ? + mean(chain[sym]) : + vec(mean(chain[sym], dims=1)) + @info (symbol=sym, exact=val, evaluated=E) + @test E ≈ val atol=atol rtol=rtol + end +end + +@testset "meta_analytic.jl" begin + rng = StableRNG(123) + + n_trials = 5 + n_patients = 50 + df = DataFrame( + y = rand(rng, Bernoulli(0.2), n_trials * n_patients), + time = rand(rng, Exponential(1), n_trials * n_patients), + trialindex = repeat(1:n_trials, n_patients) + ) + + chain = sample( + rng, + meta_analytic(df.y, df.time, df.trialindex, Beta(2, 8), Beta(9, 10)), + HMC(0.05, 10), + 1000 + ) + + check_numerical(chain, ) +end From 6a2797e3d2140413565ad2a2d6bb79657aa6b3d9 Mon Sep 17 00:00:00 2001 From: Daniel Sabanes Bove Date: Wed, 22 Nov 2023 22:32:09 +0100 Subject: [PATCH 2/2] fix types, complete test --- src/meta_analytic.jl | 4 ++-- test/runtests.jl | 5 ++++- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/src/meta_analytic.jl b/src/meta_analytic.jl index f006482..4662efa 100644 --- a/src/meta_analytic.jl +++ b/src/meta_analytic.jl @@ -5,8 +5,8 @@ This Turing model is used to generate posterior samples of the parameters `a` an """ @model function meta_analytic( y::Vector{Bool}, - time::Vector{AbstractFloat}, - trialindex::Vector{Integer}, + time::Vector{Float64}, + trialindex::Vector{Int64}, prior_a::Distribution, prior_b::Distribution) diff --git a/test/runtests.jl b/test/runtests.jl index ddecc79..e806b9d 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2,6 +2,8 @@ using Test using StableRNGs using Random using Distributions +using DataFrames +using Turing using BSSD # Helper function for numerical tests. @@ -38,5 +40,6 @@ end 1000 ) - check_numerical(chain, ) + check_numerical(chain, [:a], [0.223], rtol=0.001) + check_numerical(chain, [:b], [0.485], rtol=0.001) end