Skip to content

Commit

Permalink
add tests, add missing gibbs sampling file
Browse files Browse the repository at this point in the history
  • Loading branch information
Red-Portal committed Nov 15, 2023
1 parent 9cb32ee commit 953df70
Show file tree
Hide file tree
Showing 3 changed files with 172 additions and 5 deletions.
1 change: 0 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ uuid = "9963b6a1-5d46-439c-8efc-3a487843c7fa"
version = "0.1.1"

[deps]
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
HypothesisTests = "09f84164-cd44-5f33-b23f-e6b0d136a0d5"
MultipleTesting = "f8716d33-7c4a-5097-896f-ce0ecbd3ef6b"
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
Expand Down
92 changes: 92 additions & 0 deletions test/models/normalnormalgibbs.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@

struct Model
sigma ::Float64
sigma_eps::Float64
end

struct GibbsRandScan end

struct GibbsRandScanWrongMean end

struct GibbsRandScanWrongVar end

function MCMCTesting.sample_predictive(rng::Random.AbstractRNG, model::Model, θ)
# y ∼ θ₁ + θ₂ + ϵ
rand(rng, Normal(θ[1] + θ[2], model.sigma_eps))
end

function MCMCTesting.sample_joint(rng::Random.AbstractRNG, model::Model)
θ₁ = rand(rng, Normal(0, model.sigma))
θ₂ = rand(rng, Normal(0, model.sigma))
θ = [θ₁, θ₂]
y = MCMCTesting.sample_predictive(rng, model, θ)
θ, y
end

function complete_conditional::Real, σ²::Real, σ²_ϵ::Real, y::Real)
μ = σ²/(σ²_ϵ + σ²)*(y - θ)
σ = 1/sqrt(1/σ²_ϵ + 1/σ²)
Normal(μ, σ)
end

function complete_conditional_wrongmean::Real, σ²::Real, σ²_ϵ::Real, y::Real)
μ = σ²/(σ²_ϵ + σ²)*(y + θ)
σ = 1/sqrt(1/σ²_ϵ + 1/σ²)
Normal(μ, σ)
end

function complete_conditional_wrongvar::Real, σ²::Real, σ²_ϵ::Real, y::Real)
μ = σ²/(σ²_ϵ + σ²)*(y - θ)
Normal(μ, sqrt(σ²))
end

function MCMCTesting.markovchain_transition(
rng::Random.AbstractRNG, model::Model, kernel::GibbsRandScan, θ, y
)
θ = copy(θ)
σ² = model.sigma^2
σ²_ϵ = model.sigma_eps^2

if rand(rng, Bernoulli(0.5))
θ[1] = rand(rng, complete_conditional(θ[2], σ², σ²_ϵ, y))
θ[2] = rand(rng, complete_conditional(θ[1], σ², σ²_ϵ, y))
else
θ[2] = rand(rng, complete_conditional(θ[1], σ², σ²_ϵ, y))
θ[1] = rand(rng, complete_conditional(θ[2], σ², σ²_ϵ, y))
end
θ
end

function MCMCTesting.markovchain_transition(
rng::Random.AbstractRNG, model::Model, kernel::GibbsRandScanWrongMean, θ, y
)
θ = copy(θ)
σ² = model.sigma^2
σ²_ϵ = model.sigma_eps^2

if rand(rng, Bernoulli(0.5))
θ[1] = rand(rng, complete_conditional_wrongmean(θ[2], σ², σ²_ϵ, y))
θ[2] = rand(rng, complete_conditional( θ[1], σ², σ²_ϵ, y))
else
θ[2] = rand(rng, complete_conditional_wrongmean(θ[1], σ², σ²_ϵ, y))
θ[1] = rand(rng, complete_conditional( θ[2], σ², σ²_ϵ, y))
end
θ
end

function MCMCTesting.markovchain_transition(
rng::Random.AbstractRNG, model::Model, kernel::GibbsRandScanWrongVar, θ, y
)
θ = copy(θ)
σ² = model.sigma^2
σ²_ϵ = model.sigma_eps^2

if rand(rng, Bernoulli(0.5))
θ[1] = rand(rng, complete_conditional_wrongvar(θ[2], σ², σ²_ϵ, y))
θ[2] = rand(rng, complete_conditional( θ[1], σ², σ²_ϵ, y))
else
θ[2] = rand(rng, complete_conditional_wrongvar(θ[1], σ², σ²_ϵ, y))
θ[1] = rand(rng, complete_conditional( θ[2], σ², σ²_ϵ, y))
end
θ
end
84 changes: 80 additions & 4 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,90 @@
using Distributions
using Random
using Test
using StableRNGs

using MCMCTesting

include("models/normalnormalgibbs.jl")

@testset "interfaces" begin
model = Model(sqrt(10), sqrt(0.1))
n_samples = 100
mcmc = GibbsRandScan()
subject = TestSubject(model, mcmc)

n_dim_param = 2
n_dim_data = 1

@testset for test in [TwoSampleTest(n_samples, n_samples),
TwoSampleGibbsTest(n_samples, n_samples)]
@testset "mcmctest" begin
@testset "return type" begin
pvalue = mcmctest(test, subject; show_progress=false)
@test eltype(pvalue) <: Real

n_dim_default = if test isa TwoSampleTest
(n_dim_param + n_dim_data)*2
else
n_dim_param*2
end
@test length(pvalue) == n_dim_default
end

@testset "custom statistics" begin
pvalue = mcmctest(test, subject; show_progress=false, statistics = θ -> θ)

n_dim_default = if test isa TwoSampleTest
(n_dim_param + n_dim_data)*2
else
n_dim_param*2
end
@test length(pvalue) == div(n_dim_default,2)
end

@testset "determinism" begin
rng = StableRNG(1)
pvalue = mcmctest(rng, test, subject; show_progress=false)

rng = StableRNG(1)
pvalue_rep = mcmctest(rng, test, subject; show_progress=false)
@test pvalue == pvalue_rep
end
end

@testset "seqmcmctest" begin
@testset "return type" begin
res = seqmcmctest(test, subject, 0.001, 32; show_progress=false)
@test res isa Bool
end

@testset "custom statistics" begin
seqmcmctest(test, subject, 0.001, 32; show_progress=false, statistics = θ -> θ)

rng = StableRNG(1)
seqmcmctest(rng, test, subject, 0.001, 32; show_progress=false, statistics = θ -> θ)
end

@testset "determinism" begin
rng = StableRNG(1)
res = seqmcmctest(rng, test, subject, 0.001, 32; show_progress=false)

rng = StableRNG(1)
res_rep = seqmcmctest(rng, test, subject, 0.001, 32; show_progress=false)
@test res == res_rep
end
end
end

@testset "rank simulation" begin
test = ExactRankTest(n_samples, n_samples)
ranks = simulate_ranks(test, subject; show_progress=false)
@test all(@. 1 ranks n_samples)
@test eltype(ranks) <: Integer
@test size(ranks) == (2*2, n_samples)
end
end

@testset "normal-normal gibbs" begin
model = Model(sqrt(10), sqrt(0.1))

Expand All @@ -18,26 +97,23 @@ include("models/normalnormalgibbs.jl")

@testset "mcmctest" begin
pvalue = mcmctest(test, TestSubject(model, GibbsRandScan()); show_progress=false)
@test eltype(pvalue) <: Real
@test all(@. 0 pvalue 1)
end

@testset "seqmcmctest correct" begin
test_res = seqmcmctest(test, TestSubject(model, GibbsRandScan()), 0.001, 32; show_progress=false)
@test test_res isa Bool
@test test_res
end

@testset "seqmcmctest wrong mean" begin
test_res = seqmcmctest(test, TestSubject(model, GibbsRandScanWrongMean()), 0.001, 32; show_progress=false)
@test test_res isa Bool
@test !test_res
end

@testset "seqmcmctest wrong var" begin
test_res = seqmcmctest(test, TestSubject(model, GibbsRandScanWrongVar()), 0.001, 32; show_progress=false)
@test test_res isa Bool
@test !test_res
end
end
end

0 comments on commit 953df70

Please sign in to comment.