Skip to content

Commit

Permalink
Merge pull request #6 from openpharma/2_meta_analytic
Browse files Browse the repository at this point in the history
2: Add meta_analytic Turing model
  • Loading branch information
brockk authored Nov 30, 2023
2 parents 96bf369 + 5d434ba commit f9716a9
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 0 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,7 @@ ExpectationMaximization = "e1fe09cc-5134-44c2-a941-50f4cd97986a"
FreqTables = "da1fdf0e-e0ff-5433-a45f-9bb5ff651cb1"
Pluto = "c3e4b0f8-55cb-11ea-2926-15256bba5781"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
StatsPlots = "f3b207a7-027a-5e70-b257-86293d7955fd"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
12 changes: 12 additions & 0 deletions src/BSSD.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
module BSSD

using Turing
using StatsPlots
using Distributions

export
meta_analytic

include("meta_analytic.jl")

end
28 changes: 28 additions & 0 deletions src/meta_analytic.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
"""
Meta Analytic Prior Model
This Turing model is used to generate posterior samples of the parameters `a` and `b`.
"""
@model function meta_analytic(
y::Vector{Bool},
time::Vector{Float64},
trialindex::Vector{Int64},
prior_a::Distribution,
prior_b::Distribution)

n = length(y)
n_trials = maximum(trialindex)

a ~ prior_a
b ~ prior_b
pis ~ filldist(Beta(a * b * n, (1 - a) * b * n), n_trials)

for i in 1:n
pi = pis[trialindex[i]]
mu = log(-log(1 - pi))
x = mu + log(time[i])
prob = 1 - exp(-exp(x))
y[i] ~ Bernoulli(prob)
end

end;
45 changes: 45 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
using Test
using StableRNGs
using Random
using Distributions
using DataFrames
using Turing
using BSSD

# Helper function for numerical tests.
# Taken from https://github.com/TuringLang/Turing.jl/blob/master/test/test_utils/numerical_tests.jl#L41 for now.
function check_numerical(chain,
symbols::Vector,
exact_vals::Vector;
atol=0.2,
rtol=0.0)
for (sym, val) in zip(symbols, exact_vals)
E = val isa Real ?
mean(chain[sym]) :
vec(mean(chain[sym], dims=1))
@info (symbol=sym, exact=val, evaluated=E)
@test E val atol=atol rtol=rtol
end
end

@testset "meta_analytic.jl" begin
rng = StableRNG(123)

n_trials = 5
n_patients = 50
df = DataFrame(
y = rand(rng, Bernoulli(0.2), n_trials * n_patients),
time = rand(rng, Exponential(1), n_trials * n_patients),
trialindex = repeat(1:n_trials, n_patients)
)

chain = sample(
rng,
meta_analytic(df.y, df.time, df.trialindex, Beta(2, 8), Beta(9, 10)),
HMC(0.05, 10),
1000
)

check_numerical(chain, [:a], [0.223], rtol=0.001)
check_numerical(chain, [:b], [0.485], rtol=0.001)
end

0 comments on commit f9716a9

Please sign in to comment.