Skip to content

Commit

Permalink
Merge branch 'main' into primary-cens-example
Browse files Browse the repository at this point in the history
  • Loading branch information
seabbs authored Sep 30, 2024
2 parents 67767bd + 8ed3ec5 commit ee6cc58
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 107 deletions.
1 change: 1 addition & 0 deletions EpiAware/docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
StatsPlots = "f3b207a7-027a-5e70-b257-86293d7955fd"
TimeSeries = "9e3dc215-6440-5c97-bce1-76c03772f85e"
Transducers = "28d57a85-8fef-5791-bfe6-a80928e7c999"
Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
TuringBenchmarking = "0db1332d-5c25-4deb-809f-459bc696f94f"
220 changes: 113 additions & 107 deletions EpiAware/docs/src/showcase/replications/mishra-2020/index.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,7 @@ using Distributions, Statistics #Statistics packages
using CSV, DataFramesMeta #Data wrangling

# ╔═╡ 9eb03a0b-c6ca-4e23-8109-fb68f87d7fdf
begin #Plotting backend
using StatsPlots
using StatsPlots.PlotMeasures
end
using CairoMakie, PairPlots, TimeSeries #Plotting backend

# ╔═╡ 97b5374e-7653-4b3b-98eb-d8f73aa30580
using ReverseDiff #Automatic differentiation backend
Expand Down Expand Up @@ -155,17 +152,19 @@ We can spaghetti plot generative samples from the AR(2) process with the priors
plt_ar_sample = let
n_samples = 100
ar_mdl_samples = mapreduce(hcat, 1:n_samples) do _
ar_mdl() #Sample Z_t trajectories for the model
ar_mdl() .|> exp #Sample Z_t trajectories for the model
end

plot(ar_mdl_samples .|> exp, #R_t = exp(Z_t)
lab = "",
c = :grey,
alpha = 0.25,
title = "$(n_samples) draws from the prior Rₜ model",
fig = Figure()
ax = Axis(fig[1, 1];
yscale = log10,
ylabel = "Time varying Rₜ",
yticks = [10.0^n for n in -4:4],
yscale = :log10)
title = "$(n_samples) draws from the prior Rₜ model"
)
for col in eachcol(ar_mdl_samples)
lines!(ax, col, color = (:grey, 0.1))
end
fig
end

# ╔═╡ 9f84dec1-70f1-442e-8bef-a9494921549e
Expand Down Expand Up @@ -208,14 +207,21 @@ We can compare the discretized generation interval with the continuous estimate,

# ╔═╡ 71d08f7e-c409-4fbe-b154-b21d09010683
let
bar(model_data.gen_int,
fillalpha = 0.5,
lw = 0,
lab = "Discretized next gen pmf",
fig = Figure()
ax = Axis(fig[1, 1];
xticks = 0:14,
xlabel = "Days",
title = "Continuous and discrete generation intervals")
plot!(truth_GI, lab = "Continuous serial interval")
title = "Continuous and discrete generation intervals"
)
barplot!(ax, model_data.gen_int;
label = "Discretized next gen pmf"
)
lines!(truth_GI;
label = "Continuous serial interval",
color = :green
)
axislegend(ax)
fig
end

# ╔═╡ 4a2b5cf1-623c-4fe7-8365-49fb7972af5a
Expand Down Expand Up @@ -259,24 +265,27 @@ latent_inf_mdl = generate_latent_infs(epi, log.(R_t_fixed))
# ╔═╡ 7a6d4b14-58d3-40c1-81f2-713c830f875f
plt_epi = let
n_samples = 100
#Sample unconditionally the underlying parameters of the model
epi_mdl_samples = mapreduce(hcat, 1:n_samples) do _
latent_inf_mdl() #Sample unconditionally the underlying parameters of the model
latent_inf_mdl()
end

p1 = plot(epi_mdl_samples,
lab = "",
c = :grey,
alpha = 0.25,
fig = Figure()
ax1 = Axis(fig[1, 1];
title = "$(n_samples) draws from renewal model with chosen Rt",
ylabel = "Latent infections"
)
p2 = plot(R_t_fixed,
lab = "",
lw = 2,
ax2 = Axis(fig[2, 1];
ylabel = "Rt"
)

plot(p1, p2, layout = (2, 1))
for col in eachcol(epi_mdl_samples)
lines!(ax1, col;
color = (:grey, 0.1)
)
end
lines!(ax2, R_t_fixed;
linewidth = 2
)
fig
end

# ╔═╡ c8ef8a60-d087-4ae9-ae92-abeea5afc7ae
Expand All @@ -296,7 +305,6 @@ A prior for $\phi$ was not specified in _Mishra et al_, we select one below but

# ╔═╡ 714908a1-dc85-476f-a99f-ec5c95a78b60
obs = NegativeBinomialError(cluster_factor_prior = HalfNormal(0.1))
# obs = PoissonError()

# ╔═╡ dacb8094-89a4-404a-8243-525c0dbfa482
md"
Expand Down Expand Up @@ -324,17 +332,23 @@ plt_obs = let
obs_mdl_samples = mapreduce(hcat, 1:n_samples) do _
θ = obs_mdl() #Sample unconditionally the underlying parameters of the model
end
scatter(obs_mdl_samples,
lab = "",
c = :grey,
alpha = 0.25,
fig = Figure()
ax = Axis(fig[1, 1];
title = "$(n_samples) draws from neg. bin. obs model",
ylabel = "Observed cases"
)
plot!(expected_cases,
c = :red,
lw = 3,
lab = "Expected cases")
for col in eachcol(obs_mdl_samples)
scatter!(ax, col;
color = (:grey, 0.2)
)
end
lines!(ax, expected_cases;
color = :red,
linewidth = 3,
label = "Expected cases"
)
axislegend(ax)
fig
end

# ╔═╡ a06065e1-0e20-4cf8-8d5a-2d588da20bee
Expand Down Expand Up @@ -473,9 +487,10 @@ let
C = south_korea_data.y_t
D = south_korea_data.dates

#Unconditional model for posterior predictive sampling
mdl_unconditional = generate_epiaware(epi_prob, (y_t = fill(missing, length(C)),)) |
(var"obs.cluster_factor" = fixed_cluster_factor,)
#Case unconditional model for posterior predictive sampling
mdl_unconditional = generate_epiaware(epi_prob,
(y_t = fill(missing, length(C)),)
) | (var"obs.cluster_factor" = fixed_cluster_factor,)
posterior_gens = generated_quantities(mdl_unconditional, inference_results.samples)

#plotting quantiles
Expand All @@ -486,30 +501,55 @@ let
predicted_R_t = generated_quantiles(
posterior_gens, :Z_t, qs; transformation = x -> exp.(x))

#Plots
p1 = plot(D, predicted_y_t[:, 3], lw = 2, lab = "post. median", c = :purple)
plot!(p1, D, predicted_y_t[:, 2], fillrange = predicted_y_t[:, 4],
fillalpha = 0.5, lw = 0, c = :purple, lab = "50%")
plot!(p1, D, predicted_y_t[:, 1], fillrange = predicted_y_t[:, 5],
fillalpha = 0.2, lw = 0, c = :purple, lab = "95%")

scatter!(p1, D, C,
lab = "Actual cases",
ylabel = "Daily Cases",
title = "Posterior predictive: Cases",
ylims = (-50, maximum(C) * 2),
c = :black
ts = D .|> d -> d - minimum(D) .|> d -> d.value + 1
t_ticks = string.(D)
fig = Figure()
ax1 = Axis(fig[1, 1];
ylabel = "Daily cases",
xticks = (ts[1:14:end], t_ticks[1:14:end]),
title = "Posterior predictive: Cases"
)
ax2 = Axis(fig[2, 1];
yscale = log10,
title = "Prediction: Reproduction number",
xticks = (ts[1:14:end], t_ticks[1:14:end])
)
linkxaxes!(ax1, ax2)

p2 = plot(D, predicted_R_t[:, 3], lw = 2, lab = "post. median", c = :green,
yscale = :log10, title = "Prediction: Reproduction number")
plot!(p2, D, predicted_R_t[:, 2], fillrange = predicted_R_t[:, 4],
fillalpha = 0.5, lw = 0, c = :green, lab = "50%")
plot!(p2, D, predicted_R_t[:, 1], fillrange = predicted_R_t[:, 5],
fillalpha = 0.2, lw = 0, c = :green, lab = "95%")
hline!(p2, [1.0], lab = "Rt = 1", lw = 2, c = :blue)
lines!(ax1, ts, predicted_y_t[:, 3];
color = :purple,
linewidth = 2,
label = "Post. median"
)
band!(ax1, 1:size(predicted_y_t, 1), predicted_y_t[:, 2], predicted_y_t[:, 4];
color = (:purple, 0.4),
label = "50%"
)
band!(ax1, 1:size(predicted_y_t, 1), predicted_y_t[:, 1], predicted_y_t[:, 5];
color = (:purple, 0.2),
label = "95%"
)
scatter!(ax1, C;
color = :black,
label = "Actual cases")
axislegend(ax1)

lines!(ax2, ts, predicted_R_t[:, 3];
color = :green,
linewidth = 2,
label = "Post. median"
)
band!(ax2, 1:size(predicted_R_t, 1), predicted_R_t[:, 2], predicted_R_t[:, 4];
color = (:green, 0.4),
label = "50%"
)
band!(ax2, 1:size(predicted_R_t, 1), predicted_R_t[:, 1], predicted_R_t[:, 5];
color = (:green, 0.2),
label = "95%"
)
axislegend(ax2)

plot(p1, p2, layout = (2, 1), size = (500, 700), left_margin = 5mm)
fig
end

# ╔═╡ c05ed977-7a89-4ac8-97be-7078d69fce9f
Expand All @@ -521,51 +561,17 @@ We can interrogate the sampled chains directly from the `samples` field of the `

# ╔═╡ ff21c9ec-1581-405f-8db1-0f522b5bc296
let
p1 = histogram(inference_results.samples["latent.σ_AR"],
lab = "chain " .* string.([1 2 3 4]),
fillalpha = 0.4,
lw = 0,
norm = :pdf,
title = "Posterior dist: AR noise std")
plot!(p1, ar.std_prior,
lw = 3,
c = :black,
lab = "prior")

p2 = histogram(inference_results.samples[:init_incidence],
lab = "chain " .* string.([1 2 3 4]),
fillalpha = 0.4,
lw = 0,
norm = :pdf,
title = "Posterior dist: log-initial incidence")
plot!(p2, epi.initialisation_prior,
lw = 3,
c = :black,
lab = "prior")

p3 = histogram(inference_results.samples["latent.damp_AR[2]"],
lab = "chain " .* string.([1 2 3 4]),
fillalpha = 0.4,
lw = 0,
norm = :pdf,
title = "Posterior dist: rho_1")
plot!(p3, ar.damp_prior.v[2],
lw = 3,
c = :black,
lab = "prior")

p4 = histogram(inference_results.samples["latent.damp_AR[1]"],
lab = "chain " .* string.([1 2 3 4]),
fillalpha = 0.4,
lw = 0,
norm = :pdf,
title = "Posterior dist: rho_2")
plot!(p4, ar.damp_prior.v[1],
lw = 3,
c = :black,
lab = "prior")

plot(p1, p2, p3, p4, layout = (2, 2), size = (800, 600))
sub_chn = inference_results.samples[inference_results.samples.name_map.parameters[[1:5;
end]]]
fig = pairplot(sub_chn)
lines!(fig[1, 1], ar.std_prior, label = "Prior")
lines!(fig[2, 2], ar.init_prior.v[1], label = "Prior")
lines!(fig[3, 3], ar.init_prior.v[2], label = "Prior")
lines!(fig[4, 4], ar.damp_prior.v[1], label = "Prior")
lines!(fig[5, 5], ar.damp_prior.v[2], label = "Prior")
lines!(fig[6, 6], epi.initialisation_prior, label = "Prior")

fig
end

# ╔═╡ Cell order:
Expand Down

0 comments on commit ee6cc58

Please sign in to comment.