Skip to content

Commit

Permalink
implement different model types
Browse files Browse the repository at this point in the history
implements suggestions by @hsbadr

See
#213 (comment)
#213 (comment)
#213 (comment)

For now not implementing comparison to the approximate growth rate as
this seems quite a specific use case that could also be done outside the
stan model.

Also not implementing any approximate growth rate from seeding time -
instead minimum seeding time is now set to 1, so the last seeding time
is used to calculate the first growth rate.
  • Loading branch information
sbfnk committed Jul 4, 2023
1 parent 81c314d commit 550b156
Show file tree
Hide file tree
Showing 22 changed files with 281 additions and 187 deletions.
12 changes: 7 additions & 5 deletions R/create.R
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,7 @@ create_gp_data <- function(gp = gp_opts(), data) {
#' create_obs_model(obs_opts(week_length = 3), dates = dates)
create_obs_model <- function(obs = obs_opts(), dates) {
data <- list(
model_type = as.numeric(obs$family %in% "negbin"),
obs_dist = as.integer(obs$family %in% "negbin"),
phi_mean = obs$phi[1],
phi_sd = obs$phi[2],
week_effect = ifelse(obs$week_effect, obs$week_length, 1),
Expand Down Expand Up @@ -430,7 +430,8 @@ create_obs_model <- function(obs = obs_opts(), dates) {
#' @export
create_stan_data <- function(reported_cases, seeding_time,
rt, gp, obs, horizon,
backcalc, shifted_cases) {
backcalc, shifted_cases,
process_model) {

cases <- reported_cases[(seeding_time + 1):(.N - horizon)]$confirm

Expand All @@ -440,7 +441,8 @@ create_stan_data <- function(reported_cases, seeding_time,
t = length(reported_cases$date),
horizon = horizon,
burn_in = 0,
seeding_time = seeding_time
seeding_time = seeding_time,
process_model = process_model
)
# add Rt data
data <- c(
Expand Down Expand Up @@ -547,7 +549,7 @@ create_initial_conditions <- function(data) {
out$rho <- array(numeric(0))
out$alpha <- array(numeric(0))
}
if (data$model_type == 1) {
if (data$obs_dist == 1) {
out$rep_phi <- array(
truncnorm::rtruncnorm(
1,
Expand All @@ -560,7 +562,7 @@ create_initial_conditions <- function(data) {
if (data$seeding_time > 1) {
out$initial_growth <- array(rnorm(1, data$prior_growth, 0.01))
}
out$log_R <- array(rnorm(
out$base_cov <- rnorm(
n = 1, mean = convert_to_logmean(data$r_mean, data$r_sd),
sd = convert_to_logsd(data$r_mean, data$r_sd) * 0.1
))
Expand Down
11 changes: 11 additions & 0 deletions R/estimate_infections.R
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@
#' @param reported_cases A data frame of confirmed cases (confirm) by date
#' (date). confirm must be integer and date must be in date format.
#'
#' @param process_model A character string that defines what is being
#' modelled: "infections", "growth" or "R" (default). If ' set to "R",
#' a generation time distribution needs to be defined via the `generation_time`
#' argument.
#'
#' @param generation_time A call to `generation_time_opts()` defining the
#' generation time distribution used. For backwards compatibility a list of
#' summary parameters can also be passed.
Expand Down Expand Up @@ -233,6 +238,7 @@
#' options(old_opts)
#' }
estimate_infections <- function(reported_cases,
process_opts = process_opts(),
generation_time = generation_time_opts(),
delays = delay_opts(),
truncation = trunc_opts(),
Expand Down Expand Up @@ -295,10 +301,15 @@ estimate_infections <- function(reported_cases,
)
reported_cases <- reported_cases[-(1:backcalc$prior_window)]

model_choices <- c("infections", "growth", "R")
model <- match.arg(model, choices = model_choices)
process_model <- which(model == model_choices) - 1

# Define stan model parameters
data <- create_stan_data(
reported_cases = reported_cases,
seeding_time = seeding_time,
process_opts = process_opts,
rt = rt,
gp = gp,
obs = obs,
Expand Down
4 changes: 2 additions & 2 deletions R/extract.R
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ extract_parameter_samples <- function(stan_fit, data, reported_dates,
out$growth_rate <- extract_parameter(
"r",
samples,
reported_dates[-1]
reported_dates
)
if (data$week_effect > 1) {
out$day_of_week <- extract_parameter(
Expand Down Expand Up @@ -168,7 +168,7 @@ extract_parameter_samples <- function(stan_fit, data, reported_dates,
date := NULL
]
}
if (data$model_type == 1) {
if (data$obs_dist == 1) {
out$reporting_overdispersion <- extract_static_parameter("rep_phi", samples)
out$reporting_overdispersion <- out$reporting_overdispersion[,
value := value.V1][,
Expand Down
93 changes: 90 additions & 3 deletions R/opts.R
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ trunc_opts <- function(dist = dist_spec()) {

#' Time-Varying Reproduction Number Options
#'
#' @description `r lifecycle::badge("stable")`
#' @description `r lifecycle::badge("deprecated")`
#' Defines a list specifying the optional arguments for the time-varying
#' reproduction number. Custom settings can be supplied which override the
#' defaults.
Expand Down Expand Up @@ -158,6 +158,7 @@ trunc_opts <- function(dist = dist_spec()) {
#'
#' @return A list of settings defining the time-varying reproduction number.
#' @author Sam Abbott

#' @inheritParams create_future_rt
#' @export
#' @examples
Expand All @@ -176,6 +177,7 @@ rt_opts <- function(prior = list(mean = 1, sd = 1),
future = "latest",
gp_on = "R_t-1",
pop = 0) {
stop("rt_opts is deprecated - use process_opts instead")
rt <- list(
prior = prior,
use_rt = use_rt,
Expand All @@ -197,9 +199,93 @@ rt_opts <- function(prior = list(mean = 1, sd = 1),
return(rt)
}

#' Back Calculation Options
#' Process model optionss
#'
#' @description `r lifecycle::badge("stable")`
#' Defines a list specifying the optional arguments for the process mode.
#' Custom settings can be supplied which override the defaults.
#' @param prior List containing named numeric elements "mean" and "sd". The mean and
#' standard deviation of the log normal Rt prior. Defaults to mean of 1 and standard
#' deviation of 1.
#' @param use_rt Logical, defaults to `TRUE`. Should Rt be used to generate infections
#' and hence reported cases.
#' @param rw Numeric step size of the random walk, defaults to 0. To specify a weekly random
#' walk set `rw = 7`. For more custom break point settings consider passing in a `breakpoints`
#' variable as outlined in the next section.
#' @param use_breakpoints Logical, defaults to `TRUE`. Should break points be used if present
#' as a `breakpoint` variable in the input data. Break points should be defined as 1 if present
#' and otherwise 0. By default breakpoints are fit jointly with a global non-parametric effect
#' and so represent a conservative estimate of break point changes (alter this by setting `gp = NULL`).
#' @param pop Integer, defaults to 0. Susceptible population initially present. Used to adjust
#' Rt estimates when otherwise fixed based on the proportion of the population that is
#' susceptible. When set to 0 no population adjustment is done.
#' @param gp_on Character string, defaulting to "R_t-1". Indicates how the Gaussian process,
#' if in use, should be applied to Rt. Currently supported options are applying the Gaussian
#' process to the last estimated Rt (i.e Rt = Rt-1 * GP), and applying the Gaussian process to
#' a global mean (i.e Rt = R0 * GP). Both should produced comparable results when data is not
#' sparse but the method relying on a global mean will revert to this for real time estimates,
#' which may not be desirable.
#' @return A list of settings defining the time-varying reproduction number
#' @inheritParams create_future_rt
#' @export
#' @examples
#' # default settings
#' rt_opts()
#'
#' # add a custom length scale
#' rt_opts(prior = list(mean = 2, sd = 1))
#'
#' # add a weekly random walk
#' rt_opts(rw = 7)
#' @importFrom data.table fcase
process_opts <- function(model = "R",
prior_mean = data.table::fcase(
model == "R", list(mean = 1, sd = 1),
model == "growth", list(mean = 0, sd = 1),
model == "infections", NULL
),
prior_t = NULL,
rw = 0,
use_breakpoints = TRUE,
future = "latest",
stationary = FALSE,
pop = 0) {

## check
model_choices <- c("infections", "growth", "R")
process_model <- match.arg(process_model, choices = model_choices)
process_model <- which(process_model == model_choices) - 1

if (!(xor(is.null(prior_mean), is.null(prior_t)))) {
stop("Either 'prior_mean' or 'prior_t' must be set to NULL")
}
process <- list(
process_model = process_model,
prior_mean = prior_mean,
prior_t = prior_t,
rw = rw,
use_breakpoints = use_breakpoints,
future = future,
stationary = stationary,
pop = pop
)

# replace default settings with those specified by user
if (process$rw > 0) {
process$use_breakpoints <- TRUE
}

if (!is.null(prior_mean) &&
!("mean" %in% names(process$prior) &&
"sd" %in% names(process$prior))) {
stop("prior must have both a mean and sd specified")
}
return(process)
}

#' Back Calculation Options
#'
#' @description `r lifecycle::badge("deprecated")`
#' Defines a list specifying the optional arguments for the back calculation
#' of cases. Only used if `rt = NULL`.
#'
Expand Down Expand Up @@ -232,7 +318,8 @@ rt_opts <- function(prior = list(mean = 1, sd = 1),
#' # default settings
#' backcalc_opts()
backcalc_opts <- function(prior = "reports", prior_window = 14, rt_window = 1) {
backcalc <- list(
stop("backcalc_opts is deprecated - use process_opts instead")
backcalc <- list(
prior = match.arg(prior, choices = c("reports", "none", "infections")),
prior_window = prior_window,
rt_window = as.integer(rt_window)
Expand Down
1 change: 0 additions & 1 deletion inst/stan/data/backcalc.stan
Original file line number Diff line number Diff line change
@@ -1,2 +1 @@
int backcalc_prior; // Prior type to use for backcalculation
int rt_half_window; // Half the moving average window used when calculating Rt
7 changes: 7 additions & 0 deletions inst/stan/data/covariates.stan
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
int process_model; // 0 = infections; 1 = growth; 2 = rt
int bp_n; // no of breakpoints (0 = no breakpoints)
int breakpoints[t - seeding_time]; // when do breakpoints occur
int cov_mean_const; // 0 = not const mean; 1 = const mean
real<lower = 0> cov_mean_mean[cov_mean_const]; // const covariate mean
real<lower = 0> cov_mean_sd[cov_mean_const]; // const covariate sd
vector<lower = 0>[cov_mean_const ? 0 : t] cov_t; // time-varying covariate mean
2 changes: 1 addition & 1 deletion inst/stan/data/observation_model.stan
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
int day_of_week[t - seeding_time]; // day of the week indicator (1 - 7)
int model_type; // type of model: 0 = poisson otherwise negative binomial
int obs_dist; // type of model: 0 = poisson otherwise negative binomial
real phi_mean; // Mean and sd of the normal prior for the
real phi_sd; // reporting process
int week_effect; // length of week effect
Expand Down
2 changes: 1 addition & 1 deletion inst/stan/data/observations.stan
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
int t; // unobserved time
int seeding_time; // time period used for seeding and not observed
int<lower = 1> seeding_time; // time period used for seeding and not observed
int horizon; // forecast horizon
int future_time; // time in future for Rt
int<lower = 0> cases[t - horizon - seeding_time]; // observed cases
Expand Down
4 changes: 0 additions & 4 deletions inst/stan/data/rt.stan
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
int estimate_r; // should the reproduction no be estimated (1 = yes)
real prior_infections; // prior for initial infections
real prior_growth; // prior on initial growth rate
real <lower = 0> r_mean; // prior mean of reproduction number
real <lower = 0> r_sd; // prior standard deviation of reproduction number
int bp_n; // no of breakpoints (0 = no breakpoints)
int breakpoints[t - seeding_time]; // when do breakpoints occur
int future_fixed; // is underlying future Rt assumed to be fixed
int fixed_from; // Reference date for when Rt estimation should be fixed
int pop; // Initial susceptible population
Expand Down
2 changes: 1 addition & 1 deletion inst/stan/data/simulation_observation_model.stan
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,6 @@
real<lower = 0> day_of_week_simplex[n, week_effect];
int obs_scale;
real<lower = 0, upper = 1> frac_obs[n, obs_scale];
int model_type;
int obs_dist;
real<lower = 0> rep_phi[n, model_type]; // overdispersion of the reporting process
int<lower = 0> trunc_id; // id of truncation
4 changes: 2 additions & 2 deletions inst/stan/data/simulation_rt.stan
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
real initial_infections[seeding_time ? n : 0, 1]; // initial logged infections
real initial_growth[seeding_time > 1 ? n : 0, 1]; //initial growth

matrix[n, t - seeding_time] R; // reproduction number
int gt_dist[1]; // 0 = lognormal; 1 = gamma
vector[n] R[t - seeding_time]; // reproduction number
int pop; // susceptible population

int<lower = 0> gt_id; // id of generation time
Loading

0 comments on commit 550b156

Please sign in to comment.