diff --git a/NAMESPACE b/NAMESPACE index b477eeea6..fe2bb0cec 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -69,6 +69,7 @@ export(obs_opts) export(opts_list) export(plot_estimates) export(plot_summary) +export(process_opts) export(regional_epinow) export(regional_runtimes) export(regional_summary) diff --git a/R/create.R b/R/create.R index f02bfdc59..8f6b1afd3 100644 --- a/R/create.R +++ b/R/create.R @@ -483,7 +483,8 @@ create_obs_model <- function(obs = obs_opts(), dates) { #' } create_stan_data <- function(data, seeding_time, rt, gp, obs, horizon, - backcalc, shifted_cases) { + backcalc, shifted_cases, + process_model) { cases <- data[(seeding_time + 1):(.N - horizon)] complete_cases <- create_complete_cases(cases) @@ -497,7 +498,8 @@ create_stan_data <- function(data, seeding_time, t = length(data$date), horizon = horizon, burn_in = 0, - seeding_time = seeding_time + seeding_time = seeding_time, + process_model = process_model ) # add Rt data stan_data <- c( @@ -610,7 +612,7 @@ create_initial_conditions <- function(data) { out$rho <- array(numeric(0)) out$alpha <- array(numeric(0)) } - if (data$model_type == 1) { + if (data$obs_dist == 1) { out$rep_phi <- array( truncnorm::rtruncnorm( 1, @@ -623,7 +625,7 @@ create_initial_conditions <- function(data) { if (data$seeding_time > 1) { out$initial_growth <- array(rnorm(1, data$prior_growth, 0.01)) } - out$log_R <- array(rnorm( + out$base_cov <- array(rnorm( n = 1, mean = convert_to_logmean(data$r_mean, data$r_sd), sd = convert_to_logsd(data$r_mean, data$r_sd) * 0.1 )) diff --git a/R/depreciated.R b/R/depreciated.R new file mode 100644 index 000000000..d31e77f4d --- /dev/null +++ b/R/depreciated.R @@ -0,0 +1,130 @@ +#' Back Calculation Options +#' +#' @description `r lifecycle::badge("deprecated")` +#' Defines a list specifying the optional arguments for the back calculation +#' of cases. Only used if `rt = NULL`. +#' +#' @param prior A character string defaulting to "reports". Defines the prior +#' to use when deconvolving. Currently implemented options are to use smoothed +#' mean delay shifted reported cases ("reports"), to use the estimated +#' infections from the previous time step seeded for the first time step using +#' mean shifted reported cases ("infections"), or no prior ("none"). Using no +#' prior will result in poor real time performance. No prior and using +#' infections are only supported when a Gaussian process is present . If +#' observed data is not reliable then it a sensible first step is to explore +#' increasing the `prior_window` wit a sensible second step being to no longer +#' use reported cases as a prior (i.e set `prior = "none"`). +#' +#' @param prior_window Integer, defaults to 14 days. The mean centred smoothing +#' window to apply to mean shifted reports (used as a prior during back +#' calculation). 7 days is minimum recommended settings as this smooths day of +#' the week effects but depending on the quality of the data and the amount of +#' information users wish to use as a prior (higher values equalling a less +#' informative prior). +#' +#' @param rt_window Integer, defaults to 1. The size of the centred rolling +#' average to use when estimating Rt. This must be odd so that the central +#' estimate is included. +#' +#' @return A list of back calculation settings. +#' @author Sam Abbott +#' @export +#' @examples +#' # default settings +#' backcalc_opts() +backcalc_opts <- function(prior = "reports", prior_window = 14, rt_window = 1) { + stop("backcalc_opts is deprecated - use process_opts instead") + backcalc <- list( + prior = match.arg(prior, choices = c("reports", "none", "infections")), + prior_window = prior_window, + rt_window = as.integer(rt_window) + ) + if (backcalc$rt_window %% 2 == 0) { + stop( + "Rt rolling average window must be odd in order to include the current + estimate" + ) + } + return(backcalc) +} + +#' Time-Varying Reproduction Number Options +#' +#' @description `r lifecycle::badge("deprecated")` +#' Defines a list specifying the optional arguments for the time-varying +#' reproduction number. Custom settings can be supplied which override the +#' defaults. +#' +#' @param prior List containing named numeric elements "mean" and "sd". The +#' mean and standard deviation of the log normal Rt prior. Defaults to mean of +#' 1 and standard deviation of 1. +#' +#' @param use_rt Logical, defaults to `TRUE`. Should Rt be used to generate +#' infections and hence reported cases. +#' +#' @param rw Numeric step size of the random walk, defaults to 0. To specify a +#' weekly random walk set `rw = 7`. For more custom break point settings +#' consider passing in a `breakpoints` variable as outlined in the next section. +#' +#' @param use_breakpoints Logical, defaults to `TRUE`. Should break points be +#' used if present as a `breakpoint` variable in the input data. Break points +#' should be defined as 1 if present and otherwise 0. By default breakpoints +#' are fit jointly with a global non-parametric effect and so represent a +#' conservative estimate of break point changes (alter this by setting +#' `gp = NULL`). +#' +#' @param pop Integer, defaults to 0. Susceptible population initially present. +#' Used to adjust Rt estimates when otherwise fixed based on the proportion of +#' the population that is susceptible. When set to 0 no population adjustment +#' is done. +#' +#' @param gp_on Character string, defaulting to "R_t-1". Indicates how the +#' Gaussian process, if in use, should be applied to Rt. Currently supported +#' options are applying the Gaussian process to the last estimated Rt (i.e +#' Rt = Rt-1 * GP), and applying the Gaussian process to a global mean (i.e Rt +#' = R0 * GP). Both should produced comparable results when data is not sparse +#' but the method relying on a global mean will revert to this for real time +#' estimates, which may not be desirable. +#' +#' @return A list of settings defining the time-varying reproduction number. +#' @author Sam Abbott + +#' @inheritParams create_future_rt +#' @export +#' @examples +#' # default settings +#' rt_opts() +#' +#' # add a custom length scale +#' rt_opts(prior = list(mean = 2, sd = 1)) +#' +#' # add a weekly random walk +#' rt_opts(rw = 7) +rt_opts <- function(prior = list(mean = 1, sd = 1), + use_rt = TRUE, + rw = 0, + use_breakpoints = TRUE, + future = "latest", + gp_on = "R_t-1", + pop = 0) { + stop("rt_opts is deprecated - use process_opts instead") + rt <- list( + prior = prior, + use_rt = use_rt, + rw = rw, + use_breakpoints = use_breakpoints, + future = future, + pop = pop, + gp_on = match.arg(gp_on, choices = c("R_t-1", "R0")) + ) + + # replace default settings with those specified by user + if (rt$rw > 0) { + rt$use_breakpoints <- TRUE + } + + if (!("mean" %in% names(rt$prior) & "sd" %in% names(rt$prior))) { + stop("prior must have both a mean and sd specified") + } + return(rt) +} \ No newline at end of file diff --git a/R/estimate_infections.R b/R/estimate_infections.R index b5e9f6f6d..282dbc9d9 100644 --- a/R/estimate_infections.R +++ b/R/estimate_infections.R @@ -22,6 +22,11 @@ #' #' @param reported_cases Deprecated; use `data` instead. #' +#' @param process_model A character string that defines what is being +#' modelled: "infections", "growth" or "R" (default). If ' set to "R", +#' a generation time distribution needs to be defined via the `generation_time` +#' argument. +#' #' @param generation_time A call to [generation_time_opts()] defining the #' generation time distribution used. For backwards compatibility a list of #' summary parameters can also be passed. @@ -111,6 +116,7 @@ #' options(old_opts) #' } estimate_infections <- function(data, + process_opts = process_opts(), generation_time = generation_time_opts(), delays = delay_opts(), truncation = trunc_opts(), @@ -208,10 +214,15 @@ estimate_infections <- function(data, ) reported_cases <- reported_cases[-(1:backcalc$prior_window)] + model_choices <- c("infections", "growth", "R") + model <- match.arg(model, choices = model_choices) + process_model <- which(model == model_choices) - 1 + # Define stan model parameters stan_data <- create_stan_data( reported_cases, seeding_time = seeding_time, + process_opts = process_opts, rt = rt, gp = gp, obs = obs, diff --git a/R/extract.R b/R/extract.R index 331c100cf..32ab818e1 100644 --- a/R/extract.R +++ b/R/extract.R @@ -211,7 +211,7 @@ extract_parameter_samples <- function(stan_fit, data, reported_dates, out$growth_rate <- extract_parameter( "r", samples, - reported_dates[-1] + reported_dates ) if (data$week_effect > 1) { out$day_of_week <- extract_parameter( @@ -233,7 +233,7 @@ extract_parameter_samples <- function(stan_fit, data, reported_dates, date := NULL ] } - if (data$model_type == 1) { + if (data$obs_dist == 1) { out$reporting_overdispersion <- extract_static_parameter("rep_phi", samples) out$reporting_overdispersion <- out$reporting_overdispersion[, value := value.V1][, diff --git a/R/opts.R b/R/opts.R index 3d212614d..4648fc618 100644 --- a/R/opts.R +++ b/R/opts.R @@ -302,7 +302,7 @@ trunc_opts <- function(dist = Fixed(0), tolerance = 0.001, #' Time-Varying Reproduction Number Options #' -#' @description `r lifecycle::badge("stable")` +#' @description `r lifecycle::badge("deprecated")` #' Defines a list specifying the optional arguments for the time-varying #' reproduction number. Custom settings can be supplied which override the #' defaults. @@ -359,6 +359,7 @@ rt_opts <- function(prior = list(mean = 1, sd = 1), future = "latest", gp_on = c("R_t-1", "R0"), pop = 0) { + stop("rt_opts is deprecated - use process_opts instead") rt <- list( prior = prior, use_rt = use_rt, @@ -381,9 +382,93 @@ rt_opts <- function(prior = list(mean = 1, sd = 1), return(rt) } -#' Back Calculation Options +#' Process model optionss #' #' @description `r lifecycle::badge("stable")` +#' Defines a list specifying the optional arguments for the process mode. +#' Custom settings can be supplied which override the defaults. +#' @param prior List containing named numeric elements "mean" and "sd". The mean and +#' standard deviation of the log normal Rt prior. Defaults to mean of 1 and standard +#' deviation of 1. +#' @param use_rt Logical, defaults to `TRUE`. Should Rt be used to generate infections +#' and hence reported cases. +#' @param rw Numeric step size of the random walk, defaults to 0. To specify a weekly random +#' walk set `rw = 7`. For more custom break point settings consider passing in a `breakpoints` +#' variable as outlined in the next section. +#' @param use_breakpoints Logical, defaults to `TRUE`. Should break points be used if present +#' as a `breakpoint` variable in the input data. Break points should be defined as 1 if present +#' and otherwise 0. By default breakpoints are fit jointly with a global non-parametric effect +#' and so represent a conservative estimate of break point changes (alter this by setting `gp = NULL`). +#' @param pop Integer, defaults to 0. Susceptible population initially present. Used to adjust +#' Rt estimates when otherwise fixed based on the proportion of the population that is +#' susceptible. When set to 0 no population adjustment is done. +#' @param gp_on Character string, defaulting to "R_t-1". Indicates how the Gaussian process, +#' if in use, should be applied to Rt. Currently supported options are applying the Gaussian +#' process to the last estimated Rt (i.e Rt = Rt-1 * GP), and applying the Gaussian process to +#' a global mean (i.e Rt = R0 * GP). Both should produced comparable results when data is not +#' sparse but the method relying on a global mean will revert to this for real time estimates, +#' which may not be desirable. +#' @return A list of settings defining the time-varying reproduction number +#' @inheritParams create_future_rt +#' @export +#' @examples +#' # default settings +#' rt_opts() +#' +#' # add a custom length scale +#' rt_opts(prior = list(mean = 2, sd = 1)) +#' +#' # add a weekly random walk +#' rt_opts(rw = 7) +#' @importFrom data.table fcase +process_opts <- function(model = "R", + prior_mean = data.table::fcase( + model == "R", list(mean = 1, sd = 1), + model == "growth", list(mean = 0, sd = 1), + model == "infections", NULL + ), + prior_t = NULL, + rw = 0, + use_breakpoints = TRUE, + future = "latest", + stationary = FALSE + pop = 0) { + + ## check + model_choices <- c("infections", "growth", "R") + process_model <- match.arg(process_model, choices = model_choices) + process_model <- which(process_model == model_choices) - 1 + + if (!(xor(is.null(prior_mean), is.null(prior_t)))) { + stop("Either 'prior_mean' or 'prior_t' must be set to NULL") + } + process <- list( + process_model = process_model, + prior_mean = prior_mean, + prior_t = prior_t, + rw = rw, + use_breakpoints = use_breakpoints, + future = future, + stationary = stationary, + pop = pop, + ) + + # replace default settings with those specified by user + if (process$rw > 0) { + process$use_breakpoints <- TRUE + } + + if (!is.null(prior_mean) && + !("mean" %in% names(process$prior) && + "sd" %in% names(process$prior))) { + stop("prior must have both a mean and sd specified") + } + return(process) +} + +#' Back Calculation Options +#' +#' @description `r lifecycle::badge("deprecated")` #' Defines a list specifying the optional arguments for the back calculation #' of cases. Only used if `rt = NULL`. #' @@ -417,6 +502,7 @@ rt_opts <- function(prior = list(mean = 1, sd = 1), #' backcalc_opts() backcalc_opts <- function(prior = c("reports", "none", "infections"), prior_window = 14, rt_window = 1) { + stop("backcalc_opts is deprecated - use process_opts instead") backcalc <- list( prior = arg_match(prior), prior_window = prior_window, diff --git a/inst/stan/data/backcalc.stan b/inst/stan/data/backcalc.stan index ab7a42286..5fecb07fa 100644 --- a/inst/stan/data/backcalc.stan +++ b/inst/stan/data/backcalc.stan @@ -1,2 +1 @@ int backcalc_prior; // Prior type to use for backcalculation - int rt_half_window; // Half the moving average window used when calculating Rt diff --git a/inst/stan/data/covariates.stan b/inst/stan/data/covariates.stan new file mode 100644 index 000000000..56f0b1a1a --- /dev/null +++ b/inst/stan/data/covariates.stan @@ -0,0 +1,7 @@ +int process_model; // 0 = infections; 1 = growth; 2 = rt +int bp_n; // no of breakpoints (0 = no breakpoints) +int breakpoints[t - seeding_time]; // when do breakpoints occur +int cov_mean_const; // 0 = not const mean; 1 = const mean +real cov_mean_mean[cov_mean_const]; // const covariate mean +real cov_mean_sd[cov_mean_const]; // const covariate sd +vector[cov_mean_const ? 0 : t] cov_t; // time-varying covariate mean diff --git a/inst/stan/data/observation_model.stan b/inst/stan/data/observation_model.stan index 671004ef4..40e6c9a84 100644 --- a/inst/stan/data/observation_model.stan +++ b/inst/stan/data/observation_model.stan @@ -1,5 +1,5 @@ array[t - seeding_time] int day_of_week; // day of the week indicator (1 - 7) - int model_type; // type of model: 0 = poisson otherwise negative binomial + int obs_dist; // type of model: 0 = poisson otherwise negative binomial real phi_mean; // Mean and sd of the normal prior for the real phi_sd; // reporting process int week_effect; // length of week effect diff --git a/inst/stan/data/observations.stan b/inst/stan/data/observations.stan index 11fe8463c..6b1d2aa22 100644 --- a/inst/stan/data/observations.stan +++ b/inst/stan/data/observations.stan @@ -1,6 +1,6 @@ int t; // unobserved time int lt; // timepoints in the likelihood - int seeding_time; // time period used for seeding and not observed + int seeding_time; // time period used for seeding and not observed int horizon; // forecast horizon int future_time; // time in future for Rt array[lt] int cases; // observed cases diff --git a/inst/stan/data/rt.stan b/inst/stan/data/rt.stan index 11b1989ae..7cf1153f6 100644 --- a/inst/stan/data/rt.stan +++ b/inst/stan/data/rt.stan @@ -1,10 +1,6 @@ int estimate_r; // should the reproduction no be estimated (1 = yes) real prior_infections; // prior for initial infections real prior_growth; // prior on initial growth rate - real r_mean; // prior mean of reproduction number - real r_sd; // prior standard deviation of reproduction number - int bp_n; // no of breakpoints (0 = no breakpoints) - array[t - seeding_time] int breakpoints; // when do breakpoints occur int future_fixed; // is underlying future Rt assumed to be fixed int fixed_from; // Reference date for when Rt estimation should be fixed int pop; // Initial susceptible population diff --git a/inst/stan/data/simulation_observation_model.stan b/inst/stan/data/simulation_observation_model.stan index c8cab6b35..c85b4b470 100644 --- a/inst/stan/data/simulation_observation_model.stan +++ b/inst/stan/data/simulation_observation_model.stan @@ -3,6 +3,6 @@ array[n, week_effect] real day_of_week_simplex; int obs_scale; array[n, obs_scale] real frac_obs; - int model_type; - array[n, model_type] real rep_phi; // overdispersion of the reporting process + int obs_dist; + array[n, obs_dist] real rep_phi; // overdispersion of the reporting process int trunc_id; // id of truncation diff --git a/inst/stan/data/simulation_rt.stan b/inst/stan/data/simulation_rt.stan index 3beae3161..f53ba0a4d 100644 --- a/inst/stan/data/simulation_rt.stan +++ b/inst/stan/data/simulation_rt.stan @@ -1,4 +1,4 @@ - array[n, 1] real initial_infections; // initial logged infections + array[n, seeding_time ? 1 : 0] real initial_infections; // initial logged infections array[n, seeding_time > 1 ? 1 : 0] real initial_growth; //initial growth matrix[n, t - seeding_time] R; // reproduction number diff --git a/inst/stan/estimate_infections.stan b/inst/stan/estimate_infections.stan index 4bd1b8fb2..1cc52e17a 100644 --- a/inst/stan/estimate_infections.stan +++ b/inst/stan/estimate_infections.stan @@ -3,7 +3,7 @@ functions { #include functions/pmfs.stan #include functions/delays.stan #include functions/gaussian_process.stan -#include functions/rt.stan +#include functions/covariates.stan #include functions/infections.stan #include functions/observation_model.stan #include functions/generated_quantities.stan @@ -13,6 +13,7 @@ functions { data { #include data/observations.stan #include data/delays.stan +#include data/covariates.stan #include data/gaussian_process.stan #include data/rt.stan #include data/backcalc.stan @@ -26,9 +27,14 @@ transformed data{ // gaussian process int noise_terms = setup_noise(ot_h, t, horizon, estimate_r, stationary, future_fixed, fixed_from); matrix[noise_terms, M] PHI = setup_gp(M, L, noise_terms); // basis function - // Rt - real r_logmean = log(r_mean^2 / sqrt(r_sd^2 + r_mean^2)); - real r_logsd = sqrt(log(1 + (r_sd^2 / r_mean^2))); + // covariate mean + real cov_mean_logmean[cov_mean_const]; + real cov_mean_logsd[cov_mean_const]; + + if (cov_mean_const) { + cov_mean_logmean[1] = log(cov_mean_mean[1]^2 / sqrt(cov_mean_sd[1]^2 + cov_mean_mean[1]^2)); + cov_mean_logsd[1] = sqrt(log(1 + (cov_mean_sd[1]^2 / cov_mean_mean[1]^2))); + } array[delay_types] int delay_type_max; profile("assign max") { @@ -45,7 +51,7 @@ parameters{ array[fixed ? 0 : 1] real alpha; // scale of of noise GP vector[fixed ? 0 : M] eta; // unconstrained noise // Rt - vector[estimate_r] log_R; // baseline reproduction number estimate (log) + vector[cov_mean_const] log_cov_mean; // covariate (R/r) array[estimate_r] real initial_infections ; // seed infections array[estimate_r && seeding_time > 1 ? 1 : 0] real initial_growth; // seed growth rate array[bp_n > 0 ? 1 : 0] real bp_sd; // standard deviation of breakpoint effect @@ -55,13 +61,14 @@ parameters{ vector[delay_params_length] delay_params; // delay parameters simplex[week_effect] day_of_week_simplex;// day of week reporting effect array[obs_scale_sd > 0 ? 1 : 0] real frac_obs; // fraction of cases that are ultimately observed - array[model_type] real rep_phi; // overdispersion of the reporting process + array[obs_dist] real rep_phi; // overdispersion of the reporting process } transformed parameters { vector[fixed ? 0 : noise_terms] noise; // noise generated by the gaussian process - vector[estimate_r > 0 ? ot_h : 0] R; // reproduction number + vector[seeding_time] uobs_inf; vector[t] infections; // latent infections + vector[ot_h] cov; // covariates vector[ot_h] reports; // estimated reported cases vector[ot] obs_reports; // observed estimated reported cases vector[estimate_r * (delay_type_max[gt_id] + 1)] gt_rev_pmf; @@ -71,8 +78,23 @@ transformed parameters { noise = update_gp(PHI, M, L, alpha[1], rho[1], eta, gp_type); } } + // update covariates + cov = update_covariate( + log_cov_mean, cov_t, noise, breakpoints, bp_effects, + stationary, ot_h + ); + uobs_inf = generate_seed(initial_infections, initial_growth, seeding_time); // Estimate latent infections - if (estimate_r) { + if (process_model == 0) { + // via deconvolution + profile("infections") { + infections = infection_model(cov, uobs_inf, future_time); + } + } else if (process_model == 1) { + // via growth + infections = growth_model(cov, uobs_inf, future_time); + } else if (process_model == 2) { + // via Rt profile("gt") { gt_rev_pmf = get_delay_rev_pmf( gt_id, delay_type_max[gt_id] + 1, delay_types_p, delay_types_id, @@ -81,22 +103,9 @@ transformed parameters { 1, 1, 0 ); } - profile("R") { - R = update_Rt( - ot_h, log_R[estimate_r], noise, breakpoints, bp_effects, stationary - ); - } - profile("infections") { - infections = generate_infections( - R, seeding_time, gt_rev_pmf, initial_infections, initial_growth, pop, - future_time - ); - } - } else { - // via deconvolution profile("infections") { - infections = deconvolve_infections( - shifted_cases, noise, fixed, backcalc_prior + infections = renewal_model( + cov, uobs_inf, gt_rev_pmf, pop, future_time ); } } @@ -166,14 +175,16 @@ model { delay_dist, delay_weight ); } - if (estimate_r) { - // priors on Rt - profile("rt lp") { - rt_lp( - log_R, initial_infections, initial_growth, bp_effects, bp_sd, bp_n, - seeding_time, r_logmean, r_logsd, prior_infections, prior_growth - ); - } + profile("covariate lp") { + covariate_lp( + log_cov_mean, bp_effects, bp_sd, bp_n, cov_mean_logmean, cov_mean_logsd + ); + } + profile("infections lp") { + infections_lp( + initial_infections, initial_growth, prior_infections, prior_growth, + seeding_time + ); } // prior observation scaling if (obs_scale_sd > 0) { @@ -185,7 +196,7 @@ model { if (likelihood) { profile("report lp") { report_lp( - cases, cases_time, obs_reports, rep_phi, phi_mean, phi_sd, model_type, + cases, cases_time, obs_reports, rep_phi, phi_mean, phi_sd, obs_dist, obs_weight, accumulate ); } @@ -194,11 +205,11 @@ model { generated quantities { array[ot_h] int imputed_reports; - vector[estimate_r > 0 ? 0: ot_h] gen_R; + vector[estimate_r > 0 ? 0: ot_h] R; vector[ot_h - 1] r; vector[return_likelihood ? ot : 0] log_lik; profile("generated quantities") { - if (estimate_r == 0){ + if (estimate_r == 0 && process_model != 2) { // sample generation time vector[delay_params_length] delay_params_sample = to_vector(normal_lb_rng( delay_params_mean, delay_params_sd, delay_params_lower @@ -210,18 +221,22 @@ generated quantities { delay_dist, 1, 1, 0 ); // calculate Rt using infections and generation time - gen_R = calculate_Rt( - infections, seeding_time, sampled_gt_rev_pmf, rt_half_window - ); + R = calculate_Rt(infections, seeding_time, sampled_gt_rev_pmf); + } else { + R = cov; } // estimate growth from infections - r = calculate_growth(infections, seeding_time + 1); + if (process_model != 1) { + r = calculate_growth(infections, seeding_time); + } else { + r = cov; + } // simulate reported cases - imputed_reports = report_rng(reports, rep_phi, model_type); + imputed_reports = report_rng(reports, rep_phi, obs_dist); // log likelihood of model if (return_likelihood) { log_lik = report_log_lik( - cases, obs_reports[cases_time], rep_phi, model_type, obs_weight + cases, obs_reports[cases_time], rep_phi, obs_dist, obs_weight ); } } diff --git a/inst/stan/estimate_secondary.stan b/inst/stan/estimate_secondary.stan index 70fcc8d4a..6f777e2bd 100644 --- a/inst/stan/estimate_secondary.stan +++ b/inst/stan/estimate_secondary.stan @@ -30,7 +30,7 @@ parameters{ vector[delay_params_length] delay_params; simplex[week_effect] day_of_week_simplex; // day of week reporting effect array[obs_scale] real frac_obs; // fraction of cases that are ultimately observed - array[model_type] real rep_phi; // overdispersion of the reporting process + array[obs_dist] real rep_phi; // overdispersion of the reporting process } transformed parameters { @@ -97,7 +97,7 @@ model { if (likelihood) { report_lp( obs[(burn_in + 1):t][obs_time], obs_time, secondary[(burn_in + 1):t], - rep_phi, phi_mean, phi_sd, model_type, 1, accumulate + rep_phi, phi_mean, phi_sd, obs_dist, 1, accumulate ); } } @@ -106,10 +106,10 @@ generated quantities { array[t - burn_in] int sim_secondary; vector[return_likelihood > 1 ? t - burn_in : 0] log_lik; // simulate secondary reports - sim_secondary = report_rng(secondary[(burn_in + 1):t], rep_phi, model_type); + sim_secondary = report_rng(secondary[(burn_in + 1):t], rep_phi, obs_dist); // log likelihood of model if (return_likelihood) { log_lik = report_log_lik(obs[(burn_in + 1):t], secondary[(burn_in + 1):t], - rep_phi, model_type, obs_weight); + rep_phi, obs_dist, obs_weight); } } diff --git a/inst/stan/functions/rt.stan b/inst/stan/functions/covariates.stan similarity index 51% rename from inst/stan/functions/rt.stan rename to inst/stan/functions/covariates.stan index a2ba59e25..ecc19dd5a 100644 --- a/inst/stan/functions/rt.stan +++ b/inst/stan/functions/covariates.stan @@ -1,6 +1,7 @@ -// update a vector of Rts -vector update_Rt(int t, real log_R, vector noise, array[] int bps, - array[] real bp_effects, int stationary) { +// update combined covariates +vector update_covariate(array[] real log_cov_mean, vector cov_t, + vector noise, array[] int bps, + array[] real bp_effects, int stationary, int t) { // define control parameters int bp_n = num_elements(bp_effects); int bp_c = 0; @@ -8,7 +9,7 @@ vector update_Rt(int t, real log_R, vector noise, array[] int bps, // define result vectors vector[t] bp = rep_vector(0, t); vector[t] gp = rep_vector(0, t); - vector[t] R; + vector[t] cov; // initialise breakpoints if (bp_n) { for (s in 1:t) { @@ -32,26 +33,25 @@ vector update_Rt(int t, real log_R, vector noise, array[] int bps, gp = cumulative_sum(gp); } } - // Calculate Rt - R = rep_vector(log_R, t) + bp + gp; - R = exp(R); - return(R); + if (num_elements(log_cov_mean) > 0) { + cov = rep_vector(log_cov_mean[1], t); + } else { + cov = log(cov_t); + } + // Calculate combined covariates + cov = cov + bp + gp; + return(cov); } -// Rt priors -void rt_lp(vector log_R, array[] real initial_infections, array[] real initial_growth, - array[] real bp_effects, array[] real bp_sd, int bp_n, int seeding_time, - real r_logmean, real r_logsd, real prior_infections, - real prior_growth) { - // prior on R - log_R ~ normal(r_logmean, r_logsd); - //breakpoint effects on Rt +void covariate_lp(real[] log_cov_mean, + real[] bp_effects, real[] bp_sd, int bp_n, + real[] cov_mean_logmean, real[] cov_mean_logsd) { + // initial prior + if (num_elements(log_cov_mean) > 0) { + log_cov_mean ~ normal(cov_mean_logmean[1], cov_mean_logsd[1]); + } + // breakpoint effects if (bp_n > 0) { bp_sd[1] ~ normal(0, 0.1) T[0,]; bp_effects ~ normal(0, bp_sd[1]); } - // initial infections - initial_infections ~ normal(prior_infections, 0.2); - if (seeding_time > 1) { - initial_growth ~ normal(prior_growth, 0.2); - } } diff --git a/inst/stan/functions/generated_quantities.stan b/inst/stan/functions/generated_quantities.stan index d418a9d7b..77c447646 100644 --- a/inst/stan/functions/generated_quantities.stan +++ b/inst/stan/functions/generated_quantities.stan @@ -1,6 +1,6 @@ // calculate Rt directly from inferred infections vector calculate_Rt(vector infections, int seeding_time, - vector gt_rev_pmf, int smooth) { + vector gt_rev_pmf) { int t = num_elements(infections); int ot = t - seeding_time; vector[ot] R; @@ -13,27 +13,15 @@ vector calculate_Rt(vector infections, int seeding_time, ); R[s] = infections[s + seeding_time] / infectiousness[s]; } - if (smooth) { - for (s in 1:ot) { - real window = 0; - sR[s] = 0; - for (i in max(1, s - smooth):min(ot, s + smooth)) { - sR[s] += R[i]; - window += 1; - } - sR[s] = sR[s] / window; - } - }else{ - sR = R; - } - return(sR); + return(R); } // Calculate growth rate vector calculate_growth(vector infections, int seeding_time) { int t = num_elements(infections); - int ot = t - seeding_time; + int ot = t - seeding_time - 1; vector[t] log_inf = log(infections); - vector[ot] growth = log_inf[(seeding_time + 1):t] - log_inf[seeding_time:(t - 1)]; + vector[ot] growth = + log_inf[(seeding_time + 2):t] - log_inf[(seeding_time + 1):(t - 1)]; return(growth); } diff --git a/inst/stan/functions/infections.stan b/inst/stan/functions/infections.stan index b7790c582..484adcf1d 100644 --- a/inst/stan/functions/infections.stan +++ b/inst/stan/functions/infections.stan @@ -17,27 +17,32 @@ real update_infectiousness(vector infections, vector gt_rev_pmf, ); return(new_inf); } -// generate infections by using Rt = Rt-1 * sum(reversed generation time pmf * infections) -vector generate_infections(vector oR, int uot, vector gt_rev_pmf, - array[] real initial_infections, array[] real initial_growth, - int pop, int ht) { +// generate seed infections +vector generate_seed(real[] initial_infections, real[] initial_growth, int uot) { + vector[uot] seed_infs; + seed_infs[1] = exp(initial_infections[1]); + if (uot > 1) { + for (s in 2:uot) { + seed_infs[s] = exp(initial_infections[1] + initial_growth[1] * (s - 1)); + } + } + return(seed_infs); +} +// generate infections using infectiousness +vector renewal_model(vector oR, vector uobs_infs, vector gt_rev_pmf, + int pop, int ht) { // time indices and storage - int ot = num_elements(oR); + int ot = num_elements(r); + int uot = num_elements(uobs_inf); int nht = ot - ht; int t = ot + uot; - vector[ot] R = oR; + vector[ot] R = exp(r); real exp_adj_Rt; - vector[t] infections = rep_vector(0, t); + vector[t] infections; vector[ot] cum_infections; vector[ot] infectiousness; - // Initialise infections using daily growth - infections[1] = exp(initial_infections[1]); - if (uot > 1) { - real growth = exp(initial_growth[1]); - for (s in 2:uot) { - infections[s] = infections[s - 1] * growth; - } - } + // Initialise infections + infections[1:uot] = uobs_inf; // calculate cumulative infections if (pop) { cum_infections[1] = sum(infections[1:uot]); @@ -58,25 +63,40 @@ vector generate_infections(vector oR, int uot, vector gt_rev_pmf, } return(infections); } -// backcalculate infections using mean shifted cases and non-parametric noise -vector deconvolve_infections(vector shifted_cases, vector noise, int fixed, - int prior) { - int t = num_elements(shifted_cases); - vector[t] infections = rep_vector(1e-5, t); - if(!fixed) { - vector[t] exp_noise = exp(noise); - if (prior == 1) { - infections = infections + shifted_cases .* exp_noise; - }else if (prior == 0) { - infections = infections + exp_noise; - }else if (prior == 2) { - infections[1] = infections[1] + shifted_cases[1] * exp_noise[1]; - for (i in 2:t) { - infections[i] = infections[i - 1] * exp_noise[i]; - } - } - }else{ - infections = infections + shifted_cases; - } + +// update infections using a growth model (linear,log, or non-parametric growth) +vector growth_model(vector r, vector uobs_inf, int ht) { + // time indices and storage + int ot = num_elements(r); + int uot = num_elements(uobs_inf); + int nht = ot - ht; + int t = ot + uot; + vector[t] infections; + // Update observed infections + infections[1:uot] = uobs_inf; + infections[(uot + 1):t] = exp(log(uobs_inf[uot]) + cumulative_sum(r)); + return(infections); +} + +// update infections using a growth model (linear,log, or non-parametric growth) +vector infection_model(vector cov, vector uobs_inf, int ht) { + // time indices and storage + int ot = num_elements(cov); + int uot = num_elements(uobs_inf); + int nht = ot - ht; + int t = ot + uot; + vector[t] infections; + infections[1:uot] = uobs_inf; + infections[(uot + 1):t] = exp(cov); return(infections); } + +void infections_lp(real[] initial_infections, real[] initial_growth, + real prior_infections, real prior_growth, + int seeding_time) { + // initial infections + initial_infections ~ normal(prior_infections, 0.2); + if (seeding_time > 1) { + initial_growth ~ normal(prior_growth, 0.2); + } +} diff --git a/inst/stan/functions/observation_model.stan b/inst/stan/functions/observation_model.stan index ac54496c7..1778365a1 100644 --- a/inst/stan/functions/observation_model.stan +++ b/inst/stan/functions/observation_model.stan @@ -51,35 +51,11 @@ void truncation_lp(array[] real truncation_mean, array[] real truncation_sd, } } // update log density for reported cases -void report_lp(array[] int cases, array[] int cases_time, vector reports, - array[] real rep_phi, real phi_mean, real phi_sd, - int model_type, real weight, int accumulate) { - int n = num_elements(cases_time) - accumulate; // number of observations - vector[n] obs_reports; // reports at observation time - array[n] int obs_cases; // observed cases at observation time - if (accumulate) { - int t = num_elements(reports); - int i = 0; - int current_obs = 0; - obs_reports = rep_vector(0, n); - while (i <= t && current_obs <= n) { - if (current_obs > 0) { // first observation gets ignored when accumulating - obs_reports[current_obs] += reports[i]; - } - if (i == cases_time[current_obs + 1]) { - current_obs += 1; - } - i += 1; - } - obs_cases = cases[2:(n + 1)]; - } else { - obs_reports = reports[cases_time]; - obs_cases = cases; } - if (model_type) { - real dispersion = inv_square(phi_sd > 0 ? rep_phi[model_type] : phi_mean); + if (obs_dist) { + real dispersion = inv_square(phi_sd > 0 ? rep_phi[obs_dist] : phi_mean); if (phi_sd > 0) { - rep_phi[model_type] ~ normal(phi_mean, phi_sd) T[0,]; + rep_phi[obs_dist] ~ normal(phi_mean, phi_sd) T[0,]; } if (weight == 1) { obs_cases ~ neg_binomial_2(obs_reports, dispersion); @@ -98,17 +74,17 @@ void report_lp(array[] int cases, array[] int cases_time, vector reports, } // update log likelihood (as above but not vectorised and returning log likelihood) vector report_log_lik(array[] int cases, vector reports, - array[] real rep_phi, int model_type, real weight) { + array[] real rep_phi, int obs_dist, real weight) { int t = num_elements(reports); vector[t] log_lik; // defer to poisson if phi is large, to avoid overflow - if (model_type == 0) { + if (obs_dist == 0) { for (i in 1:t) { log_lik[i] = poisson_lpmf(cases[i] | reports[i]) * weight; } } else { - real dispersion = inv_square(rep_phi[model_type]); + real dispersion = inv_square(rep_phi[obs_dist]); for (i in 1:t) { log_lik[i] = neg_binomial_2_lpmf(cases[i] | reports[i], dispersion) * weight; } @@ -116,12 +92,12 @@ vector report_log_lik(array[] int cases, vector reports, return(log_lik); } // sample reported cases from the observation model -array[] int report_rng(vector reports, array[] real rep_phi, int model_type) { +array[] int report_rng(vector reports, array[] real rep_phi, int obs_dist) { int t = num_elements(reports); array[t] int sampled_reports; real dispersion = 1e5; - if (model_type) { - dispersion = inv_square(rep_phi[model_type]); + if (obs_dist) { + dispersion = inv_square(rep_phi[obs_dist]); } for (s in 1:t) { diff --git a/inst/stan/simulate_infections.stan b/inst/stan/simulate_infections.stan index 1f4f65cb9..de89ef92a 100644 --- a/inst/stan/simulate_infections.stan +++ b/inst/stan/simulate_infections.stan @@ -3,7 +3,6 @@ functions { #include functions/pmfs.stan #include functions/delays.stan #include functions/gaussian_process.stan -#include functions/rt.stan #include functions/infections.stan #include functions/observation_model.stan #include functions/generated_quantities.stan @@ -36,6 +35,7 @@ generated quantities { matrix[n, t - seeding_time] reports; // observed cases array[n, t - seeding_time] int imputed_reports; matrix[n, t - seeding_time - 1] r; + vector[seeding_time] uobs_inf; for (i in 1:n) { // generate infections from Rt trace vector[delay_type_max[gt_id] + 1] gt_rev_pmf; @@ -46,11 +46,10 @@ generated quantities { 1, 1, 0 ); - infections[i] = to_row_vector(generate_infections( - to_vector(R[i]), seeding_time, gt_rev_pmf, initial_infections[i], - initial_growth[i], pop, future_time - )); - + uobs_inf = generate_seed(initial_infections[i], initial_growth[i], seeding_time); + // generate infections from Rt trace + infections[i] = renewal_model(R[i], uobs_inf, gt_rev_pmf, pop, future_time); + // convolve from latent infections to mean of observations if (delay_id) { vector[delay_type_max[delay_id] + 1] delay_rev_pmf = get_delay_rev_pmf( delay_id, delay_type_max[delay_id] + 1, delay_types_p, delay_types_id, @@ -58,21 +57,18 @@ generated quantities { delay_np_pmf_groups, delay_params[i], delay_params_groups, delay_dist, 0, 1, 0 ); - // convolve from latent infections to mean of observations - reports[i] = to_row_vector(convolve_to_report( - to_vector(infections[i]), delay_rev_pmf, seeding_time) - ); + reports[i] = convolve_to_report(infections[i], delay_rev_pmf, seeding_time); } else { reports[i] = to_row_vector( infections[i, (seeding_time + 1):t] ); + reports[i] = infections[(seeding_time + 1):t]; } - // weekly reporting effect if (week_effect > 1) { - reports[i] = to_row_vector( - day_of_week_effect(to_vector(reports[i]), day_of_week, - to_vector(day_of_week_simplex[i]))); + reports[i] = day_of_week_effect( + reports[i], day_of_week, to_vector(day_of_week_simplex[i]) + ); } // truncate near time cases to observed reports if (trunc_id) { @@ -88,14 +84,14 @@ generated quantities { } // scale observations if (obs_scale) { - reports[i] = to_row_vector(scale_obs(to_vector(reports[i]), frac_obs[i, 1])); + reports[i] = scale_obs(reports[i], frac_obs[i, 1]); } // simulate reported cases imputed_reports[i] = report_rng( - to_vector(reports[i]), rep_phi[i], model_type + to_vector(reports[i]), rep_phi[i], obs_dist ); r[i] = to_row_vector( - calculate_growth(to_vector(infections[i]), seeding_time + 1) + calculate_growth(to_vector(infections[i]), seeding_time) ); } } diff --git a/inst/stan/simulate_secondary.stan b/inst/stan/simulate_secondary.stan index d59f1d484..6fabff25c 100644 --- a/inst/stan/simulate_secondary.stan +++ b/inst/stan/simulate_secondary.stan @@ -62,7 +62,9 @@ generated quantities { // weekly reporting effect if (week_effect > 1) { - secondary = day_of_week_effect(secondary, day_of_week, to_vector(day_of_week_simplex[i])); + secondary = day_of_week_effect( + secondary, day_of_week, to_vector(day_of_week_simplex[i]) + ); } // truncate near time cases to observed reports @@ -80,7 +82,7 @@ generated quantities { // simulate secondary reports sim_secondary[i] = report_rng( - tail(secondary, all_dates ? t : h), rep_phi[i], model_type + tail(secondary, all_dates ? t : h), rep_phi[i], obs_dist ); } } diff --git a/man/backcalc_opts.Rd b/man/backcalc_opts.Rd index 7535f4497..fd80501ad 100644 --- a/man/backcalc_opts.Rd +++ b/man/backcalc_opts.Rd @@ -1,5 +1,5 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/opts.R +% Please edit documentation in R/depreciated.R \name{backcalc_opts} \alias{backcalc_opts} \title{Back Calculation Options} @@ -37,7 +37,7 @@ estimate is included.} A \verb{} object of back calculation settings. } \description{ -\ifelse{html}{\href{https://lifecycle.r-lib.org/articles/stages.html#stable}{\figure{lifecycle-stable.svg}{options: alt='[Stable]'}}}{\strong{[Stable]}} +\ifelse{html}{\href{https://lifecycle.r-lib.org/articles/stages.html#deprecated}{\figure{lifecycle-deprecated.svg}{options: alt='[Deprecated]'}}}{\strong{[Deprecated]}} Defines a list specifying the optional arguments for the back calculation of cases. Only used if \code{rt = NULL}. } diff --git a/man/create_stan_data.Rd b/man/create_stan_data.Rd index e755fca8d..35375bec7 100644 --- a/man/create_stan_data.Rd +++ b/man/create_stan_data.Rd @@ -12,7 +12,8 @@ create_stan_data( obs, horizon, backcalc, - shifted_cases + shifted_cases, + process_model ) } \arguments{ diff --git a/man/estimate_infections.Rd b/man/estimate_infections.Rd index ca9a549af..1a805585b 100644 --- a/man/estimate_infections.Rd +++ b/man/estimate_infections.Rd @@ -7,6 +7,7 @@ Growth} \usage{ estimate_infections( data, + process_opts = process_opts(), generation_time = generation_time_opts(), delays = delay_opts(), truncation = trunc_opts(), @@ -91,6 +92,10 @@ Corresponds to the "DEBUG" level from \code{futile.logger}. See \code{setup_logg for more detailed logging options.} \item{reported_cases}{Deprecated; use \code{data} instead.} +\item{process_model}{A character string that defines what is being +modelled: "infections", "growth" or "R" (default). If ' set to "R", +a generation time distribution needs to be defined via the \code{generation_time} +argument.} } \value{ A list of output including: posterior samples, summarised posterior diff --git a/man/process_opts.Rd b/man/process_opts.Rd new file mode 100644 index 000000000..8cce5af5f --- /dev/null +++ b/man/process_opts.Rd @@ -0,0 +1,75 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/opts.R +\name{process_opts} +\alias{process_opts} +\title{Process model optionss} +\usage{ +process_opts( + model = "R", + prior_mean = data.table::fcase(model == "R", list(mean = 1, sd = 1), model == "growth", + list(mean = 0, sd = 1), model == "infections", NULL), + prior_t = NULL, + rw = 0, + use_breakpoints = TRUE, + future = "latest", + stationary = FALSE, + pop = 0 +) +} +\arguments{ +\item{rw}{Numeric step size of the random walk, defaults to 0. To specify a +weekly random walk set \code{rw = 7}. For more custom break point settings +consider passing in a \code{breakpoints} variable as outlined in the next section.} + +\item{use_breakpoints}{Logical, defaults to \code{TRUE}. Should break points be +used if present as a \code{breakpoint} variable in the input data. Break points +should be defined as 1 if present and otherwise 0. By default breakpoints +are fit jointly with a global non-parametric effect and so represent a +conservative estimate of break point changes (alter this by setting +\code{gp = NULL}).} + +\item{future}{A character string or integer. This argument indicates how to set future Rt values. Supported +options are to project using the Rt model ("project"), to use the latest estimate based on partial data ("latest"), +to use the latest estimate based on data that is over 50\% complete ("estimate"). If an integer is supplied then the Rt estimate +from this many days into the future (or past if negative) past will be used forwards in time.} + +\item{pop}{Integer, defaults to 0. Susceptible population initially present. +Used to adjust Rt estimates when otherwise fixed based on the proportion of +the population that is susceptible. When set to 0 no population adjustment +is done.} + +\item{prior}{List containing named numeric elements "mean" and "sd". The +mean and standard deviation of the log normal Rt prior. Defaults to mean of +1 and standard deviation of 1.} + +\item{use_rt}{Logical, defaults to \code{TRUE}. Should Rt be used to generate +infections and hence reported cases.} + +\item{gp_on}{Character string, defaulting to "R_t-1". Indicates how the +Gaussian process, if in use, should be applied to Rt. Currently supported +options are applying the Gaussian process to the last estimated Rt (i.e +Rt = Rt-1 * GP), and applying the Gaussian process to a global mean (i.e Rt += R0 * GP). Both should produced comparable results when data is not sparse +but the method relying on a global mean will revert to this for real time +estimates, which may not be desirable.} +} +\value{ +A list of settings defining the time-varying reproduction number. +} +\description{ +\ifelse{html}{\href{https://lifecycle.r-lib.org/articles/stages.html#stable}{\figure{lifecycle-stable.svg}{options: alt='[Stable]'}}}{\strong{[Stable]}} +Defines a list specifying the optional arguments for the process mode. +Custom settings can be supplied which override the defaults. +} +\examples{ +# default settings +process_opts() + +# add a weekly random walk +process_opts(rw = 7) +} +\author{ +Sebastian Funk + +Sam Abbott +} diff --git a/man/rt_opts.Rd b/man/rt_opts.Rd index 24774d891..c92bb86f2 100644 --- a/man/rt_opts.Rd +++ b/man/rt_opts.Rd @@ -1,5 +1,5 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/opts.R +% Please edit documentation in R/depreciated.R \name{rt_opts} \alias{rt_opts} \title{Time-Varying Reproduction Number Options} @@ -59,7 +59,7 @@ An \verb{} object with settings defining the time-varying reproduction number. } \description{ -\ifelse{html}{\href{https://lifecycle.r-lib.org/articles/stages.html#stable}{\figure{lifecycle-stable.svg}{options: alt='[Stable]'}}}{\strong{[Stable]}} +\ifelse{html}{\href{https://lifecycle.r-lib.org/articles/stages.html#deprecated}{\figure{lifecycle-deprecated.svg}{options: alt='[Deprecated]'}}}{\strong{[Deprecated]}} Defines a list specifying the optional arguments for the time-varying reproduction number. Custom settings can be supplied which override the defaults. diff --git a/tests/testthat/test-create_obs_model.R b/tests/testthat/test-create_obs_model.R index b2940a284..ce6d4d865 100644 --- a/tests/testthat/test-create_obs_model.R +++ b/tests/testthat/test-create_obs_model.R @@ -5,11 +5,11 @@ test_that("create_obs_model works with default settings", { obs <- create_obs_model(dates = dates) expect_equal(length(obs), 12) expect_equal(names(obs), c( - "model_type", "phi_mean", "phi_sd", "week_effect", "obs_weight", + "obs_dist", "phi_mean", "phi_sd", "week_effect", "obs_weight", "obs_scale", "obs_scale_mean", "obs_scale_sd", "accumulate", "likelihood", "return_likelihood", "day_of_week" )) - expect_equal(obs$model_type, 1) + expect_equal(obs$obs_dist, 1) expect_equal(obs$week_effect, 7) expect_equal(obs$obs_scale, 0) expect_equal(obs$likelihood, 1) @@ -21,7 +21,7 @@ test_that("create_obs_model works with default settings", { test_that("create_obs_model can be used with a Poisson model", { obs <- create_obs_model(dates = dates, obs = obs_opts(family = "poisson")) - expect_equal(obs$model_type, 0) + expect_equal(obs$obs_dist, 0) }) test_that("create_obs_model can be used with a scaling", {