Skip to content

Commit

Permalink
add sequential testing
Browse files Browse the repository at this point in the history
  • Loading branch information
Red-Portal committed Nov 7, 2023
1 parent 970f332 commit f4f40b1
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 3 deletions.
5 changes: 5 additions & 0 deletions examples/gibbs/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -102,4 +102,9 @@ function main()
mcmctest(test, TestSubject(model, GibbsRandScan())) |> display
mcmctest(test, TestSubject(model, GibbsRandScanWrongMean())) |> display
mcmctest(test, TestSubject(model, GibbsRandScanWrongVar())) |> display

test = TwoSampleGibbsTest(100, 100, 100)
seqmcmctest(test, TestSubject(model, GibbsRandScan()), 0.001, 32) |> display
seqmcmctest(test, TestSubject(model, GibbsRandScanWrongMean()), 0.001, 32) |> display
seqmcmctest(test, TestSubject(model, GibbsRandScanWrongVar()), 0.001, 32) |> display
end
7 changes: 6 additions & 1 deletion src/MCMCTesting.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,13 @@ export
sample_predictive,
sample_joint,
sample_markov_chain,
mcmctest
mcmctest,
seqmcmctest

using Random
using HypothesisTests
using ProgressMeter
using MultipleTesting

function sample_joint end
function sample_predictive end
Expand All @@ -25,6 +27,9 @@ struct TestSubject{M, K}
kernel::K
end

abstract type AbstractMCMCTest end

include("twosample.jl")
include("seqtest.jl")

end
54 changes: 54 additions & 0 deletions src/seqtest.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@

function seqmcmctest(
test ::AbstractMCMCTest,
subject ::TestSubject,
false_rejection_rate::Real,
samplesize ::Int,
max_iter ::Real = 3,
samplesize_increase ::Real = 2.;
show_progress = true,
pvalue_adjustment::MultipleTesting.PValueAdjustment = MultipleTesting.Bonferroni(),
kwargs...)
α = false_rejection_rate
k = max_iter
β = α / k
γ = β^(1/k)
Δ = samplesize_increase

for i = 1:k
prog = ProgressMeter.Progress(
samplesize;
barlen = 31,
showspeed = true,
enabled = show_progress
)
pvals_all = mapreduce(hcat, 1:samplesize) do n
pval = mcmctest(test, subject; kwargs...)
next!(prog,
showvalues = [
(:test_iteration, i),
(:pvalue_sampling, "$(n)/$(samplesize)")
])
pval
end

pvals_adjusted = mapreduce(vcat, eachcol(pvals_all)) do pvals_paramwise
adjust(Vector(pvals_paramwise), pvalue_adjustment)
end

q = minimum(pvals_adjusted)*length(pvals_adjusted)

if q β
return false
elseif q > γ + β
break
end

β /= γ

if i == 1
samplesize = ceil(Int, samplesize*Δ)
end
end
true
end
4 changes: 2 additions & 2 deletions src/twosample.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@ function markovchain_multiple_transition(
θ
end

struct TwoSampleTest
struct TwoSampleTest <: AbstractMCMCTest
n_control ::Int
n_treatment ::Int
n_mcmc_steps::Int
end

struct TwoSampleGibbsTest
struct TwoSampleGibbsTest <: AbstractMCMCTest
n_control ::Int
n_treatment ::Int
n_mcmc_steps::Int
Expand Down

0 comments on commit f4f40b1

Please sign in to comment.