Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Plots to Makie in Mishra replication #461

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion EpiAware/docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,19 @@
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
AdvancedHMC = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d"
CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b"
CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0"
Changelog = "5217a498-cd5d-4ec6-b8c2-9b85a09b6e3e"
DataFramesMeta = "1313f7d8-7da2-5740-9ea0-a2ca25f37964"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8"
EpiAware = "b2eeebe4-5992-4301-9193-7ebc9f62c855"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
PairPlots = "43a3c2be-4208-490b-832a-a21dcd55d7da"
Pluto = "c3e4b0f8-55cb-11ea-2926-15256bba5781"
PlutoStaticHTML = "359b1769-a58e-495b-9770-312e911026ad"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsPlots = "f3b207a7-027a-5e70-b257-86293d7955fd"
TimeSeries = "9e3dc215-6440-5c97-bce1-76c03772f85e"
Transducers = "28d57a85-8fef-5791-bfe6-a80928e7c999"
Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
220 changes: 113 additions & 107 deletions EpiAware/docs/src/showcase/replications/mishra-2020/index.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,7 @@ using Distributions, Statistics #Statistics packages
using CSV, DataFramesMeta #Data wrangling

# ╔═╡ 9eb03a0b-c6ca-4e23-8109-fb68f87d7fdf
begin #Plotting backend
using StatsPlots
using StatsPlots.PlotMeasures
end
using CairoMakie, PairPlots, TimeSeries #Plotting backend

# ╔═╡ 97b5374e-7653-4b3b-98eb-d8f73aa30580
using ReverseDiff #Automatic differentiation backend
Expand Down Expand Up @@ -155,17 +152,19 @@ We can spaghetti plot generative samples from the AR(2) process with the priors
plt_ar_sample = let
n_samples = 100
ar_mdl_samples = mapreduce(hcat, 1:n_samples) do _
ar_mdl() #Sample Z_t trajectories for the model
ar_mdl() .|> exp #Sample Z_t trajectories for the model
end

plot(ar_mdl_samples .|> exp, #R_t = exp(Z_t)
lab = "",
c = :grey,
alpha = 0.25,
title = "$(n_samples) draws from the prior Rₜ model",
fig = Figure()
ax = Axis(fig[1, 1];
yscale = log10,
ylabel = "Time varying Rₜ",
yticks = [10.0^n for n in -4:4],
yscale = :log10)
title = "$(n_samples) draws from the prior Rₜ model"
)
for col in eachcol(ar_mdl_samples)
lines!(ax, col, color = (:grey, 0.1))
end
fig
end

# ╔═╡ 9f84dec1-70f1-442e-8bef-a9494921549e
Expand Down Expand Up @@ -208,14 +207,21 @@ We can compare the discretized generation interval with the continuous estimate,

# ╔═╡ 71d08f7e-c409-4fbe-b154-b21d09010683
let
bar(model_data.gen_int,
fillalpha = 0.5,
lw = 0,
lab = "Discretized next gen pmf",
fig = Figure()
ax = Axis(fig[1, 1];
xticks = 0:14,
xlabel = "Days",
title = "Continuous and discrete generation intervals")
plot!(truth_GI, lab = "Continuous serial interval")
title = "Continuous and discrete generation intervals"
)
barplot!(ax, model_data.gen_int;
label = "Discretized next gen pmf"
)
lines!(truth_GI;
label = "Continuous serial interval",
color = :green
)
axislegend(ax)
fig
end

# ╔═╡ 4a2b5cf1-623c-4fe7-8365-49fb7972af5a
Expand Down Expand Up @@ -259,24 +265,27 @@ latent_inf_mdl = generate_latent_infs(epi, log.(R_t_fixed))
# ╔═╡ 7a6d4b14-58d3-40c1-81f2-713c830f875f
plt_epi = let
n_samples = 100
#Sample unconditionally the underlying parameters of the model
epi_mdl_samples = mapreduce(hcat, 1:n_samples) do _
latent_inf_mdl() #Sample unconditionally the underlying parameters of the model
latent_inf_mdl()
end

p1 = plot(epi_mdl_samples,
lab = "",
c = :grey,
alpha = 0.25,
fig = Figure()
ax1 = Axis(fig[1, 1];
title = "$(n_samples) draws from renewal model with chosen Rt",
ylabel = "Latent infections"
)
p2 = plot(R_t_fixed,
lab = "",
lw = 2,
ax2 = Axis(fig[2, 1];
ylabel = "Rt"
)

plot(p1, p2, layout = (2, 1))
for col in eachcol(epi_mdl_samples)
lines!(ax1, col;
color = (:grey, 0.1)
)
end
lines!(ax2, R_t_fixed;
linewidth = 2
)
fig
end

# ╔═╡ c8ef8a60-d087-4ae9-ae92-abeea5afc7ae
Expand All @@ -296,7 +305,6 @@ A prior for $\phi$ was not specified in _Mishra et al_, we select one below but

# ╔═╡ 714908a1-dc85-476f-a99f-ec5c95a78b60
obs = NegativeBinomialError(cluster_factor_prior = HalfNormal(0.1))
# obs = PoissonError()

# ╔═╡ dacb8094-89a4-404a-8243-525c0dbfa482
md"
Expand Down Expand Up @@ -324,17 +332,23 @@ plt_obs = let
obs_mdl_samples = mapreduce(hcat, 1:n_samples) do _
θ = obs_mdl() #Sample unconditionally the underlying parameters of the model
end
scatter(obs_mdl_samples,
lab = "",
c = :grey,
alpha = 0.25,
fig = Figure()
ax = Axis(fig[1, 1];
title = "$(n_samples) draws from neg. bin. obs model",
ylabel = "Observed cases"
)
plot!(expected_cases,
c = :red,
lw = 3,
lab = "Expected cases")
for col in eachcol(obs_mdl_samples)
scatter!(ax, col;
color = (:grey, 0.2)
)
end
lines!(ax, expected_cases;
color = :red,
linewidth = 3,
label = "Expected cases"
)
axislegend(ax)
fig
end

# ╔═╡ a06065e1-0e20-4cf8-8d5a-2d588da20bee
Expand Down Expand Up @@ -473,9 +487,10 @@ let
C = south_korea_data.y_t
D = south_korea_data.dates

#Unconditional model for posterior predictive sampling
mdl_unconditional = generate_epiaware(epi_prob, (y_t = fill(missing, length(C)),)) |
(var"obs.cluster_factor" = fixed_cluster_factor,)
#Case unconditional model for posterior predictive sampling
mdl_unconditional = generate_epiaware(epi_prob,
(y_t = fill(missing, length(C)),)
) | (var"obs.cluster_factor" = fixed_cluster_factor,)
posterior_gens = generated_quantities(mdl_unconditional, inference_results.samples)

#plotting quantiles
Expand All @@ -486,30 +501,55 @@ let
predicted_R_t = generated_quantiles(
posterior_gens, :Z_t, qs; transformation = x -> exp.(x))

#Plots
p1 = plot(D, predicted_y_t[:, 3], lw = 2, lab = "post. median", c = :purple)
plot!(p1, D, predicted_y_t[:, 2], fillrange = predicted_y_t[:, 4],
fillalpha = 0.5, lw = 0, c = :purple, lab = "50%")
plot!(p1, D, predicted_y_t[:, 1], fillrange = predicted_y_t[:, 5],
fillalpha = 0.2, lw = 0, c = :purple, lab = "95%")

scatter!(p1, D, C,
lab = "Actual cases",
ylabel = "Daily Cases",
title = "Posterior predictive: Cases",
ylims = (-50, maximum(C) * 2),
c = :black
ts = D .|> d -> d - minimum(D) .|> d -> d.value + 1
t_ticks = string.(D)
fig = Figure()
ax1 = Axis(fig[1, 1];
ylabel = "Daily cases",
xticks = (ts[1:14:end], t_ticks[1:14:end]),
title = "Posterior predictive: Cases"
)
ax2 = Axis(fig[2, 1];
yscale = log10,
title = "Prediction: Reproduction number",
xticks = (ts[1:14:end], t_ticks[1:14:end])
)
linkxaxes!(ax1, ax2)

p2 = plot(D, predicted_R_t[:, 3], lw = 2, lab = "post. median", c = :green,
yscale = :log10, title = "Prediction: Reproduction number")
plot!(p2, D, predicted_R_t[:, 2], fillrange = predicted_R_t[:, 4],
fillalpha = 0.5, lw = 0, c = :green, lab = "50%")
plot!(p2, D, predicted_R_t[:, 1], fillrange = predicted_R_t[:, 5],
fillalpha = 0.2, lw = 0, c = :green, lab = "95%")
hline!(p2, [1.0], lab = "Rt = 1", lw = 2, c = :blue)
lines!(ax1, ts, predicted_y_t[:, 3];
color = :purple,
linewidth = 2,
label = "Post. median"
)
band!(ax1, 1:size(predicted_y_t, 1), predicted_y_t[:, 2], predicted_y_t[:, 4];
color = (:purple, 0.4),
label = "50%"
)
band!(ax1, 1:size(predicted_y_t, 1), predicted_y_t[:, 1], predicted_y_t[:, 5];
color = (:purple, 0.2),
label = "95%"
)
scatter!(ax1, C;
color = :black,
label = "Actual cases")
axislegend(ax1)

lines!(ax2, ts, predicted_R_t[:, 3];
color = :green,
linewidth = 2,
label = "Post. median"
)
band!(ax2, 1:size(predicted_R_t, 1), predicted_R_t[:, 2], predicted_R_t[:, 4];
color = (:green, 0.4),
label = "50%"
)
band!(ax2, 1:size(predicted_R_t, 1), predicted_R_t[:, 1], predicted_R_t[:, 5];
color = (:green, 0.2),
label = "95%"
)
axislegend(ax2)

plot(p1, p2, layout = (2, 1), size = (500, 700), left_margin = 5mm)
fig
end

# ╔═╡ c05ed977-7a89-4ac8-97be-7078d69fce9f
Expand All @@ -521,51 +561,17 @@ We can interrogate the sampled chains directly from the `samples` field of the `

# ╔═╡ ff21c9ec-1581-405f-8db1-0f522b5bc296
let
p1 = histogram(inference_results.samples["latent.σ_AR"],
lab = "chain " .* string.([1 2 3 4]),
fillalpha = 0.4,
lw = 0,
norm = :pdf,
title = "Posterior dist: AR noise std")
plot!(p1, ar.std_prior,
lw = 3,
c = :black,
lab = "prior")

p2 = histogram(inference_results.samples[:init_incidence],
lab = "chain " .* string.([1 2 3 4]),
fillalpha = 0.4,
lw = 0,
norm = :pdf,
title = "Posterior dist: log-initial incidence")
plot!(p2, epi.initialisation_prior,
lw = 3,
c = :black,
lab = "prior")

p3 = histogram(inference_results.samples["latent.damp_AR[2]"],
lab = "chain " .* string.([1 2 3 4]),
fillalpha = 0.4,
lw = 0,
norm = :pdf,
title = "Posterior dist: rho_1")
plot!(p3, ar.damp_prior.v[2],
lw = 3,
c = :black,
lab = "prior")

p4 = histogram(inference_results.samples["latent.damp_AR[1]"],
lab = "chain " .* string.([1 2 3 4]),
fillalpha = 0.4,
lw = 0,
norm = :pdf,
title = "Posterior dist: rho_2")
plot!(p4, ar.damp_prior.v[1],
lw = 3,
c = :black,
lab = "prior")

plot(p1, p2, p3, p4, layout = (2, 2), size = (800, 600))
sub_chn = inference_results.samples[inference_results.samples.name_map.parameters[[1:5;
end]]]
fig = pairplot(sub_chn)
lines!(fig[1, 1], ar.std_prior, label = "Prior")
lines!(fig[2, 2], ar.init_prior.v[1], label = "Prior")
lines!(fig[3, 3], ar.init_prior.v[2], label = "Prior")
lines!(fig[4, 4], ar.damp_prior.v[1], label = "Prior")
lines!(fig[5, 5], ar.damp_prior.v[2], label = "Prior")
lines!(fig[6, 6], epi.initialisation_prior, label = "Prior")

fig
end

# ╔═╡ Cell order:
Expand Down
Loading