diff --git a/NAMESPACE b/NAMESPACE index 1fd427ff1..04ac3a520 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -66,8 +66,11 @@ export(forecast_secondary) export(gamma_dist_def) export(generation_time_opts) export(get_dist) +export(get_distribution) export(get_generation_time) export(get_incubation_period) +export(get_parameters) +export(get_pmf) export(get_raw_result) export(get_regional_results) export(get_regions) @@ -77,6 +80,7 @@ export(growth_to_R) export(lognorm_dist_def) export(make_conf) export(map_prob_change) +export(new_dist_spec) export(obs_opts) export(opts_list) export(plot_estimates) diff --git a/NEWS.md b/NEWS.md index b178d82c7..df8c3fa15 100644 --- a/NEWS.md +++ b/NEWS.md @@ -14,6 +14,7 @@ * A new `simulate_infections` function has been added that can be used to simulate from the model from given initial conditions and parameters. By @sbfnk in #557 and reviewed by @jamesmbaazam. * The function `init_cumulative_fit()` has been deprecated. By @jamesmbaazam in #541 and reviewed by @sbfnk. * The interface to generating delay distributions has been completely overhauled. Instead of calling `dist_spec()` users now specify distributions using functions that represent the available distributions, i.e. `LogNormal()`, `Gamma()` and `Fixed()`. Uncertainty is specified using calls of the same nature, to `Normal()`. More information on the underlying design can be found in `inst/dev/design_dist.md` By @sbfnk in #504 and reviewed by @seabbs. +* The accessor functions `get_parameters()`, `get_pmf()`, and `get_distribution()` have been added to extract elements of a object. By @sbfnk in #646 and reviewed by @jamesmbaazam. * The functions `sample_approx_dist()`, `report_cases()`, and `adjust_infection_reports()` have been deprecated as the functionality they provide can now be achieved with `simulate_secondary()`. See #597 by @jamesmbaazam and reviewed by @sbfnk. ## Documentation diff --git a/R/create.R b/R/create.R index 55be17600..d3e146334 100644 --- a/R/create.R +++ b/R/create.R @@ -82,9 +82,7 @@ create_clean_reported_cases <- function(data, horizon = 0, #' @description `r lifecycle::badge("stable")` #' Creates a complete data set without NA values and appropriate indices #' -#' @param cases; data frame with a column "confirm" that may contain NA values -#' @param burn_in; integer (default 0). Number of days to remove from the -#' start of the time series be filtered out. +#' @param cases data frame with a column "confirm" that may contain NA values #' #' @return A data frame without NA values, with two columns: confirm (number) #' @importFrom data.table setDT diff --git a/R/deprecated.R b/R/deprecated.R index 5aa2c6f5d..73dec0a3d 100644 --- a/R/deprecated.R +++ b/R/deprecated.R @@ -273,8 +273,8 @@ dist_spec <- function(distribution = c( mean = Normal(mean, mean_sd), sd = Normal(sd, sd_sd) ) - params_mean <- vapply(temp_dist[[1]]$parameters, mean, numeric(1)) - params_sd <- vapply(temp_dist[[1]]$parameters, sd_dist, numeric(1)) + params_mean <- vapply(get_parameters(temp_dist), mean, numeric(1)) + params_sd <- vapply(get_parameters(temp_dist), sd_dist, numeric(1)) } else if (distribution == "normal") { params_mean <- c(mean = mean, sd = sd) params_sd <- c(mean = mean_sd, sd = sd_sd) diff --git a/R/dist_spec.R b/R/dist_spec.R index 281fb5a08..6efd2f782 100644 --- a/R/dist_spec.R +++ b/R/dist_spec.R @@ -582,10 +582,10 @@ print.dist_spec <- function(x, ...) { } else if (x[[i]]$distribution == "fixed") { ## fixed cat(indent_str, "- fixed value:\n", sep = "") - if (is.numeric(x[[i]]$parameters$value)) { - cat(indent_str, " ", x[[i]]$parameters$value, "\n", sep = "") + if (is.numeric(get_parameters(x, i)$value)) { + cat(indent_str, " ", get_parameters(x, i)$value, "\n", sep = "") } else { - .print.dist_spec(x[[i]]$parameters$value, indent = indent + 4) + .print.dist_spec(get_parameters(x, i)$value, indent = indent + 4) } } else { ## parametric @@ -595,18 +595,18 @@ print.dist_spec <- function(x, ...) { } cat(":\n") ## loop over natural parameters and print - for (param in names(x[[i]]$parameters)) { + for (param in names(get_parameters(x, i))) { cat( indent_str, " ", param, ":\n", sep = "" ) - if (is.numeric(x[[i]]$parameters[[param]])) { + if (is.numeric(get_parameters(x, i)[[param]])) { cat( indent_str, " ", - signif(x[[i]]$parameters[[param]], digits = 2), "\n", + signif(get_parameters(x, i)[[param]], digits = 2), "\n", sep = "" ) } else { - .print.dist_spec(x[[i]]$parameters[[param]], indent = indent + 4) + .print.dist_spec(get_parameters(x, i)[[param]], indent = indent + 4) } } } @@ -654,12 +654,12 @@ plot.dist_spec <- function(x, ...) { for (i in seq_along(x)) { if (x[[i]]$distribution == "nonparametric") { # Fixed distribution - pmf <- x[[i]]$pmf + pmf <- get_pmf(x, i) dist_name <- paste0("Nonparametric", " (ID: ", i, ")") } else { # Uncertain distribution c_dist <- discretise(fix_dist(extract_single_dist(x, i))) - pmf <- c_dist[[1]]$pmf + pmf <- get_pmf(c_dist) dist_name <- paste0( ifelse(is.na(dist_sd[i]), "Uncertain ", ""), x[[i]]$distribution, " (ID: ", i, ")" @@ -951,14 +951,12 @@ extract_params <- function(params, distribution) { #' @inheritParams extract_params #' @importFrom purrr walk #' @return A `dist_spec` of the given specification. -#' @keywords internal +#' @export #' @examples -#' \dontrun{ #' new_dist_spec( #' params = list(mean = 2, sd = 1, max = Inf), #' distribution = "normal" #' ) -#' } new_dist_spec <- function(params, distribution) { if (distribution == "nonparametric") { ## nonparametric distribution @@ -968,8 +966,12 @@ new_dist_spec <- function(params, distribution) { ) } else { ## process min/max first - max <- params$max - params$max <- NULL + if (is.null(params$max)) { + max <- Inf + } else { + max <- params$max + params$max <- NULL + } ## extract parameters and convert all to dist_spec params <- extract_params(params, distribution) ## fixed distribution @@ -1100,3 +1102,86 @@ convert_to_natural <- function(params, distribution) { } return(params) } + +##' Perform checks for `` `get_...` functions +##' +##' @param x A ``. +##' @param id Integer; the id of the distribution to get parameters of (if x is +##' a composite distribution). If `x` is a single distribution this is ignored +##' and can be left as `NULL`. +##' @return The id to use. +##' @keywords internal +##' @author Sebastian Funk +get_dist_spec_id <- function(x, id) { + if (!is.null(id) && id > length(x)) { + stop( + "`id` can't be greater than the number of distributions (", length(x), + ")." + ) + } + if (length(x) > 1) { + if (is.null(id)) { + stop("`id` must be specified when `x` is a composite distribution.") + } + } else { + id <- 1 + } + return(id) +} + +##' Get parameters of a parametric distribution +##' +##' @inheritParams get_dist_spec_id +##' @description `r lifecycle::badge("experimental")` +##' @return A list of parameters of the distribution. +##' @export +##' @examples +##' dist <- Gamma(shape = 3, rate = 2) +##' get_parameters(dist) +get_parameters <- function(x, id = NULL) { + if (!is(x, "dist_spec")) { + stop("Can only get parameters of a .") + } + id <- get_dist_spec_id(x, id) + if (x[[id]]$distribution == "nonparametric") { + stop("Cannot get parameters of a nonparametric distribution.") + } + return(x[[id]]$parameters) +} + +##' Get the probability mass function of a nonparametric distribution +##' +##' @inheritParams get_dist_spec_id +##' @description `r lifecycle::badge("experimental")` +##' @return The pmf of the distribution +##' @export +##' @examples +##' dist <- discretise(Gamma(shape = 3, rate = 2, max = 10)) +##' get_pmf(dist) +get_pmf <- function(x, id = NULL) { + if (!is(x, "dist_spec")) { + stop("Can only get pmf of a .") + } + id <- get_dist_spec_id(x, id) + if (x[[id]]$distribution != "nonparametric") { + stop("Cannot get pmf of a parametric distribution.") + } + return(x[[id]]$pmf) +} + +##' Get the distribution of a [dist_spec()] +##' +##' @inheritParams get_dist_spec_id +##' @description `r lifecycle::badge("experimental")` +##' @return A character string naming the distribution (or "nonparametric") +##' @export +##' @examples +##' dist <- Gamma(shape = 3, rate = 2, max = 10) +##' get_distribution(dist) +get_distribution <- function(x, id = NULL) { + if (!is(x, "dist_spec")) { + stop("Can only get distribution of a .") + } + id <- get_dist_spec_id(x, id) + return(x[[id]]$distribution) +} diff --git a/R/estimate_truncation.R b/R/estimate_truncation.R index f16fc807c..f09ff5383 100644 --- a/R/estimate_truncation.R +++ b/R/estimate_truncation.R @@ -105,7 +105,7 @@ #' # illustrative purposes only. #' out <- epinow( #' example_truncated[[5]], -#' truncation = est$dist +#' truncation = trunc_opts(est$dist) #' ) #' plot(out) #' options(old_opts) @@ -291,9 +291,12 @@ estimate_truncation <- function(data, max_truncation, trunc_max = 10, parameters <- purrr::map(seq_along(params_mean), function(id) { Normal(params_mean[id], params_sd[id]) }) - names(parameters) <- natural_params(truncation[[1]]$distribution) - out$dist <- truncation - out$dist[[1]]$parameters <- parameters + names(parameters) <- natural_params(get_distribution(truncation)) + parameters$max <- max(truncation) + out$dist <- new_dist_spec( + params = parameters, + distribution = get_distribution(truncation) + ) # summarise reconstructed observations recon_obs <- extract_stan_param(fit, "recon_obs", diff --git a/_pkgdown.yml b/_pkgdown.yml index 2fa85415d..2abaf1338 100644 --- a/_pkgdown.yml +++ b/_pkgdown.yml @@ -126,7 +126,11 @@ reference: - apply_tolerance - collapse - discretise - - contains("dist") + - contains("_dist") + - contains("dist_") + - get_parameters + - get_pmf + - get_distribution - title: Simulate desc: Functions to help with simulating data or mapping to reported cases contents: diff --git a/data-raw/truncated.R b/data-raw/truncated.R index 30c1e7356..ae552d781 100644 --- a/data-raw/truncated.R +++ b/data-raw/truncated.R @@ -11,7 +11,7 @@ library("EpiNow2") #' @keywords internal apply_truncation <- function(index, data, dist) { set.seed(index) - if (dist[[1]]$distribution == 0) { + if (get_distribution(dist) == "lognormal") { dfunc <- dlnorm } else { dfunc <- dgamma @@ -20,12 +20,12 @@ apply_truncation <- function(index, data, dist) { dfunc( seq_len(max(dist) + 1), rnorm(1, - dist[[1]]$parameters$meanlog[[1]]$parameters$mean, - dist[[1]]$parameters$meanlog[[1]]$parameters$sd + get_parameters(get_parameters(dist)$meanlog)$mean, + get_parameters(get_parameters(dist)$meanlog)$sd ), rnorm(1, - dist[[1]]$parameters$sdlog[[1]]$parameters$mean, - dist[[1]]$parameters$sdlog[[1]]$parameters$sd + get_parameters(get_parameters(dist)$sdlog)$mean, + get_parameters(get_parameters(dist)$sdlog)$sd ) ) ) diff --git a/man/create_complete_cases.Rd b/man/create_complete_cases.Rd index 9d98bfd11..46e5c4a76 100644 --- a/man/create_complete_cases.Rd +++ b/man/create_complete_cases.Rd @@ -7,10 +7,7 @@ create_complete_cases(cases) } \arguments{ -\item{cases;}{data frame with a column "confirm" that may contain NA values} - -\item{burn_in;}{integer (default 0). Number of days to remove from the -start of the time series be filtered out.} +\item{cases}{data frame with a column "confirm" that may contain NA values} } \value{ A data frame without NA values, with two columns: confirm (number) diff --git a/man/estimate_truncation.Rd b/man/estimate_truncation.Rd index 5d5220f23..22b90a2ad 100644 --- a/man/estimate_truncation.Rd +++ b/man/estimate_truncation.Rd @@ -142,7 +142,7 @@ plot(est) # illustrative purposes only. out <- epinow( example_truncated[[5]], - truncation = est$dist + truncation = trunc_opts(est$dist) ) plot(out) options(old_opts) diff --git a/man/get_dist_spec_id.Rd b/man/get_dist_spec_id.Rd new file mode 100644 index 000000000..06bd0624d --- /dev/null +++ b/man/get_dist_spec_id.Rd @@ -0,0 +1,25 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/dist_spec.R +\name{get_dist_spec_id} +\alias{get_dist_spec_id} +\title{Perform checks for \verb{} \code{get_...} functions} +\usage{ +get_dist_spec_id(x, id) +} +\arguments{ +\item{x}{A \verb{}.} + +\item{id}{Integer; the id of the distribution to get parameters of (if x is +a composite distribution). If \code{x} is a single distribution this is ignored +and can be left as \code{NULL}.} +} +\value{ +The id to use. +} +\description{ +Perform checks for \verb{} \code{get_...} functions +} +\author{ +Sebastian Funk +} +\keyword{internal} diff --git a/man/get_distribution.Rd b/man/get_distribution.Rd new file mode 100644 index 000000000..904cefe3c --- /dev/null +++ b/man/get_distribution.Rd @@ -0,0 +1,25 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/dist_spec.R +\name{get_distribution} +\alias{get_distribution} +\title{Get the distribution of a \code{\link[=dist_spec]{dist_spec()}}} +\usage{ +get_distribution(x, id = NULL) +} +\arguments{ +\item{x}{A \verb{}.} + +\item{id}{Integer; the id of the distribution to get parameters of (if x is +a composite distribution). If \code{x} is a single distribution this is ignored +and can be left as \code{NULL}.} +} +\value{ +A character string naming the distribution (or "nonparametric") +} +\description{ +\ifelse{html}{\href{https://lifecycle.r-lib.org/articles/stages.html#experimental}{\figure{lifecycle-experimental.svg}{options: alt='[Experimental]'}}}{\strong{[Experimental]}} +} +\examples{ +dist <- Gamma(shape = 3, rate = 2, max = 10) +get_distribution(dist) +} diff --git a/man/get_parameters.Rd b/man/get_parameters.Rd new file mode 100644 index 000000000..dd6dc034d --- /dev/null +++ b/man/get_parameters.Rd @@ -0,0 +1,25 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/dist_spec.R +\name{get_parameters} +\alias{get_parameters} +\title{Get parameters of a parametric distribution} +\usage{ +get_parameters(x, id = NULL) +} +\arguments{ +\item{x}{A \verb{}.} + +\item{id}{Integer; the id of the distribution to get parameters of (if x is +a composite distribution). If \code{x} is a single distribution this is ignored +and can be left as \code{NULL}.} +} +\value{ +A list of parameters of the distribution. +} +\description{ +\ifelse{html}{\href{https://lifecycle.r-lib.org/articles/stages.html#experimental}{\figure{lifecycle-experimental.svg}{options: alt='[Experimental]'}}}{\strong{[Experimental]}} +} +\examples{ +dist <- Gamma(shape = 3, rate = 2) +get_parameters(dist) +} diff --git a/man/get_pmf.Rd b/man/get_pmf.Rd new file mode 100644 index 000000000..4f1090bb3 --- /dev/null +++ b/man/get_pmf.Rd @@ -0,0 +1,25 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/dist_spec.R +\name{get_pmf} +\alias{get_pmf} +\title{Get the probability mass function of a nonparametric distribution} +\usage{ +get_pmf(x, id = NULL) +} +\arguments{ +\item{x}{A \verb{}.} + +\item{id}{Integer; the id of the distribution to get parameters of (if x is +a composite distribution). If \code{x} is a single distribution this is ignored +and can be left as \code{NULL}.} +} +\value{ +The pmf of the distribution +} +\description{ +\ifelse{html}{\href{https://lifecycle.r-lib.org/articles/stages.html#experimental}{\figure{lifecycle-experimental.svg}{options: alt='[Experimental]'}}}{\strong{[Experimental]}} +} +\examples{ +dist <- discretise(Gamma(shape = 3, rate = 2, max = 10)) +get_pmf(dist) +} diff --git a/man/new_dist_spec.Rd b/man/new_dist_spec.Rd index b0a3fb553..1c649873f 100644 --- a/man/new_dist_spec.Rd +++ b/man/new_dist_spec.Rd @@ -20,11 +20,8 @@ This will convert all parameters to natural parameters before generating a \code{dist_spec}. If they have uncertainty this will be done using sampling. } \examples{ -\dontrun{ new_dist_spec( params = list(mean = 2, sd = 1, max = Inf), distribution = "normal" ) } -} -\keyword{internal} diff --git a/tests/testthat/test-dist.R b/tests/testthat/test-dist.R index 084f34859..826bb04e6 100644 --- a/tests/testthat/test-dist.R +++ b/tests/testthat/test-dist.R @@ -7,8 +7,8 @@ test_that("distributions are the same in R and stan", { lognormal_dist <- do.call(LogNormal, args) gamma_dist <- do.call(Gamma, args) - lognormal_params <- unname(as.numeric(lognormal_dist[[1]]$parameters)) - gamma_params <- unname(as.numeric(gamma_dist[[1]]$parameters)) + lognormal_params <- unname(as.numeric(get_parameters(lognormal_dist))) + gamma_params <- unname(as.numeric(get_parameters(gamma_dist))) pmf_r_lognormal <- discretise(lognormal_dist)[[1]]$pmf pmf_r_gamma <- discretise(gamma_dist)[[1]]$pmf diff --git a/tests/testthat/test-dist_spec.R b/tests/testthat/test-dist_spec.R index 992e746a2..c961c0c3e 100644 --- a/tests/testthat/test-dist_spec.R +++ b/tests/testthat/test-dist_spec.R @@ -1,4 +1,3 @@ - test_that("dist_spec returns correct output for fixed lognormal distribution", { result <- discretise(LogNormal(meanlog = 5, sdlog = 1, max = 19)) expect_null(result[[1]]$parameters) @@ -219,16 +218,16 @@ test_that("composite delay distributions can be disassembled", { test_that("delay distributions can be specified in different ways", { expect_equal( - unname(as.numeric(LogNormal(mean = 4, sd = 1)[[1]]$parameters)), + unname(as.numeric(get_parameters(LogNormal(mean = 4, sd = 1)))), c(1.4, 0.25), tolerance = 0.1 ) expect_equal( - round(discretise(LogNormal(mean = 4, sd = 1, max = 10))[[1]]$pmf, 2), + round(get_pmf(discretise(LogNormal(mean = 4, sd = 1, max = 10))), 2), c(0.00, 0.00, 0.07, 0.27, 0.35, 0.21, 0.07, 0.02, 0.00, 0.00, 0.00) ) expect_equal( - unname(as.numeric(Gamma(mean = 4, sd = 1)[[1]]$parameters)), + unname(as.numeric(get_parameters(Gamma(mean = 4, sd = 1)))), c(16, 4), tolerance = 0.1 ) @@ -238,37 +237,62 @@ test_that("delay distributions can be specified in different ways", { ) expect_equal( unname(as.numeric( - Gamma( - shape = Normal(16, 2), rate = Normal(4, 1) - )[[1]]$parameters$shape[[1]]$parameters + get_parameters(get_parameters( + c( + Gamma( + shape = Normal(12, 3), rate = Normal(3, 0.5) + ), + Gamma( + shape = Normal(16, 2), rate = Normal(4, 1) + ) + ), 2 + )$shape) )), c(16, 2) ) expect_equal( unname(as.numeric( - Gamma( - shape = Normal(16, 2), rate = Normal(4, 1) - )[[1]]$parameters$rate[[1]]$parameters + get_parameters(get_parameters( + Gamma( + shape = Normal(16, 2), rate = Normal(4, 1) + ) + )$rate) )), c(4, 1) ) expect_equal( - unname(as.numeric(Normal(mean = 4, sd = 1)[[1]]$parameters)), c(4, 1) + unname(as.numeric(get_parameters(Normal(mean = 4, sd = 1)))), c(4, 1) ) expect_equal( round(discretise(Normal(mean = 4, sd = 1, max = 5))[[1]]$pmf, 2), c(0.00, 0.01, 0.09, 0.26, 0.38, 0.26) ) expect_equal(discretise(Fixed(value = 3))[[1]]$pmf, c(0, 0, 0, 1)) - expect_equal(Fixed(value = 3.5)[[1]]$parameters$value, 3.5) + expect_equal(get_parameters(Fixed(value = 3.5))$value, 3.5) expect_equal( - NonParametric(c(0.1, 0.3, 0.2, 0.4))[[1]]$pmf, + get_pmf(NonParametric(c(0.1, 0.3, 0.2, 0.4))), c(0.1, 0.3, 0.2, 0.4) ) expect_equal( - round(NonParametric(c(0.1, 0.3, 0.2, 0.1, 0.1))[[1]]$pmf, 2), + round(get_pmf(NonParametric(c(0.1, 0.3, 0.2, 0.1, 0.1))), 2), c(0.12, 0.37, 0.25, 0.12, 0.12) ) + expect_equal( + get_distribution(NonParametric(c(0.1, 0.3, 0.2, 0.1, 0.1))), + "nonparametric" + )}) + +test_that("get functions report errors", { + expect_error(get_parameters("test"), "only get parameters") + expect_error(get_distribution(Gamma(mean = 4, sd = 1), 2), "can't be greater") + expect_error(get_pmf(Gamma(mean = 4, sd = 1)), "parametric") + expect_error( + get_parameters(NonParametric(c(0.1, 0.3, 0.2, 0.1, 0.1))), + "nonparametric" + ) + expect_error(get_parameters(c( + Gamma(mean = 4, sd = 1), Gamma(mean = 4, sd = 1) + )), "must be specified") }) test_that("deprecated functions are deprecated", {