Skip to content

Commit

Permalink
Plots to Makie
Browse files Browse the repository at this point in the history
Remove Statsplots and convert all plots to CairoMakie in Mishra replication
  • Loading branch information
SamuelBrand1 committed Sep 27, 2024
1 parent 7372fba commit 11b9034
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 108 deletions.
4 changes: 3 additions & 1 deletion EpiAware/docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,19 @@
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
AdvancedHMC = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d"
CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b"
CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0"
Changelog = "5217a498-cd5d-4ec6-b8c2-9b85a09b6e3e"
DataFramesMeta = "1313f7d8-7da2-5740-9ea0-a2ca25f37964"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8"
EpiAware = "b2eeebe4-5992-4301-9193-7ebc9f62c855"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
PairPlots = "43a3c2be-4208-490b-832a-a21dcd55d7da"
Pluto = "c3e4b0f8-55cb-11ea-2926-15256bba5781"
PlutoStaticHTML = "359b1769-a58e-495b-9770-312e911026ad"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsPlots = "f3b207a7-027a-5e70-b257-86293d7955fd"
TimeSeries = "9e3dc215-6440-5c97-bce1-76c03772f85e"
Transducers = "28d57a85-8fef-5791-bfe6-a80928e7c999"
Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
220 changes: 113 additions & 107 deletions EpiAware/docs/src/showcase/replications/mishra-2020/index.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,7 @@ using Distributions, Statistics #Statistics packages
using CSV, DataFramesMeta #Data wrangling

# ╔═╡ 9eb03a0b-c6ca-4e23-8109-fb68f87d7fdf
begin #Plotting backend
using StatsPlots
using StatsPlots.PlotMeasures
end
using CairoMakie, PairPlots, TimeSeries #Plotting backend

# ╔═╡ 97b5374e-7653-4b3b-98eb-d8f73aa30580
using ReverseDiff #Automatic differentiation backend
Expand Down Expand Up @@ -155,17 +152,19 @@ We can spaghetti plot generative samples from the AR(2) process with the priors
plt_ar_sample = let
n_samples = 100
ar_mdl_samples = mapreduce(hcat, 1:n_samples) do _
ar_mdl() #Sample Z_t trajectories for the model
ar_mdl() .|> exp #Sample Z_t trajectories for the model
end

plot(ar_mdl_samples .|> exp, #R_t = exp(Z_t)
lab = "",
c = :grey,
alpha = 0.25,
title = "$(n_samples) draws from the prior Rₜ model",
fig = Figure()
ax = Axis(fig[1, 1];
yscale = log10,
ylabel = "Time varying Rₜ",
yticks = [10.0^n for n in -4:4],
yscale = :log10)
title = "$(n_samples) draws from the prior Rₜ model"
)
for col in eachcol(ar_mdl_samples)
lines!(ax, col, color = (:grey, 0.1))
end
fig
end

# ╔═╡ 9f84dec1-70f1-442e-8bef-a9494921549e
Expand Down Expand Up @@ -208,14 +207,21 @@ We can compare the discretized generation interval with the continuous estimate,

# ╔═╡ 71d08f7e-c409-4fbe-b154-b21d09010683
let
bar(model_data.gen_int,
fillalpha = 0.5,
lw = 0,
lab = "Discretized next gen pmf",
fig = Figure()
ax = Axis(fig[1, 1];
xticks = 0:14,
xlabel = "Days",
title = "Continuous and discrete generation intervals")
plot!(truth_GI, lab = "Continuous serial interval")
title = "Continuous and discrete generation intervals"
)
barplot!(ax, model_data.gen_int;
label = "Discretized next gen pmf"
)
lines!(truth_GI;
label = "Continuous serial interval",
color = :green
)
axislegend(ax)
fig
end

# ╔═╡ 4a2b5cf1-623c-4fe7-8365-49fb7972af5a
Expand Down Expand Up @@ -259,24 +265,27 @@ latent_inf_mdl = generate_latent_infs(epi, log.(R_t_fixed))
# ╔═╡ 7a6d4b14-58d3-40c1-81f2-713c830f875f
plt_epi = let
n_samples = 100
#Sample unconditionally the underlying parameters of the model
epi_mdl_samples = mapreduce(hcat, 1:n_samples) do _
latent_inf_mdl() #Sample unconditionally the underlying parameters of the model
latent_inf_mdl()
end

p1 = plot(epi_mdl_samples,
lab = "",
c = :grey,
alpha = 0.25,
fig = Figure()
ax1 = Axis(fig[1, 1];
title = "$(n_samples) draws from renewal model with chosen Rt",
ylabel = "Latent infections"
)
p2 = plot(R_t_fixed,
lab = "",
lw = 2,
ax2 = Axis(fig[2, 1];
ylabel = "Rt"
)

plot(p1, p2, layout = (2, 1))
for col in eachcol(epi_mdl_samples)
lines!(ax1, col;
color = (:grey, 0.1)
)
end
lines!(ax2, R_t_fixed;
linewidth = 2
)
fig
end

# ╔═╡ c8ef8a60-d087-4ae9-ae92-abeea5afc7ae
Expand All @@ -296,7 +305,6 @@ A prior for $\phi$ was not specified in _Mishra et al_, we select one below but

# ╔═╡ 714908a1-dc85-476f-a99f-ec5c95a78b60
obs = NegativeBinomialError(cluster_factor_prior = HalfNormal(0.1))
# obs = PoissonError()

# ╔═╡ dacb8094-89a4-404a-8243-525c0dbfa482
md"
Expand Down Expand Up @@ -324,17 +332,23 @@ plt_obs = let
obs_mdl_samples = mapreduce(hcat, 1:n_samples) do _
θ = obs_mdl() #Sample unconditionally the underlying parameters of the model
end
scatter(obs_mdl_samples,
lab = "",
c = :grey,
alpha = 0.25,
fig = Figure()
ax = Axis(fig[1, 1];
title = "$(n_samples) draws from neg. bin. obs model",
ylabel = "Observed cases"
)
plot!(expected_cases,
c = :red,
lw = 3,
lab = "Expected cases")
for col in eachcol(obs_mdl_samples)
scatter!(ax, col;
color = (:grey, 0.2)
)
end
lines!(ax, expected_cases;
color = :red,
linewidth = 3,
label = "Expected cases"
)
axislegend(ax)
fig
end

# ╔═╡ a06065e1-0e20-4cf8-8d5a-2d588da20bee
Expand Down Expand Up @@ -473,9 +487,10 @@ let
C = south_korea_data.y_t
D = south_korea_data.dates

#Unconditional model for posterior predictive sampling
mdl_unconditional = generate_epiaware(epi_prob, (y_t = fill(missing, length(C)),)) |
(var"obs.cluster_factor" = fixed_cluster_factor,)
#Case unconditional model for posterior predictive sampling
mdl_unconditional = generate_epiaware(epi_prob,
(y_t = fill(missing, length(C)),)
) | (var"obs.cluster_factor" = fixed_cluster_factor,)
posterior_gens = generated_quantities(mdl_unconditional, inference_results.samples)

#plotting quantiles
Expand All @@ -486,30 +501,55 @@ let
predicted_R_t = generated_quantiles(
posterior_gens, :Z_t, qs; transformation = x -> exp.(x))

#Plots
p1 = plot(D, predicted_y_t[:, 3], lw = 2, lab = "post. median", c = :purple)
plot!(p1, D, predicted_y_t[:, 2], fillrange = predicted_y_t[:, 4],
fillalpha = 0.5, lw = 0, c = :purple, lab = "50%")
plot!(p1, D, predicted_y_t[:, 1], fillrange = predicted_y_t[:, 5],
fillalpha = 0.2, lw = 0, c = :purple, lab = "95%")

scatter!(p1, D, C,
lab = "Actual cases",
ylabel = "Daily Cases",
title = "Posterior predictive: Cases",
ylims = (-50, maximum(C) * 2),
c = :black
ts = D .|> d -> d - minimum(D) .|> d -> d.value + 1
t_ticks = string.(D)
fig = Figure()
ax1 = Axis(fig[1, 1];
ylabel = "Daily cases",
xticks = (ts[1:14:end], t_ticks[1:14:end]),
title = "Posterior predictive: Cases"
)
ax2 = Axis(fig[2, 1];
yscale = log10,
title = "Prediction: Reproduction number",
xticks = (ts[1:14:end], t_ticks[1:14:end])
)
linkxaxes!(ax1, ax2)

p2 = plot(D, predicted_R_t[:, 3], lw = 2, lab = "post. median", c = :green,
yscale = :log10, title = "Prediction: Reproduction number")
plot!(p2, D, predicted_R_t[:, 2], fillrange = predicted_R_t[:, 4],
fillalpha = 0.5, lw = 0, c = :green, lab = "50%")
plot!(p2, D, predicted_R_t[:, 1], fillrange = predicted_R_t[:, 5],
fillalpha = 0.2, lw = 0, c = :green, lab = "95%")
hline!(p2, [1.0], lab = "Rt = 1", lw = 2, c = :blue)
lines!(ax1, ts, predicted_y_t[:, 3];
color = :purple,
linewidth = 2,
label = "Post. median"
)
band!(ax1, 1:size(predicted_y_t, 1), predicted_y_t[:, 2], predicted_y_t[:, 4];
color = (:purple, 0.4),
label = "50%"
)
band!(ax1, 1:size(predicted_y_t, 1), predicted_y_t[:, 1], predicted_y_t[:, 5];
color = (:purple, 0.2),
label = "95%"
)
scatter!(ax1, C;
color = :black,
label = "Actual cases")
axislegend(ax1)

lines!(ax2, ts, predicted_R_t[:, 3];
color = :green,
linewidth = 2,
label = "Post. median"
)
band!(ax2, 1:size(predicted_R_t, 1), predicted_R_t[:, 2], predicted_R_t[:, 4];
color = (:green, 0.4),
label = "50%"
)
band!(ax2, 1:size(predicted_R_t, 1), predicted_R_t[:, 1], predicted_R_t[:, 5];
color = (:green, 0.2),
label = "95%"
)
axislegend(ax2)

plot(p1, p2, layout = (2, 1), size = (500, 700), left_margin = 5mm)
fig
end

# ╔═╡ c05ed977-7a89-4ac8-97be-7078d69fce9f
Expand All @@ -521,51 +561,17 @@ We can interrogate the sampled chains directly from the `samples` field of the `

# ╔═╡ ff21c9ec-1581-405f-8db1-0f522b5bc296
let
p1 = histogram(inference_results.samples["latent.σ_AR"],
lab = "chain " .* string.([1 2 3 4]),
fillalpha = 0.4,
lw = 0,
norm = :pdf,
title = "Posterior dist: AR noise std")
plot!(p1, ar.std_prior,
lw = 3,
c = :black,
lab = "prior")

p2 = histogram(inference_results.samples[:init_incidence],
lab = "chain " .* string.([1 2 3 4]),
fillalpha = 0.4,
lw = 0,
norm = :pdf,
title = "Posterior dist: log-initial incidence")
plot!(p2, epi.initialisation_prior,
lw = 3,
c = :black,
lab = "prior")

p3 = histogram(inference_results.samples["latent.damp_AR[2]"],
lab = "chain " .* string.([1 2 3 4]),
fillalpha = 0.4,
lw = 0,
norm = :pdf,
title = "Posterior dist: rho_1")
plot!(p3, ar.damp_prior.v[2],
lw = 3,
c = :black,
lab = "prior")

p4 = histogram(inference_results.samples["latent.damp_AR[1]"],
lab = "chain " .* string.([1 2 3 4]),
fillalpha = 0.4,
lw = 0,
norm = :pdf,
title = "Posterior dist: rho_2")
plot!(p4, ar.damp_prior.v[1],
lw = 3,
c = :black,
lab = "prior")

plot(p1, p2, p3, p4, layout = (2, 2), size = (800, 600))
sub_chn = inference_results.samples[inference_results.samples.name_map.parameters[[1:5;
end]]]
fig = pairplot(sub_chn)
lines!(fig[1, 1], ar.std_prior, label = "Prior")
lines!(fig[2, 2], ar.init_prior.v[1], label = "Prior")
lines!(fig[3, 3], ar.init_prior.v[2], label = "Prior")
lines!(fig[4, 4], ar.damp_prior.v[1], label = "Prior")
lines!(fig[5, 5], ar.damp_prior.v[2], label = "Prior")
lines!(fig[6, 6], epi.initialisation_prior, label = "Prior")

fig
end

# ╔═╡ Cell order:
Expand Down

0 comments on commit 11b9034

Please sign in to comment.