Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace scan usage with accumulate scan #444

Merged
merged 5 commits into from
Aug 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ R_1 = 1 \Big{/} \sum_{t\geq 1} e^{-rt} g_t
log_I0_prior = Normal(log(1.0), 1.0)

# ╔═╡ 8487835e-d430-4300-bd7c-e33f5769ee32
epi = Renewal(model_data, log_I0_prior)
epi = Renewal(model_data; initialisation_prior = log_I0_prior)

# ╔═╡ 2119319f-a2ef-4c96-82c4-3c7eaf40d2e0
md"
Expand Down
6 changes: 3 additions & 3 deletions EpiAware/src/EpiInfModels/EpiInfModels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,12 @@ Module for defining epidemiological models.
module EpiInfModels

using ..EpiAwareBase

using ..EpiAwareUtils: scan, censored_pmf
using ..EpiAwareUtils

using Turing, Distributions, DocStringExtensions, LinearAlgebra

#Export models
export EpiData, DirectInfections, ExpGrowthRate, Renewal, RenewalWithPopulation
export EpiData, DirectInfections, ExpGrowthRate, Renewal

#Export functions
export R_to_r, r_to_R, expected_Rt
Expand All @@ -19,6 +18,7 @@ include("docstrings.jl")
include("EpiData.jl")
include("DirectInfections.jl")
include("ExpGrowthRate.jl")
include("RenewalSteps.jl")
include("Renewal.jl")
include("utils.jl")

Expand Down
227 changes: 40 additions & 187 deletions EpiAware/src/EpiInfModels/Renewal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,11 @@ number.
`initialisation_prior` for the prior distribution of ``\hat{I}_0``. The default
`initialisation_prior` is `Normal()`.

## Constructor
## Constructors

- `Renewal(; data, initialisation_prior)`.
- `Renewal(; data, initialisation_prior)`. Construct a `Renewal` model with default update steps.
- `Renewal(data; initialisation_prior)`. Construct a `Renewal` model with default update steps.
- `Renewal(data, initialisation_prior, recurrent_step)` Construct a `Renewal` model with `recurrent_step` update step function.

## Example usage with `generate_latent_infs`

Expand All @@ -53,7 +55,7 @@ g = exp
data = EpiData(gen_int, g)

# Create an Renewal model
renewal_model = Renewal(data = data, initialisation_prior = Normal())
renewal_model = Renewal(data; initialisation_prior = Normal())
```

Then, we can use `generate_latent_infs` to construct a Turing model for the unobserved
Expand All @@ -77,198 +79,49 @@ unobserved infections.
I_t = generated_quantities(latent_inf, θ)
```
"
@kwdef struct Renewal{S <: Sampleable} <: EpiAwareBase.AbstractTuringRenewal
data::EpiData
initialisation_prior::S = Normal()
end

@doc """
function (epi_model::Renewal)(recent_incidence, Rt)

Callable on a `Renewal` struct for compute new incidence based on recent incidence and Rt.

## Mathematical specification

The new incidence is given by

```math
I_t = R_t \\sum_{i=1}^{n-1} I_{t-i} g_i
```

where `I_t` is the new incidence, `R_t` is the reproduction number, `I_{t-i}` is the recent incidence
and `g_i` is the generation interval.

# Arguments
- `recent_incidence`: Array of recent incidence values.
- `Rt`: Reproduction number.

# Returns
- Tuple containing the updated incidence array and the new incidence value.
"""
function (epi_model::Renewal)(recent_incidence, Rt)
new_incidence = Rt * dot(recent_incidence, epi_model.data.gen_int)
return ([new_incidence; recent_incidence[1:(epi_model.data.len_gen_int - 1)]],
new_incidence)
struct Renewal{E, S <: Sampleable, A} <:
EpiAwareBase.AbstractTuringRenewal where {
E <: EpiData, A <: AbstractConstantRenewalStep}
data::E
initialisation_prior::S
recurrent_step::A

function Renewal(data::EpiData; initialisation_prior = Normal())
rev_gen_int = reverse(data.gen_int)
recurrent_step = ConstantRenewalStep(rev_gen_int)
return Renewal(data, initialisation_prior, recurrent_step)
end

function Renewal(; data::EpiData, initialisation_prior = Normal())
rev_gen_int = reverse(data.gen_int)
recurrent_step = ConstantRenewalStep(rev_gen_int)
return Renewal(data, initialisation_prior, recurrent_step)
end

function Renewal(data::E,
initialisation_prior::S,
recurrent_step::A) where {
E <: EpiData, S <: Sampleable, A <: AbstractConstantRenewalStep}
return new{E, S, A}(data, initialisation_prior, recurrent_step)
end
end

"""
Create the initial vector of infected individuals for a renewal model.
Create the initial state of the `Renewal` model.

# Arguments
- `epi_model::Renewal`: The renewal model.
- `epi_model::Renewal`: The Renewal model.
- `I₀`: The initial number of infected individuals.
- `Rt`: The time-varying reproduction number.
- `Rt`: The initial time-varying reproduction number.

# Returns
The initial vector of infected individuals.

"""
function make_renewal_init(epi_model::Renewal, I₀, Rt)
r_approx = R_to_r(Rt[1], epi_model)
return I₀ * [exp(-r_approx * t) for t in 0:(epi_model.data.len_gen_int - 1)]
end

@doc raw"
Model unobserved/latent infections as due to time-varying Renewal model with reproduction
number ``\mathcal{R}_t`` which is generated by a latent process and a population size
of available people who can be infected `N`.

## Mathematical specification

If ``Z_t`` is a realisation of the latent model, then the unobserved/latent infections are
given by

```math
\begin{align}
\mathcal{R}_t &= g(Z_t),\\
S_t &= S_{t-1} - I_t,\\
I_t &= {S_{t-1} \over N}\mathcal{R}_t \sum_{i=1}^{n-1} I_{t-i} g_i, \qquad t \geq 1, \\
I_t &= g(\hat{I}_0) \exp(r(\mathcal{R}_1) t), \qquad t \leq 0.
\end{align}
```

where ``g`` is a transformation function and the unconstrained initial infections
``\hat{I}_0`` are sampled from a prior distribution. The discrete generation interval is
given by ``g_i``.

``r(\mathcal{R}_1)`` is the exponential growth rate implied by ``\mathcal{R}_1)``
using the implicit relationship between the exponential growth rate and the reproduction
number.

```math
\mathcal{R} \sum_{j \geq 1} g_j \exp(- r j)= 1.
```

`Renewal` are constructed by passing an `EpiData` object `data` and an
`initialisation_prior` for the prior distribution of ``\hat{I}_0``. The default
`initialisation_prior` is `Normal()`.

## Constructor

- `RenewalWithPopulation(; data, initialisation_prior, pop_size)`.

## Example usage with `generate_latent_infs`

`generate_latent_infs` can be used to construct a `Turing` model for the latent infections
conditional on the sample path of a latent process. In this example, we generate a sample
of a white noise latent process.

First, we construct an `Renewal` struct with an `EpiData` object, an initialisation
prior and a transformation function.

```julia
using Distributions, Turing, EpiAware
gen_int = [0.2, 0.3, 0.5]
g = exp

# Create an EpiData object
data = EpiData(gen_int, g)

# Create an Renewal model
renewal_model = RenewalWithPopulation(data = data, initialisation_prior = Normal(), pop_size = 1e6)
```

Then, we can use `generate_latent_infs` to construct a Turing model for the unobserved
infection generation model set by the type of `direct_inf_model`.

```julia
# Construct a Turing model
Z_t = randn(100) * 0.05
latent_inf = generate_latent_infs(renewal_model, Z_t)
```

Now we can use the `Turing` PPL API to sample underlying parameters and generate the
unobserved infections.

```julia
# Sample from the unobserved infections model

#Sample random parameters from prior
θ = rand(latent_inf)
#Get unobserved infections as a generated quantities from the model
I_t = generated_quantities(latent_inf, θ)
```
"
@kwdef struct RenewalWithPopulation{S <: Sampleable} <: EpiAwareBase.AbstractTuringRenewal
data::EpiData
initialisation_prior::S = Normal()
pop_size::Float64 = 1e6
end

@doc """
function (epi_model::RenewalWithPopulation)(recent_incidence_and_available_sus, Rt)

Callable on a `RenewalWithPopulation` struct for compute new incidence based on
recent incidence, Rt and depletion of susceptibles.

## Mathematical specification

The new incidence is given by

```math
I_t = {S_{t-1} / N} R_t \\sum_{i=1}^{n-1} I_{t-i} g_i
```

where `I_t` is the new incidence, `R_t` is the reproduction number, `I_{t-i}` is the recent incidence
and `g_i` is the generation interval.

# Arguments
- `recent_incidence_and_available_sus`: A tuple with an array of recent incidence
values and the remaining susceptible/available individuals.
- `Rt`: Reproduction number.

# Returns
- Tuple containing the updated incidence array and the new `recent_incidence_and_available_sus`
value.
"""
function (epi_model::RenewalWithPopulation)(recent_incidence_and_available_sus, Rt)
recent_incidence, S = recent_incidence_and_available_sus
new_incidence = max(S / epi_model.pop_size, 0.0) * Rt *
dot(recent_incidence, epi_model.data.gen_int)
new_S = S - new_incidence
new_recent_incidence_and_available_sus = (
[new_incidence; recent_incidence[1:(epi_model.data.len_gen_int - 1)]], new_S)

return (new_recent_incidence_and_available_sus, new_incidence)
end

"""
Constructs the initial conditions for a renewal model with population.

# Arguments
- `epi_model::RenewalWithPopulation`: The renewal model with population.
- `I₀`: The initial number of infected individuals.
- `Rt`: The time-varying reproduction number.

# Returns
- A tuple containing the initial number of infected individuals at each generation
interval and the population size of susceptible/available people.

"""
function make_renewal_init(epi_model::RenewalWithPopulation, I₀, Rt)
r_approx = R_to_r(Rt[1], epi_model)
return (I₀ * [exp(-r_approx * t) for t in 0:(epi_model.data.len_gen_int - 1)],
epi_model.pop_size)
function make_renewal_init(epi_model::Renewal, I₀, Rt₀)
r_approx = R_to_r(Rt₀, epi_model)
return renewal_init_state(
epi_model.recurrent_step, I₀, r_approx, epi_model.data.len_gen_int)
end

@doc raw"
Expand All @@ -292,7 +145,7 @@ g = exp
data = EpiData(gen_int, g)

# Create an Renewal model
renewal_model = Renewal(data = data, initialisation_prior = Normal())
renewal_model = Renewal(data; initialisation_prior = Normal())
```

Then, we can use `generate_latent_infs` to construct a Turing model for the unobserved
Expand Down Expand Up @@ -322,8 +175,8 @@ I_t = generated_quantities(latent_inf, θ)
I₀ = epi_model.data.transformation(init_incidence)
Rt = epi_model.data.transformation.(_Rt)

init = make_renewal_init(epi_model, I₀, Rt)
I_t, _ = scan(epi_model, init, Rt)
init = make_renewal_init(epi_model, I₀, Rt[1])
I_t = accumulate_scan(epi_model.recurrent_step, init, Rt)

return I_t
end
Loading
Loading