diff --git a/NAMESPACE b/NAMESPACE index 168b52cb0..e78bed18b 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -106,6 +106,7 @@ export(simulate_infections) export(simulate_secondary) export(stan_laplace_opts) export(stan_opts) +export(stan_pathfinder_opts) export(stan_sampling_opts) export(stan_vb_opts) export(summarise_key_measures) diff --git a/R/opts.R b/R/opts.R index 3c1cdc1c4..a9833f25f 100644 --- a/R/opts.R +++ b/R/opts.R @@ -798,6 +798,36 @@ stan_laplace_opts <- function(backend = "cmdstanr", return(opts) } +#' Stan pathfinder algorithm Options +#' +#' @description `r lifecycle::badge("stable")` Defines a list specifying the +#' arguments passed to [cmdstanr::laplace()]. +#' +#' @inheritParams stan_opts +#' @inheritParams stan_vb_opts +#' @param ... Additional parameters to pass to [cmdstanr::laplace()]. +#' @return A list of arguments to pass to [cmdstanr::laplace()]. +#' @export +#' @examples +#' stan_laplace_opts() +stan_pathfinder_opts <- function(backend = "cmdstanr", + samples = 2000, + trials = 10, + ...) { + if (backend != "cmdstanr") { + stop( + "The pathfinder algorithm is only available with the \"cmdstanr\" ", + "backend." + ) + } + opts <- list( + trials = trials, + draws = samples + ) + opts <- c(opts, ...) + return(opts) +} + #' Rstan Options #' #' @description `r lifecycle::badge("deprecated")` @@ -858,7 +888,9 @@ rstan_opts <- function(object = NULL, #' #' @param method A character string, defaulting to sampling. Currently supports #' MCMC sampling ("sampling") or approximate posterior sampling via -#' variational inference ("vb"). +#' variational inference ("vb") and, if the "cmdstanr" backend is used, +#' approximate posterior sampling with the laplaces algorithm ("laplace") or +#' pathfinder ("pathfinder"). #' #' @param backend Character string indicating the backend to use for fitting #' stan models. Supported arguments are "rstan" (default) or "cmdstanr". @@ -904,7 +936,9 @@ stan_opts <- function(object = NULL, init_fit = NULL, return_fit = TRUE, ...) { - method <- arg_match(method, values = c("sampling", "vb", "laplace")) + method <- arg_match( + method, values = c("sampling", "vb", "laplace", "pathfinder") + ) backend <- arg_match(backend, values = c("rstan", "cmdstanr")) if (backend == "cmdstanr" && !requireNamespace("cmdstanr", quietly = TRUE)) { stop( @@ -934,6 +968,10 @@ stan_opts <- function(object = NULL, opts <- c( opts, stan_laplace_opts(backend = backend, ...) ) + } else if (method == "pathfinder") { + opts <- c( + opts, stan_pathfinder_opts(samples = samples, backend = backend, ...) + ) } if (!is.null(init_fit)) { diff --git a/R/stan.R b/R/stan.R index 0eece9fb6..c5e000c66 100644 --- a/R/stan.R +++ b/R/stan.R @@ -86,10 +86,10 @@ fit_model <- function(args, id = "stan") { future = args$future, max_execution_time = args$max_execution_time, id = id ) - } else if (args$method %in% c("vb", "laplace")) { + } else if (args$method %in% c("vb", "laplace", "pathfinder")) { fit <- fit_model_approximate(args, id = id) } else { - stop("args$method unknown") + stop("method ", args$method, " unknown") } return(fit) } diff --git a/man/fit_model_with_vb.Rd b/man/fit_model_approximate.Rd similarity index 81% rename from man/fit_model_with_vb.Rd rename to man/fit_model_approximate.Rd index 262b2b435..e5a2e9fc9 100644 --- a/man/fit_model_with_vb.Rd +++ b/man/fit_model_approximate.Rd @@ -1,10 +1,10 @@ % Generated by roxygen2: do not edit by hand % Please edit documentation in R/estimate_infections.R -\name{fit_model_with_vb} -\alias{fit_model_with_vb} -\title{Fit a Stan Model using Variational Inference} +\name{fit_model_approximate} +\alias{fit_model_approximate} +\title{Fit a Stan Model using an approximate method} \usage{ -fit_model_with_vb(args, future = FALSE, id = "stan") +fit_model_approximate(args, future = FALSE, id = "stan") } \arguments{ \item{args}{List of stan arguments.} diff --git a/man/fit_model_with_laplace.Rd b/man/fit_model_with_laplace.Rd deleted file mode 100644 index eab30bcda..000000000 --- a/man/fit_model_with_laplace.Rd +++ /dev/null @@ -1,21 +0,0 @@ -% Generated by roxygen2: do not edit by hand -% Please edit documentation in R/estimate_infections.R -\name{fit_model_with_laplace} -\alias{fit_model_with_laplace} -\title{Fit a Stan Model using the Laplace Algorithm} -\usage{ -fit_model_with_laplace(args, id = "stan") -} -\arguments{ -\item{args}{List of stan arguments.} - -\item{id}{A character string used to assign logging information on error. -Used by \code{\link[=regional_epinow]{regional_epinow()}} to assign errors to regions. Alter the default to -run with error catching.} -} -\value{ -A stan model object -} -\description{ -\ifelse{html}{\href{https://lifecycle.r-lib.org/articles/stages.html#maturing}{\figure{lifecycle-maturing.svg}{options: alt='[Maturing]'}}}{\strong{[Maturing]}} -} diff --git a/man/stan_laplace_opts.Rd b/man/stan_laplace_opts.Rd index 7bb5b72d5..9a7e5a99e 100644 --- a/man/stan_laplace_opts.Rd +++ b/man/stan_laplace_opts.Rd @@ -4,12 +4,15 @@ \alias{stan_laplace_opts} \title{Stan Laplace algorithm Options} \usage{ -stan_laplace_opts(backend = "cmdstanr", ...) +stan_laplace_opts(backend = "cmdstanr", trials = 10, ...) } \arguments{ \item{backend}{Character string indicating the backend to use for fitting stan models. Supported arguments are "rstan" (default) or "cmdstanr".} +\item{trials}{Numeric, defaults to 10. Number of attempts to use +rstan::vb()] before failing.} + \item{...}{Additional parameters to pass to \code{\link[cmdstanr:model-method-laplace]{cmdstanr::laplace()}}.} } \value{ diff --git a/man/stan_opts.Rd b/man/stan_opts.Rd index daaa4ae23..bcbacdedc 100644 --- a/man/stan_opts.Rd +++ b/man/stan_opts.Rd @@ -26,7 +26,9 @@ When using multiple chains iterations per chain is samples / chains.} \item{method}{A character string, defaulting to sampling. Currently supports MCMC sampling ("sampling") or approximate posterior sampling via -variational inference ("vb").} +variational inference ("vb") and, if the "cmdstanr" backend is used, +approximate posterior sampling with the laplaces algorithm ("laplace") or +pathfinder ("pathfinder").} \item{backend}{Character string indicating the backend to use for fitting stan models. Supported arguments are "rstan" (default) or "cmdstanr".} diff --git a/man/stan_pathfinder_opts.Rd b/man/stan_pathfinder_opts.Rd new file mode 100644 index 000000000..8f6848370 --- /dev/null +++ b/man/stan_pathfinder_opts.Rd @@ -0,0 +1,30 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/opts.R +\name{stan_pathfinder_opts} +\alias{stan_pathfinder_opts} +\title{Stan pathfinder algorithm Options} +\usage{ +stan_pathfinder_opts(backend = "cmdstanr", samples = 2000, trials = 10, ...) +} +\arguments{ +\item{backend}{Character string indicating the backend to use for fitting +stan models. Supported arguments are "rstan" (default) or "cmdstanr".} + +\item{samples}{Numeric, default 2000. Overall number of posterior samples. +When using multiple chains iterations per chain is samples / chains.} + +\item{trials}{Numeric, defaults to 10. Number of attempts to use +rstan::vb()] before failing.} + +\item{...}{Additional parameters to pass to \code{\link[cmdstanr:model-method-laplace]{cmdstanr::laplace()}}.} +} +\value{ +A list of arguments to pass to \code{\link[cmdstanr:model-method-laplace]{cmdstanr::laplace()}}. +} +\description{ +\ifelse{html}{\href{https://lifecycle.r-lib.org/articles/stages.html#stable}{\figure{lifecycle-stable.svg}{options: alt='[Stable]'}}}{\strong{[Stable]}} Defines a list specifying the +arguments passed to \code{\link[cmdstanr:model-method-laplace]{cmdstanr::laplace()}}. +} +\examples{ +stan_laplace_opts() +} diff --git a/tests/testthat/test-epinow.R b/tests/testthat/test-epinow.R index acad3e89f..a8ff6c106 100644 --- a/tests/testthat/test-epinow.R +++ b/tests/testthat/test-epinow.R @@ -85,6 +85,29 @@ test_that("epinow produces expected output when run with the ) }) +test_that("epinow produces expected output when run with the + pathfinder algorithm", { + skip_on_os("windows") + output <- capture.output(suppressMessages(suppressWarnings( + out <- epinow( + reported_cases = reported_cases, + generation_time = generation_time_opts(example_generation_time), + delays = delay_opts(example_incubation_period + reporting_delay), + stan = stan_opts(method = "pathfinder", backend = "cmdstanr"), + logs = NULL, verbose = FALSE + ) + ))) + expect_equal(names(out), expected_out) + df_non_zero(out$estimates$samples) + df_non_zero(out$estimates$summarised) + df_non_zero(out$estimated_reported_cases$samples) + df_non_zero(out$estimated_reported_cases$summarised) + df_non_zero(out$summary) + expect_equal( + names(out$plots), c("infections", "reports", "R", "growth_rate", "summary") + ) +}) + test_that("epinow runs without error when saving to disk", { expect_null(suppressWarnings(epinow( reported_cases = reported_cases,