Skip to content

Commit

Permalink
add pathfinder algorithm
Browse files Browse the repository at this point in the history
  • Loading branch information
sbfnk committed Mar 22, 2024
1 parent d0c6452 commit f7055f3
Show file tree
Hide file tree
Showing 9 changed files with 107 additions and 31 deletions.
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ export(simulate_infections)
export(simulate_secondary)
export(stan_laplace_opts)
export(stan_opts)
export(stan_pathfinder_opts)
export(stan_sampling_opts)
export(stan_vb_opts)
export(summarise_key_measures)
Expand Down
42 changes: 40 additions & 2 deletions R/opts.R
Original file line number Diff line number Diff line change
Expand Up @@ -798,6 +798,36 @@ stan_laplace_opts <- function(backend = "cmdstanr",
return(opts)
}

#' Stan pathfinder algorithm Options
#'
#' @description `r lifecycle::badge("stable")` Defines a list specifying the
#' arguments passed to [cmdstanr::laplace()].
#'
#' @inheritParams stan_opts
#' @inheritParams stan_vb_opts
#' @param ... Additional parameters to pass to [cmdstanr::laplace()].
#' @return A list of arguments to pass to [cmdstanr::laplace()].
#' @export
#' @examples
#' stan_laplace_opts()
stan_pathfinder_opts <- function(backend = "cmdstanr",
samples = 2000,
trials = 10,
...) {
if (backend != "cmdstanr") {
stop(
"The pathfinder algorithm is only available with the \"cmdstanr\" ",
"backend."
)
}
opts <- list(
trials = trials,
draws = samples
)
opts <- c(opts, ...)
return(opts)
}

#' Rstan Options
#'
#' @description `r lifecycle::badge("deprecated")`
Expand Down Expand Up @@ -858,7 +888,9 @@ rstan_opts <- function(object = NULL,
#'
#' @param method A character string, defaulting to sampling. Currently supports
#' MCMC sampling ("sampling") or approximate posterior sampling via
#' variational inference ("vb").
#' variational inference ("vb") and, if the "cmdstanr" backend is used,
#' approximate posterior sampling with the laplaces algorithm ("laplace") or
#' pathfinder ("pathfinder").
#'
#' @param backend Character string indicating the backend to use for fitting
#' stan models. Supported arguments are "rstan" (default) or "cmdstanr".
Expand Down Expand Up @@ -904,7 +936,9 @@ stan_opts <- function(object = NULL,
init_fit = NULL,
return_fit = TRUE,
...) {
method <- arg_match(method, values = c("sampling", "vb", "laplace"))
method <- arg_match(
method, values = c("sampling", "vb", "laplace", "pathfinder")
)
backend <- arg_match(backend, values = c("rstan", "cmdstanr"))
if (backend == "cmdstanr" && !requireNamespace("cmdstanr", quietly = TRUE)) {
stop(
Expand Down Expand Up @@ -934,6 +968,10 @@ stan_opts <- function(object = NULL,
opts <- c(
opts, stan_laplace_opts(backend = backend, ...)
)
} else if (method == "pathfinder") {
opts <- c(
opts, stan_pathfinder_opts(samples = samples, backend = backend, ...)
)
}

if (!is.null(init_fit)) {
Expand Down
4 changes: 2 additions & 2 deletions R/stan.R
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,10 @@ fit_model <- function(args, id = "stan") {
future = args$future,
max_execution_time = args$max_execution_time, id = id
)
} else if (args$method %in% c("vb", "laplace")) {
} else if (args$method %in% c("vb", "laplace", "pathfinder")) {
fit <- fit_model_approximate(args, id = id)
} else {
stop("args$method unknown")
stop("method ", args$method, " unknown")
}
return(fit)
}
8 changes: 4 additions & 4 deletions man/fit_model_with_vb.Rd → man/fit_model_approximate.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

21 changes: 0 additions & 21 deletions man/fit_model_with_laplace.Rd

This file was deleted.

5 changes: 4 additions & 1 deletion man/stan_laplace_opts.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 3 additions & 1 deletion man/stan_opts.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

30 changes: 30 additions & 0 deletions man/stan_pathfinder_opts.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

23 changes: 23 additions & 0 deletions tests/testthat/test-epinow.R
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,29 @@ test_that("epinow produces expected output when run with the
)
})

test_that("epinow produces expected output when run with the
pathfinder algorithm", {
skip_on_os("windows")
output <- capture.output(suppressMessages(suppressWarnings(
out <- epinow(
reported_cases = reported_cases,
generation_time = generation_time_opts(example_generation_time),
delays = delay_opts(example_incubation_period + reporting_delay),
stan = stan_opts(method = "pathfinder", backend = "cmdstanr"),
logs = NULL, verbose = FALSE
)
)))
expect_equal(names(out), expected_out)
df_non_zero(out$estimates$samples)
df_non_zero(out$estimates$summarised)
df_non_zero(out$estimated_reported_cases$samples)
df_non_zero(out$estimated_reported_cases$summarised)
df_non_zero(out$summary)
expect_equal(
names(out$plots), c("infections", "reports", "R", "growth_rate", "summary")
)
})

test_that("epinow runs without error when saving to disk", {
expect_null(suppressWarnings(epinow(
reported_cases = reported_cases,
Expand Down

0 comments on commit f7055f3

Please sign in to comment.