Skip to content

Commit

Permalink
Issue 757: Vectorise GP stan code (#742)
Browse files Browse the repository at this point in the history
* refactor gps along approach used by aki

* add some skeleton unit tests

* update model

* update data chunk

* add r interface code

* add some R side tests where missing for gp related code

* get tests passing

* fix tests for create_gp_data

* fix tests

* update docs

* fix tests warnings to error for uncommon mattern orders

* remove spurious plots

* update EpiNow2 vignette

* review kkernals

* update and test

* correct scaling of L

* rescale lengthscale

* change adapt-delta default to 0.9

* expand inits for GP as causing issues as to close to 1/0

* get rid of normalisation and use unormalised lpdf where possible (equiv)

* widen optimisation sweep to include delay weight default

* non-center random walk

* tune prior specification

* tune dispersion prior

* tune phi

* update vignette

* get rid of rw change

* revert Rt

* catch update_rt

* add news

* revert vignette changes

* fix gp_opts tests

* fix create tests

* update gp tests

* skip tests as required on windows

* fix linting

* constrain delay uncertainty

* correct gp stan tests

* fix GP test

* put the deprecition warning behind a gate

* refactor gps along approach used by aki

* add some skeleton unit tests

* update model

* update data chunk

* add r interface code

* add some R side tests where missing for gp related code

* get tests passing

* fix tests for create_gp_data

* fix tests

* update docs

* fix tests warnings to error for uncommon mattern orders

* update EpiNow2 vignette

* review kkernals

* update and test

* correct scaling of L

* rescale lengthscale

* change adapt-delta default to 0.9

* expand inits for GP as causing issues as to close to 1/0

* get rid of normalisation and use unormalised lpdf where possible (equiv)

* widen optimisation sweep to include delay weight default

* non-center random walk

* tune prior specification

* tune dispersion prior

* tune phi

* update vignette

* get rid of rw change

* revert Rt

* catch update_rt

* add news

* revert vignette changes

* fix gp_opts tests

* fix create tests

* update gp tests

* skip tests as required on windows

* fix linting

* constrain delay uncertainty

* correct gp stan tests

* fix GP test

* put the deprecition warning behind a gate

* fix linting

* Update NEWS.md

* add linear kernel support

* add docs and newa

* integration tests and minor issues

* fixes for periodic kernel dimension differences

* drop linear kernel support

* lint space

* catch outstanding linear tests

* catch stan tests

* make the eecdf in convolve test less random

* Update R/create.R

Co-authored-by: James Azam <[email protected]>

* Update NEWS.md

Co-authored-by: James Azam <[email protected]>

* Update create.R - remove out of date gp type 3 check

* Update opts.R - remove linear kernel references

* Update R/opts.R

Co-authored-by: James Azam <[email protected]>

* Update opts.R  - fix review suggestions

* Update opts.R - remove linear reference

* Update estimate_infections.stan

* Update tests/testthat/test-create_gp_data.R

Co-authored-by: James Azam <[email protected]>

* Update NEWS.md

* Update opts.R

* Document

---------

Co-authored-by: James Azam <[email protected]>
Co-authored-by: GitHub Actions <[email protected]>
  • Loading branch information
3 people authored Aug 29, 2024
1 parent 39cdaff commit 025693c
Show file tree
Hide file tree
Showing 25 changed files with 606 additions and 220 deletions.
6 changes: 6 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,12 @@
- `epinow()` now returns the "timing" output in a "time difference"" format that is easier to understand and work with. By @jamesmbaazam in #688 and reviewed by @sbfnk.
- The interface for defining delay distributions has been generalised to also cater for continuous distributions
- When defining probability distributions these can now be truncated using the `tolerance` argument
- Ornstein-Uhlenbeck and 5 / 2 Matérn kernels have been added. By @sbfnk in #741 and reviewed by @seabbs.
- Gaussian processes have been vectorised, leading to some speed gains 🚀 , and the `gp_opts()` function has gained three more options, "periodic", "ou", and "se", to specify periodic and linear kernels respectively. By @seabbs in #742 and reviewed by @jamesmbaazam.
- Prior predictive checks have been used to update the following priors: the prior on the magnitude of the Gaussian process (from HalfNormal(0, 1) to HalfNormal(0, 0.1)), and the prior on the overdispersion (from 1 / HalfNormal(0, 1)^2 to 1 / HalfNormal(0, 0.25)). In the user-facing API, this is a change in default values of the `sd` of `phi` in `obs_opts()` from 1 to 0.25. By @seabbs in #742 and reviewed by @jamesmbaazam.
- The default stan control options have been updated from `list(adapt_delta = 0.95, max_treedepth = 15)` to `list(adapt_delta = 0.9, max_treedepth = 12)` due to improved performance and to reduce the runtime of the default parameterisations. By @seabbs in #742 and reviewed by @jamesmbaazam.
- Initialisation has been simplified by sampling directly from the priors, where possible, rather than from a constrained space. By @seabbs in #742 and reviewed by @jamesmbaazam.
- Unnecessary normalisation of delay priors has been removed. By @seabbs in #742 and reviewed by @jamesmbaazam.
- Ornstein-Uhlenbeck and 5 / 2 Matérn kernels have been added. By @sbfnk in # and reviewed by @.
- Switch to broadcasting from random walks and added unit tests. By @seabbs in #747 and reviewed by @jamesmbaazam.
- Optimised convolution code to take into account the relative length of the vectors being convolved. See #745 by @seabbs and reviewed by @jamesmbaazam.
Expand Down
80 changes: 50 additions & 30 deletions R/create.R
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,7 @@ create_backcalc_data <- function(backcalc = backcalc_opts()) {
)
return(data)
}

#' Create Gaussian Process Data
#'
#' @description `r lifecycle::badge("stable")`
Expand Down Expand Up @@ -387,33 +388,50 @@ create_gp_data <- function(gp = gp_opts(), data) {
} else {
fixed <- FALSE
}
# reset ls_max if larger than observed time
time <- data$t - data$seeding_time - data$horizon
if (gp$ls_max > time) {
gp$ls_max <- time

time <- data$t - data$seeding_time
if (data$future_fixed > 0) {
time <- time + data$fixed_from - data$horizon
}
if (data$stationary == 1) {
time <- time - 1
}

obs_time <- data$t - data$seeding_time
if (gp$ls_max > obs_time) {
gp$ls_max <- obs_time
}

times <- seq_len(time)

rescaled_times <- (times - mean(times)) / sd(times)
gp$ls_mean <- gp$ls_mean / sd(times)
gp$ls_sd <- gp$ls_sd / sd(times)
gp$ls_min <- gp$ls_min / sd(times)
gp$ls_max <- gp$ls_max / sd(times)

# basis functions
M <- data$t - data$seeding_time
M <- ifelse(data$future_fixed == 1, M - (data$horizon - data$fixed_from), M)
M <- ceiling(M * gp$basis_prop)
M <- ceiling(time * gp$basis_prop)

# map settings to underlying gp stan requirements
gp_data <- list(
fixed = as.numeric(fixed),
M = M,
L = gp$boundary_scale,
L = gp$boundary_scale * max(rescaled_times),
ls_meanlog = convert_to_logmean(gp$ls_mean, gp$ls_sd),
ls_sdlog = convert_to_logsd(gp$ls_mean, gp$ls_sd),
ls_min = gp$ls_min,
ls_max = data$t - data$seeding_time - data$horizon,
ls_max = gp$ls_max,
alpha_mean = gp$alpha_mean,
alpha_sd = gp$alpha_sd,
gp_type = data.table::fcase(
is.infinite(gp$matern_order), 0,
gp$matern_order == 1 / 2, 1,
gp$matern_order == 3 / 2, 2,
default = 3
)
gp$kernel == "se", 0,
gp$kernel == "periodic", 1,
gp$kernel == "matern" || gp$kernel == "ou", 2,
default = 2
),
nu = gp$matern_order,
w0 = gp$w0
)

gp_data <- c(data, gp_data)
Expand Down Expand Up @@ -606,42 +624,44 @@ create_initial_conditions <- function(data) {
out <- create_delay_inits(data)

if (data$fixed == 0) {
out$eta <- array(rnorm(data$M, mean = 0, sd = 0.1))
out$rho <- array(rlnorm(1,
out$eta <- array(rnorm(
ifelse(data$gp_type == 1, data$M * 2, data$M), mean = 0, sd = 0.1))
out$rescaled_rho <- array(rlnorm(1,
meanlog = data$ls_meanlog,
sdlog = ifelse(data$ls_sdlog > 0, data$ls_sdlog * 0.1, 0.01)
sdlog = ifelse(data$ls_sdlog > 0, data$ls_sdlog, 0.01)
))
out$rescaled_rho <- array(data.table::fcase(
out$rescaled_rho > data$ls_max, data$ls_max - 0.001,
out$rescaled_rho < data$ls_min, data$ls_min + 0.001,
default = out$rescaled_rho
))

out$rho <- array(data.table::fcase(
out$rho > data$ls_max, data$ls_max - 0.001,
out$rho < data$ls_min, data$ls_min + 0.001,
default = out$rho
))

out$alpha <- array(
truncnorm::rtruncnorm(1, a = 0, mean = 0, sd = data$alpha_sd)
truncnorm::rtruncnorm(
1, a = 0, mean = data$alpha_mean, sd = data$alpha_sd
)
)
} else {
out$eta <- array(numeric(0))
out$rho <- array(numeric(0))
out$rescaled_rho <- array(numeric(0))
out$alpha <- array(numeric(0))
}
if (data$model_type == 1) {
out$rep_phi <- array(
truncnorm::rtruncnorm(
1,
a = 0, mean = data$phi_mean, sd = data$phi_sd / 10
a = 0, mean = data$phi_mean, sd = data$phi_sd
)
)
}
if (data$estimate_r == 1) {
out$initial_infections <- array(rnorm(1, data$prior_infections, 0.02))
out$initial_infections <- array(rnorm(1, data$prior_infections, 0.2))
if (data$seeding_time > 1) {
out$initial_growth <- array(rnorm(1, data$prior_growth, 0.01))
out$initial_growth <- array(rnorm(1, data$prior_growth, 0.02))
}
out$log_R <- array(rnorm(
n = 1, mean = convert_to_logmean(data$r_mean, data$r_sd),
sd = convert_to_logsd(data$r_mean, data$r_sd) * 0.1
sd = convert_to_logsd(data$r_mean, data$r_sd)
))
}

Expand All @@ -656,7 +676,7 @@ create_initial_conditions <- function(data) {
out$frac_obs <- array(truncnorm::rtruncnorm(1,
a = 0, b = 1,
mean = data$obs_scale_mean,
sd = data$obs_scale_sd * 0.1
sd = data$obs_scale_sd
))
} else {
out$frac_obs <- array(numeric(0))
Expand Down
3 changes: 1 addition & 2 deletions R/estimate_infections.R
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,7 @@
#' def <- estimate_infections(reported_cases,
#' generation_time = gt_opts(generation_time),
#' delays = delay_opts(incubation_period + reporting_delay),
#' rt = rt_opts(prior = list(mean = 2, sd = 0.1)),
#' stan = stan_opts(control = list(adapt_delta = 0.95))
#' rt = rt_opts(prior = list(mean = 2, sd = 0.1))
#' )
#' # real time estimates
#' summary(def)
Expand Down
98 changes: 56 additions & 42 deletions R/opts.R
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
#' @param weight_prior Logical; if TRUE (default), any priors given in `dist`
#' will be weighted by the number of observation data points, in doing so
#' approximately placing an independent prior at each time step and usually
#' preventing the posteriors from shifting. If FALSE, no weight will be
#' applied, i.e. any parameters in `dist` will be treated as a single
#' parameters.
#' preventing the posteriors from shifting. If FALSE, no weight
#' will be applied, i.e. any parameters in `dist` will be treated as a single
#' parameters.
#' @inheritParams apply_default_tolerance
#' @return A `<generation_time_opts>` object summarising the input delay
#' distributions.
Expand Down Expand Up @@ -401,7 +401,7 @@ backcalc_opts <- function(prior = c("reports", "none", "infections"),
#'
#' @description `r lifecycle::badge("stable")`
#' Defines a list specifying the structure of the approximate Gaussian
#' process. Custom settings can be supplied which override the defaults.
#' process. Custom settings can be supplied which override the defaults.
#'
#' @param ls_mean Numeric, defaults to 21 days. The mean of the lognormal
#' length scale.
Expand All @@ -411,31 +411,34 @@ backcalc_opts <- function(prior = c("reports", "none", "infections"),
#' process length scale will be used with recommended parameters
#' \code{inv_gamma(1.499007, 0.057277 * ls_max)}.
#'
#' @param ls_min Numeric, defaults to 0. The minimum value of the length scale.
#'
#' @param ls_max Numeric, defaults to 60. The maximum value of the length
#' scale. Updated in [create_gp_data()] to be the length of the input data if
#' this is smaller.
#'
#' @param ls_min Numeric, defaults to 0. The minimum value of the length scale.
#' @param alpha_mean Numeric, defaults to 0. The mean of the magnitude parameter
#' of the Gaussian process kernel. Should be approximately the expected variance
#' of the logged Rt.
#'
#' @param alpha_sd Numeric, defaults to 0.05. The standard deviation of the
#' magnitude parameter of the Gaussian process kernel. Should be approximately
#' @param alpha_sd Numeric, defaults to 0.01. The standard deviation of the
#' magnitude parameter of the Gaussian process kernel. Should be approximately
#' the expected standard deviation of the logged Rt.
#'
#' @param kernel Character string, the type of kernel required. Currently
#' supporting the squared exponential kernel ("se", or "matern" with
#' 'matern_order = Inf'), 3 over 2 oder 5 over 2 Matern kernel ("matern", with
#' `matern_order = 3/2` (default) or `matern_order = 5/2`, respectively), or
#' Orstein-Uhlenbeck ("ou", or "matern" with 'matern_order = 1/2'). Defaulting
#' to the Matérn 3 over 2 kernel for a balance of smoothness and
#' discontinuities.
#' supporting the Matern kernel ("matern"), squared exponential kernel ("se"),
#' periodic kernel, Ornstein-Uhlenbeck #' kernel ("ou"), and the periodic
#' kernel ("periodic").
#'
#' @param matern_order Numeric, defaults to 3/2. Order of Matérn Kernel to use.
#' Currently the orders 1/2, 3/2, 5/2 and Inf are supported.
#' Common choices are 1/2, 3/2, and 5/2. If `kernel` is set
#' to "ou", `matern_order` will be automatically set to 1/2. Only used if
#' the kernel is set to "matern".
#'
#' @param matern_type Deprated; Numeric, defaults to 3/2. Order of Matérn Kernel
#' to use. Currently the orders 1/2, 3/2, 5/2 and Inf are supported.
#' @param matern_type Deprecated; Numeric, defaults to 3/2. Order of Matérn
#' Kernel to use. Currently, the orders 1/2, 3/2, 5/2 and Inf are supported.
#'
#' @param basis_prop Numeric, proportion of time points to use as basis
#' @param basis_prop Numeric, the proportion of time points to use as basis
#' functions. Defaults to 0.2. Decreasing this value results in a decrease in
#' accuracy but a faster compute time (with increasing it having the first
#' effect). In general smaller posterior length scales require a higher
Expand All @@ -446,6 +449,9 @@ backcalc_opts <- function(prior = c("reports", "none", "infections"),
#' approximate Gaussian process. See (Riutort-Mayol et al. 2020
#' <https://arxiv.org/abs/2004.11408>) for advice on updating this default.
#'
#' @param w0 Numeric, defaults to 1.0. Fundamental frequency for periodic
#' kernel. They are only used if `kernel` is set to "periodic".
#'
#' @importFrom rlang arg_match
#' @return A `<gp_opts>` object of settings defining the Gaussian process
#' @export
Expand All @@ -455,21 +461,30 @@ backcalc_opts <- function(prior = c("reports", "none", "infections"),
#'
#' # add a custom length scale
#' gp_opts(ls_mean = 4)
#'
#' # use linear kernel
#' gp_opts(kernel = "periodic")
gp_opts <- function(basis_prop = 0.2,
boundary_scale = 1.5,
ls_mean = 21,
ls_sd = 7,
ls_min = 0,
ls_max = 60,
alpha_sd = 0.05,
kernel = c("matern", "se", "ou"),
alpha_mean = 0,
alpha_sd = 0.01,
kernel = c("matern", "se", "ou", "periodic"),
matern_order = 3 / 2,
matern_type) {
lifecycle::deprecate_warn(
"1.6.0", "gp_opts(matern_type)", "gp_opts(matern_order)"
)
matern_type,
w0 = 1.0) {

if (!missing(matern_type)) {
if (!missing(matern_order) && matern_type == matern_order) {
lifecycle::deprecate_warn(
"1.6.0", "gp_opts(matern_type)", "gp_opts(matern_order)"
)
}

if (!missing(matern_type)) {
if (!missing(matern_order) && matern_type != matern_order) {
stop(
"Incompatible `matern_order` and `matern_type`. ",
"Use `matern_order` only."
Expand All @@ -480,20 +495,15 @@ gp_opts <- function(basis_prop = 0.2,

kernel <- arg_match(kernel)
if (kernel == "se") {
if (!missing(matern_order) && is.finite(matern_order)) {
stop("Squared exponential kernel must have matern order unset or `Inf`.")
}
matern_order <- Inf
} else if (kernel == "ou") {
if (!missing(matern_order) && matern_order != 1 / 2) {
stop("Ornstein-Uhlenbeck kernel must have matern order unset or `1 / 2`.") ## nolint: nonportable_path_linter
}
matern_order <- 1 / 2
} else if (!(is.infinite(matern_order) ||
matern_order %in% c(1 / 2, 3 / 2, 5 / 2))) {
stop(
"only the Matern kernels of order `1 / 2`, `3 / 2`, `5 / 2` or `Inf` ", ## nolint: nonportable_path_linter
"are currently supported"
} else if (
!(is.infinite(matern_order) || matern_order %in% c(1 / 2, 3 / 2, 5 / 2))
) {
warning(
"Uncommon Matern kernel order. Common orders are `1 / 2`, `3 / 2`,", # nolint
" and `5 / 2`" # nolint
)
}

Expand All @@ -504,9 +514,11 @@ gp_opts <- function(basis_prop = 0.2,
ls_sd = ls_sd,
ls_min = ls_min,
ls_max = ls_max,
alpha_mean = alpha_mean,
alpha_sd = alpha_sd,
kernel = kernel,
matern_order = matern_order
matern_order = matern_order,
w0 = w0
)

attr(gp, "class") <- c("gp_opts", class(gp))
Expand All @@ -523,8 +535,10 @@ gp_opts <- function(basis_prop = 0.2,
#' @param phi Overdispersion parameter of the reporting process, used only if
#' `familiy` is "negbin". Can be supplied either as a single numeric value
#' (fixed overdispersion) or a list with numeric elements mean (`mean`) and
#' standard deviation (`sd`) defining a normally distributed overdispersion.
#' Defaults to a list with elements `mean = 0` and `sd = 1`.
#' standard deviation (`sd`) defining a normally distributed prior.
#' Internally parametersed such that the overedispersion is one over the
#' square of this prior overdispersion. Defaults to a list with elements
#' `mean = 0` and `sd = 0.25`.
#' @param weight Numeric, defaults to 1. Weight to give the observed data in the
#' log density.
#' @param week_effect Logical defaulting to `TRUE`. Should a day of the week
Expand Down Expand Up @@ -563,7 +577,7 @@ gp_opts <- function(basis_prop = 0.2,
#' # Scale reported data
#' obs_opts(scale = list(mean = 0.2, sd = 0.02))
obs_opts <- function(family = c("negbin", "poisson"),
phi = list(mean = 0, sd = 1),
phi = list(mean = 0, sd = 0.25),
weight = 1,
week_effect = TRUE,
week_length = 7,
Expand Down Expand Up @@ -634,8 +648,8 @@ obs_opts <- function(family = c("negbin", "poisson"),
#' @param chains Numeric, defaults to 4. Number of MCMC chains to use.
#'
#' @param control List, defaults to empty. control parameters to pass to
#' underlying `rstan` function. By default `adapt_delta = 0.95` and
#' `max_treedepth = 15` though these settings can be overwritten.
#' underlying `rstan` function. By default `adapt_delta = 0.9` and
#' `max_treedepth = 12` though these settings can be overwritten.
#'
#' @param save_warmup Logical, defaults to FALSE. Should warmup progress be
#' saved.
Expand Down Expand Up @@ -684,7 +698,7 @@ stan_sampling_opts <- function(cores = getOption("mc.cores", 1L),
future = future,
max_execution_time = max_execution_time
)
control_def <- list(adapt_delta = 0.95, max_treedepth = 15)
control_def <- list(adapt_delta = 0.9, max_treedepth = 12)
control_def <- modifyList(control_def, control)
if (any(c("iter", "iter_sampling") %in% names(dot_args))) {
warning(
Expand Down
3 changes: 1 addition & 2 deletions R/regional_epinow.R
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,7 @@
#' delays = delay_opts(example_incubation_period + example_reporting_delay),
#' rt = rt_opts(prior = list(mean = 2, sd = 0.2)),
#' stan = stan_opts(
#' samples = 100, warmup = 200,
#' control = list(adapt_delta = 0.95)
#' samples = 100, warmup = 200
#' ),
#' verbose = interactive()
#' )
Expand Down
1 change: 0 additions & 1 deletion R/simulate_infections.R
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,6 @@ simulate_infections <- function(estimates, R, initial_infections,
#' generation_time = generation_time_opts(example_generation_time),
#' delays = delay_opts(example_incubation_period + example_reporting_delay),
#' rt = rt_opts(prior = list(mean = 2, sd = 0.1), rw = 7),
#' stan = stan_opts(control = list(adapt_delta = 0.9)),
#' obs = obs_opts(scale = list(mean = 0.1, sd = 0.01)),
#' gp = NULL, horizon = 0
#' )
Expand Down
5 changes: 4 additions & 1 deletion inst/stan/data/gaussian_process.stan
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@
real ls_sdlog; // sdlog for gp lengthscale prior
real<lower=0> ls_min; // Lower bound for the lengthscale
real<lower=0> ls_max; // Upper bound for the lengthscale
real alpha_mean; // mean of the alpha gp kernal parameter
real alpha_sd; // standard deviation of the alpha gp kernal parameter
int gp_type; // type of gp, 0 = squared exponential, 1 = 3/2 matern
int gp_type; // type of gp, 0 = squared exponential, 1 = periodic, 2 = Matern
real nu; // smoothness parameter for Matern kernel (used if gp_type = 2)
real w0; // fundamental frequency for periodic kernel (used if gp_type = 1)
int stationary; // is underlying gaussian process first or second order
int fixed; // should a gaussian process be used
Loading

0 comments on commit 025693c

Please sign in to comment.