From 6672a1517338164dbc1db452607adc186ae3f78f Mon Sep 17 00:00:00 2001 From: Github Actions CI Date: Fri, 17 Feb 2023 12:30:06 -0500 Subject: [PATCH 01/24] Create simple_ensemble.R based on Issue #1 draft code --- R/simple_ensemble.R | 78 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 78 insertions(+) create mode 100644 R/simple_ensemble.R diff --git a/R/simple_ensemble.R b/R/simple_ensemble.R new file mode 100644 index 0000000..b8c02e4 --- /dev/null +++ b/R/simple_ensemble.R @@ -0,0 +1,78 @@ +#' Compute ensemble predictions by summarizing component predictions for each +#' combination of model task, output type, and type id. Supported output types +#' include `mean`, `median`, `quantile`, `cdf`, and `category`. +#' +#' @param predictions a `data.frame` with component model predictions. It should +#' have columns `team_abbr`, `model_abbr`, one column for each task id +#' variable, `output_type`, `output_type_id`, and `value`. +#' @param task_id_vars an optional `character` vector naming the columns of +#' `predictions` that correspond to task id variables. The default is `NULL`, +#' in which case the task id variables are looked up from the `hub_con` if one +#' is provided. If neither `task_id_vars` nor `hub_con` are provided, all +#' columns in `predictions` _other than_ `team_abbr`, `model_abbr`, +#' `output_type`, `output_id`, and `value` will be used as task id +#' variables. +#' @param hub_con an optional hub connection object; see `hubUtils::connect_hub` +#' @param weights an optional `data.frame` with component model weights. If +#' provided, it should have columns `team_name`, `model_abbr`, `weight`, +#' and optionally, additional columns corresponding to task id variables, +#' `output_type`, or `output_id`, if weights are specific to values of +#' those variables. The default is `NULL`, in which case an equally-weighted +#' ensemble is calculated. +#' @param agg_fun a function or character string name of a function to use for +#' aggregating component predictions into the ensemble prediction. The default +#' is `mean`, in which case the ensemble prediction is the simple average of +#' the component model predictions. The provided function should have an +#' argument `x` for the vector of numeric values to summarize, and for weighted +#' methods, an argument `w` with a numeric vector of weights. +#' @param agg_args a named list of any additional arguments that will be passed +#' to `agg_fun`. +#' @param team_abbr, model_abbr `character` strings with the name of the team +#' and model to use for the ensemble predictions. +#' +#' @return a data.frame with columns `team_abbr`, `model_abbr`, one column for +#' each task id variable, `output_type`, `output_id`, and `value`. + +simple_ensemble <- function(predictions, task_id_vars = NULL, hub_con = NULL, + weights = NULL, agg_fun = mean, agg_args = list(), + team_abbr = "Hub", model_abbr = "ensemble") { + + # require(matrixStats) + + if (is.null(task_id_vars) && is.null(hub_con)) { + temp <- colnames(predictions) + task_id_vars <- temp[!temp %in% c("team_abbr", "model_abbr", "output_type", "output_id", "value")] + } else if (is.null(task_id_vars)) { + # task_id variables looked up from `hub_con` + } + + col_names <- c("team_abbr", "model_abbr", task_id_vars, "output_type", "output_id", "value") + if ((length(predictions) == 0) || !all(names(predictions) %in% col_names)) { + stop("predictions did not have required columns", call. = FALSE) + } else if (!all(names(predictions) == col_names) && names(predictions) %in% col_names) { + predictions <- relocate(predictions, all_of(col_names)) + } + + if (!all(predictions$output_type %in% c("mean", "median", "quantile", "cdf", "category"))) # throw warning or error + + if (!all(names(weights) %in% c("team_abbr", "model_abbr", "weight"))) { + stop("weights did not have required columns", call. = FALSE) + } + + if (is.null(weights)) { + weights <- predictions %>% + distinct(team_abbr, model_abbr) %>% + mutate(weight = 1/n()) + } + + if (agg_fun == "mean") agg_fun = "weightedMean" + if (agg_fun == "median") agg_fun = "weightedMedian" + + + predictions %>% + dplyr::left_join(weights) %>% + dplyr::group_by(across(all_of(c(task_id_vars, "output_type", "output_id")))) %>% + dplyr::summarize(value = do.call(agg_fun, args = c(agg_args, list(x=value, w=weight)))) %>% + dplyr::mutate(team_abbr = team_abbr, model_abbr = model_abbr, .before = 1) + # do we want to have the horizon column before target? +} From 2a37506fcd613851428b2cad35933dfe4098e00e Mon Sep 17 00:00:00 2001 From: Github Actions CI Date: Fri, 17 Feb 2023 12:30:50 -0500 Subject: [PATCH 02/24] add example-data for function testing --- .../2022-10-01-simple_hub-baseline.csv | 25 ++ .../2022-10-08-simple_hub-baseline.csv | 277 ++++++++++++++++++ .../2022-10-08-team1-goodmodel.csv | 24 ++ 3 files changed, 326 insertions(+) create mode 100644 inst/example-data/2022-10-01-simple_hub-baseline.csv create mode 100644 inst/example-data/2022-10-08-simple_hub-baseline.csv create mode 100644 inst/example-data/2022-10-08-team1-goodmodel.csv diff --git a/inst/example-data/2022-10-01-simple_hub-baseline.csv b/inst/example-data/2022-10-01-simple_hub-baseline.csv new file mode 100644 index 0000000..82c1d62 --- /dev/null +++ b/inst/example-data/2022-10-01-simple_hub-baseline.csv @@ -0,0 +1,25 @@ +"origin_date","target","horizon","location","type","type_id","value" +2022-10-01,"wk inc flu hosp",1,"US","mean",NA,150 +2022-10-01,"wk inc flu hosp",1,"US","quantile","0.01",135 +2022-10-01,"wk inc flu hosp",1,"US","quantile","0.025",137 +2022-10-01,"wk inc flu hosp",1,"US","quantile","0.05",139 +2022-10-01,"wk inc flu hosp",1,"US","quantile","0.1",140 +2022-10-01,"wk inc flu hosp",1,"US","quantile","0.15",141 +2022-10-01,"wk inc flu hosp",1,"US","quantile","0.2",141 +2022-10-01,"wk inc flu hosp",1,"US","quantile","0.25",142 +2022-10-01,"wk inc flu hosp",1,"US","quantile","0.3",143 +2022-10-01,"wk inc flu hosp",1,"US","quantile","0.35",144 +2022-10-01,"wk inc flu hosp",1,"US","quantile","0.4",145 +2022-10-01,"wk inc flu hosp",1,"US","quantile","0.45",147 +2022-10-01,"wk inc flu hosp",1,"US","quantile","0.5",148 +2022-10-01,"wk inc flu hosp",1,"US","quantile","0.55",149 +2022-10-01,"wk inc flu hosp",1,"US","quantile","0.6",150 +2022-10-01,"wk inc flu hosp",1,"US","quantile","0.65",152 +2022-10-01,"wk inc flu hosp",1,"US","quantile","0.7",155 +2022-10-01,"wk inc flu hosp",1,"US","quantile","0.75",161 +2022-10-01,"wk inc flu hosp",1,"US","quantile","0.8",165 +2022-10-01,"wk inc flu hosp",1,"US","quantile","0.85",170 +2022-10-01,"wk inc flu hosp",1,"US","quantile","0.9",175 +2022-10-01,"wk inc flu hosp",1,"US","quantile","0.95",176 +2022-10-01,"wk inc flu hosp",1,"US","quantile","0.975",176 +2022-10-01,"wk inc flu hosp",1,"US","quantile","0.99",205 diff --git a/inst/example-data/2022-10-08-simple_hub-baseline.csv b/inst/example-data/2022-10-08-simple_hub-baseline.csv new file mode 100644 index 0000000..98ad58c --- /dev/null +++ b/inst/example-data/2022-10-08-simple_hub-baseline.csv @@ -0,0 +1,277 @@ +"origin_date","target","horizon","location","type","type_id","value" +2022-10-08,"wk inc flu hosp",1,"US","quantile","0.01",135 +2022-10-08,"wk inc flu hosp",1,"US","quantile","0.025",137 +2022-10-08,"wk inc flu hosp",1,"US","quantile","0.05",139 +2022-10-08,"wk inc flu hosp",1,"US","quantile","0.1",140 +2022-10-08,"wk inc flu hosp",1,"US","quantile","0.15",141 +2022-10-08,"wk inc flu hosp",1,"US","quantile","0.2",141 +2022-10-08,"wk inc flu hosp",1,"US","quantile","0.25",142 +2022-10-08,"wk inc flu hosp",1,"US","quantile","0.3",143 +2022-10-08,"wk inc flu hosp",1,"US","quantile","0.35",144 +2022-10-08,"wk inc flu hosp",1,"US","quantile","0.4",145 +2022-10-08,"wk inc flu hosp",1,"US","quantile","0.45",147 +2022-10-08,"wk inc flu hosp",1,"US","quantile","0.5",148 +2022-10-08,"wk inc flu hosp",1,"US","quantile","0.55",149 +2022-10-08,"wk inc flu hosp",1,"US","quantile","0.6",150 +2022-10-08,"wk inc flu hosp",1,"US","quantile","0.65",152 +2022-10-08,"wk inc flu hosp",1,"US","quantile","0.7",155 +2022-10-08,"wk inc flu hosp",1,"US","quantile","0.75",161 +2022-10-08,"wk inc flu hosp",1,"US","quantile","0.8",165 +2022-10-08,"wk inc flu hosp",1,"US","quantile","0.85",170 +2022-10-08,"wk inc flu hosp",1,"US","quantile","0.9",175 +2022-10-08,"wk inc flu hosp",1,"US","quantile","0.95",176 +2022-10-08,"wk inc flu hosp",1,"US","quantile","0.975",176 +2022-10-08,"wk inc flu hosp",1,"US","quantile","0.99",205 +2022-10-08,"wk inc flu hosp",2,"US","quantile","0.01",135 +2022-10-08,"wk inc flu hosp",2,"US","quantile","0.025",137 +2022-10-08,"wk inc flu hosp",2,"US","quantile","0.05",139 +2022-10-08,"wk inc flu hosp",2,"US","quantile","0.1",140 +2022-10-08,"wk inc flu hosp",2,"US","quantile","0.15",141 +2022-10-08,"wk inc flu hosp",2,"US","quantile","0.2",141 +2022-10-08,"wk inc flu hosp",2,"US","quantile","0.25",142 +2022-10-08,"wk inc flu hosp",2,"US","quantile","0.3",143 +2022-10-08,"wk inc flu hosp",2,"US","quantile","0.35",144 +2022-10-08,"wk inc flu hosp",2,"US","quantile","0.4",145 +2022-10-08,"wk inc flu hosp",2,"US","quantile","0.45",147 +2022-10-08,"wk inc flu hosp",2,"US","quantile","0.5",148 +2022-10-08,"wk inc flu hosp",2,"US","quantile","0.55",149 +2022-10-08,"wk inc flu hosp",2,"US","quantile","0.6",150 +2022-10-08,"wk inc flu hosp",2,"US","quantile","0.65",152 +2022-10-08,"wk inc flu hosp",2,"US","quantile","0.7",155 +2022-10-08,"wk inc flu hosp",2,"US","quantile","0.75",161 +2022-10-08,"wk inc flu hosp",2,"US","quantile","0.8",165 +2022-10-08,"wk inc flu hosp",2,"US","quantile","0.85",170 +2022-10-08,"wk inc flu hosp",2,"US","quantile","0.9",175 +2022-10-08,"wk inc flu hosp",2,"US","quantile","0.95",176 +2022-10-08,"wk inc flu hosp",2,"US","quantile","0.975",176 +2022-10-08,"wk inc flu hosp",2,"US","quantile","0.99",205 +2022-10-08,"wk inc flu hosp",3,"US","quantile","0.01",135 +2022-10-08,"wk inc flu hosp",3,"US","quantile","0.025",137 +2022-10-08,"wk inc flu hosp",3,"US","quantile","0.05",139 +2022-10-08,"wk inc flu hosp",3,"US","quantile","0.1",140 +2022-10-08,"wk inc flu hosp",3,"US","quantile","0.15",141 +2022-10-08,"wk inc flu hosp",3,"US","quantile","0.2",141 +2022-10-08,"wk inc flu hosp",3,"US","quantile","0.25",142 +2022-10-08,"wk inc flu hosp",3,"US","quantile","0.3",143 +2022-10-08,"wk inc flu hosp",3,"US","quantile","0.35",144 +2022-10-08,"wk inc flu hosp",3,"US","quantile","0.4",145 +2022-10-08,"wk inc flu hosp",3,"US","quantile","0.45",147 +2022-10-08,"wk inc flu hosp",3,"US","quantile","0.5",148 +2022-10-08,"wk inc flu hosp",3,"US","quantile","0.55",149 +2022-10-08,"wk inc flu hosp",3,"US","quantile","0.6",150 +2022-10-08,"wk inc flu hosp",3,"US","quantile","0.65",152 +2022-10-08,"wk inc flu hosp",3,"US","quantile","0.7",155 +2022-10-08,"wk inc flu hosp",3,"US","quantile","0.75",161 +2022-10-08,"wk inc flu hosp",3,"US","quantile","0.8",165 +2022-10-08,"wk inc flu hosp",3,"US","quantile","0.85",170 +2022-10-08,"wk inc flu hosp",3,"US","quantile","0.9",175 +2022-10-08,"wk inc flu hosp",3,"US","quantile","0.95",176 +2022-10-08,"wk inc flu hosp",3,"US","quantile","0.975",176 +2022-10-08,"wk inc flu hosp",3,"US","quantile","0.99",205 +2022-10-08,"wk inc flu hosp",4,"US","quantile","0.01",135 +2022-10-08,"wk inc flu hosp",4,"US","quantile","0.025",137 +2022-10-08,"wk inc flu hosp",4,"US","quantile","0.05",139 +2022-10-08,"wk inc flu hosp",4,"US","quantile","0.1",140 +2022-10-08,"wk inc flu hosp",4,"US","quantile","0.15",141 +2022-10-08,"wk inc flu hosp",4,"US","quantile","0.2",141 +2022-10-08,"wk inc flu hosp",4,"US","quantile","0.25",142 +2022-10-08,"wk inc flu hosp",4,"US","quantile","0.3",143 +2022-10-08,"wk inc flu hosp",4,"US","quantile","0.35",144 +2022-10-08,"wk inc flu hosp",4,"US","quantile","0.4",145 +2022-10-08,"wk inc flu hosp",4,"US","quantile","0.45",147 +2022-10-08,"wk inc flu hosp",4,"US","quantile","0.5",148 +2022-10-08,"wk inc flu hosp",4,"US","quantile","0.55",149 +2022-10-08,"wk inc flu hosp",4,"US","quantile","0.6",150 +2022-10-08,"wk inc flu hosp",4,"US","quantile","0.65",152 +2022-10-08,"wk inc flu hosp",4,"US","quantile","0.7",155 +2022-10-08,"wk inc flu hosp",4,"US","quantile","0.75",161 +2022-10-08,"wk inc flu hosp",4,"US","quantile","0.8",165 +2022-10-08,"wk inc flu hosp",4,"US","quantile","0.85",170 +2022-10-08,"wk inc flu hosp",4,"US","quantile","0.9",175 +2022-10-08,"wk inc flu hosp",4,"US","quantile","0.95",176 +2022-10-08,"wk inc flu hosp",4,"US","quantile","0.975",176 +2022-10-08,"wk inc flu hosp",4,"US","quantile","0.99",205 +2022-10-08,"wk inc flu hosp",1,"04","quantile","0.01",135 +2022-10-08,"wk inc flu hosp",1,"04","quantile","0.025",137 +2022-10-08,"wk inc flu hosp",1,"04","quantile","0.05",139 +2022-10-08,"wk inc flu hosp",1,"04","quantile","0.1",140 +2022-10-08,"wk inc flu hosp",1,"04","quantile","0.15",141 +2022-10-08,"wk inc flu hosp",1,"04","quantile","0.2",141 +2022-10-08,"wk inc flu hosp",1,"04","quantile","0.25",142 +2022-10-08,"wk inc flu hosp",1,"04","quantile","0.3",143 +2022-10-08,"wk inc flu hosp",1,"04","quantile","0.35",144 +2022-10-08,"wk inc flu hosp",1,"04","quantile","0.4",145 +2022-10-08,"wk inc flu hosp",1,"04","quantile","0.45",147 +2022-10-08,"wk inc flu hosp",1,"04","quantile","0.5",148 +2022-10-08,"wk inc flu hosp",1,"04","quantile","0.55",149 +2022-10-08,"wk inc flu hosp",1,"04","quantile","0.6",150 +2022-10-08,"wk inc flu hosp",1,"04","quantile","0.65",152 +2022-10-08,"wk inc flu hosp",1,"04","quantile","0.7",155 +2022-10-08,"wk inc flu hosp",1,"04","quantile","0.75",161 +2022-10-08,"wk inc flu hosp",1,"04","quantile","0.8",165 +2022-10-08,"wk inc flu hosp",1,"04","quantile","0.85",170 +2022-10-08,"wk inc flu hosp",1,"04","quantile","0.9",175 +2022-10-08,"wk inc flu hosp",1,"04","quantile","0.95",176 +2022-10-08,"wk inc flu hosp",1,"04","quantile","0.975",176 +2022-10-08,"wk inc flu hosp",1,"04","quantile","0.99",205 +2022-10-08,"wk inc flu hosp",2,"04","quantile","0.01",135 +2022-10-08,"wk inc flu hosp",2,"04","quantile","0.025",137 +2022-10-08,"wk inc flu hosp",2,"04","quantile","0.05",139 +2022-10-08,"wk inc flu hosp",2,"04","quantile","0.1",140 +2022-10-08,"wk inc flu hosp",2,"04","quantile","0.15",141 +2022-10-08,"wk inc flu hosp",2,"04","quantile","0.2",141 +2022-10-08,"wk inc flu hosp",2,"04","quantile","0.25",142 +2022-10-08,"wk inc flu hosp",2,"04","quantile","0.3",143 +2022-10-08,"wk inc flu hosp",2,"04","quantile","0.35",144 +2022-10-08,"wk inc flu hosp",2,"04","quantile","0.4",145 +2022-10-08,"wk inc flu hosp",2,"04","quantile","0.45",147 +2022-10-08,"wk inc flu hosp",2,"04","quantile","0.5",148 +2022-10-08,"wk inc flu hosp",2,"04","quantile","0.55",149 +2022-10-08,"wk inc flu hosp",2,"04","quantile","0.6",150 +2022-10-08,"wk inc flu hosp",2,"04","quantile","0.65",152 +2022-10-08,"wk inc flu hosp",2,"04","quantile","0.7",155 +2022-10-08,"wk inc flu hosp",2,"04","quantile","0.75",161 +2022-10-08,"wk inc flu hosp",2,"04","quantile","0.8",165 +2022-10-08,"wk inc flu hosp",2,"04","quantile","0.85",170 +2022-10-08,"wk inc flu hosp",2,"04","quantile","0.9",175 +2022-10-08,"wk inc flu hosp",2,"04","quantile","0.95",176 +2022-10-08,"wk inc flu hosp",2,"04","quantile","0.975",176 +2022-10-08,"wk inc flu hosp",2,"04","quantile","0.99",205 +2022-10-08,"wk inc flu hosp",3,"04","quantile","0.01",135 +2022-10-08,"wk inc flu hosp",3,"04","quantile","0.025",137 +2022-10-08,"wk inc flu hosp",3,"04","quantile","0.05",139 +2022-10-08,"wk inc flu hosp",3,"04","quantile","0.1",140 +2022-10-08,"wk inc flu hosp",3,"04","quantile","0.15",141 +2022-10-08,"wk inc flu hosp",3,"04","quantile","0.2",141 +2022-10-08,"wk inc flu hosp",3,"04","quantile","0.25",142 +2022-10-08,"wk inc flu hosp",3,"04","quantile","0.3",143 +2022-10-08,"wk inc flu hosp",3,"04","quantile","0.35",144 +2022-10-08,"wk inc flu hosp",3,"04","quantile","0.4",145 +2022-10-08,"wk inc flu hosp",3,"04","quantile","0.45",147 +2022-10-08,"wk inc flu hosp",3,"04","quantile","0.5",148 +2022-10-08,"wk inc flu hosp",3,"04","quantile","0.55",149 +2022-10-08,"wk inc flu hosp",3,"04","quantile","0.6",150 +2022-10-08,"wk inc flu hosp",3,"04","quantile","0.65",152 +2022-10-08,"wk inc flu hosp",3,"04","quantile","0.7",155 +2022-10-08,"wk inc flu hosp",3,"04","quantile","0.75",161 +2022-10-08,"wk inc flu hosp",3,"04","quantile","0.8",165 +2022-10-08,"wk inc flu hosp",3,"04","quantile","0.85",170 +2022-10-08,"wk inc flu hosp",3,"04","quantile","0.9",175 +2022-10-08,"wk inc flu hosp",3,"04","quantile","0.95",176 +2022-10-08,"wk inc flu hosp",3,"04","quantile","0.975",176 +2022-10-08,"wk inc flu hosp",3,"04","quantile","0.99",205 +2022-10-08,"wk inc flu hosp",4,"04","quantile","0.01",135 +2022-10-08,"wk inc flu hosp",4,"04","quantile","0.025",137 +2022-10-08,"wk inc flu hosp",4,"04","quantile","0.05",139 +2022-10-08,"wk inc flu hosp",4,"04","quantile","0.1",140 +2022-10-08,"wk inc flu hosp",4,"04","quantile","0.15",141 +2022-10-08,"wk inc flu hosp",4,"04","quantile","0.2",141 +2022-10-08,"wk inc flu hosp",4,"04","quantile","0.25",142 +2022-10-08,"wk inc flu hosp",4,"04","quantile","0.3",143 +2022-10-08,"wk inc flu hosp",4,"04","quantile","0.35",144 +2022-10-08,"wk inc flu hosp",4,"04","quantile","0.4",145 +2022-10-08,"wk inc flu hosp",4,"04","quantile","0.45",147 +2022-10-08,"wk inc flu hosp",4,"04","quantile","0.5",148 +2022-10-08,"wk inc flu hosp",4,"04","quantile","0.55",149 +2022-10-08,"wk inc flu hosp",4,"04","quantile","0.6",150 +2022-10-08,"wk inc flu hosp",4,"04","quantile","0.65",152 +2022-10-08,"wk inc flu hosp",4,"04","quantile","0.7",155 +2022-10-08,"wk inc flu hosp",4,"04","quantile","0.75",161 +2022-10-08,"wk inc flu hosp",4,"04","quantile","0.8",165 +2022-10-08,"wk inc flu hosp",4,"04","quantile","0.85",170 +2022-10-08,"wk inc flu hosp",4,"04","quantile","0.9",175 +2022-10-08,"wk inc flu hosp",4,"04","quantile","0.95",176 +2022-10-08,"wk inc flu hosp",4,"04","quantile","0.975",176 +2022-10-08,"wk inc flu hosp",4,"04","quantile","0.99",205 +2022-10-08,"wk inc flu hosp",1,"01","quantile","0.01",135 +2022-10-08,"wk inc flu hosp",1,"01","quantile","0.025",137 +2022-10-08,"wk inc flu hosp",1,"01","quantile","0.05",139 +2022-10-08,"wk inc flu hosp",1,"01","quantile","0.1",140 +2022-10-08,"wk inc flu hosp",1,"01","quantile","0.15",141 +2022-10-08,"wk inc flu hosp",1,"01","quantile","0.2",141 +2022-10-08,"wk inc flu hosp",1,"01","quantile","0.25",142 +2022-10-08,"wk inc flu hosp",1,"01","quantile","0.3",143 +2022-10-08,"wk inc flu hosp",1,"01","quantile","0.35",144 +2022-10-08,"wk inc flu hosp",1,"01","quantile","0.4",145 +2022-10-08,"wk inc flu hosp",1,"01","quantile","0.45",147 +2022-10-08,"wk inc flu hosp",1,"01","quantile","0.5",148 +2022-10-08,"wk inc flu hosp",1,"01","quantile","0.55",149 +2022-10-08,"wk inc flu hosp",1,"01","quantile","0.6",150 +2022-10-08,"wk inc flu hosp",1,"01","quantile","0.65",152 +2022-10-08,"wk inc flu hosp",1,"01","quantile","0.7",155 +2022-10-08,"wk inc flu hosp",1,"01","quantile","0.75",161 +2022-10-08,"wk inc flu hosp",1,"01","quantile","0.8",165 +2022-10-08,"wk inc flu hosp",1,"01","quantile","0.85",170 +2022-10-08,"wk inc flu hosp",1,"01","quantile","0.9",175 +2022-10-08,"wk inc flu hosp",1,"01","quantile","0.95",176 +2022-10-08,"wk inc flu hosp",1,"01","quantile","0.975",176 +2022-10-08,"wk inc flu hosp",1,"01","quantile","0.99",205 +2022-10-08,"wk inc flu hosp",2,"01","quantile","0.01",135 +2022-10-08,"wk inc flu hosp",2,"01","quantile","0.025",137 +2022-10-08,"wk inc flu hosp",2,"01","quantile","0.05",139 +2022-10-08,"wk inc flu hosp",2,"01","quantile","0.1",140 +2022-10-08,"wk inc flu hosp",2,"01","quantile","0.15",141 +2022-10-08,"wk inc flu hosp",2,"01","quantile","0.2",141 +2022-10-08,"wk inc flu hosp",2,"01","quantile","0.25",142 +2022-10-08,"wk inc flu hosp",2,"01","quantile","0.3",143 +2022-10-08,"wk inc flu hosp",2,"01","quantile","0.35",144 +2022-10-08,"wk inc flu hosp",2,"01","quantile","0.4",145 +2022-10-08,"wk inc flu hosp",2,"01","quantile","0.45",147 +2022-10-08,"wk inc flu hosp",2,"01","quantile","0.5",148 +2022-10-08,"wk inc flu hosp",2,"01","quantile","0.55",149 +2022-10-08,"wk inc flu hosp",2,"01","quantile","0.6",150 +2022-10-08,"wk inc flu hosp",2,"01","quantile","0.65",152 +2022-10-08,"wk inc flu hosp",2,"01","quantile","0.7",155 +2022-10-08,"wk inc flu hosp",2,"01","quantile","0.75",161 +2022-10-08,"wk inc flu hosp",2,"01","quantile","0.8",165 +2022-10-08,"wk inc flu hosp",2,"01","quantile","0.85",170 +2022-10-08,"wk inc flu hosp",2,"01","quantile","0.9",175 +2022-10-08,"wk inc flu hosp",2,"01","quantile","0.95",176 +2022-10-08,"wk inc flu hosp",2,"01","quantile","0.975",176 +2022-10-08,"wk inc flu hosp",2,"01","quantile","0.99",205 +2022-10-08,"wk inc flu hosp",3,"01","quantile","0.01",135 +2022-10-08,"wk inc flu hosp",3,"01","quantile","0.025",137 +2022-10-08,"wk inc flu hosp",3,"01","quantile","0.05",139 +2022-10-08,"wk inc flu hosp",3,"01","quantile","0.1",140 +2022-10-08,"wk inc flu hosp",3,"01","quantile","0.15",141 +2022-10-08,"wk inc flu hosp",3,"01","quantile","0.2",141 +2022-10-08,"wk inc flu hosp",3,"01","quantile","0.25",142 +2022-10-08,"wk inc flu hosp",3,"01","quantile","0.3",143 +2022-10-08,"wk inc flu hosp",3,"01","quantile","0.35",144 +2022-10-08,"wk inc flu hosp",3,"01","quantile","0.4",145 +2022-10-08,"wk inc flu hosp",3,"01","quantile","0.45",147 +2022-10-08,"wk inc flu hosp",3,"01","quantile","0.5",148 +2022-10-08,"wk inc flu hosp",3,"01","quantile","0.55",149 +2022-10-08,"wk inc flu hosp",3,"01","quantile","0.6",150 +2022-10-08,"wk inc flu hosp",3,"01","quantile","0.65",152 +2022-10-08,"wk inc flu hosp",3,"01","quantile","0.7",155 +2022-10-08,"wk inc flu hosp",3,"01","quantile","0.75",161 +2022-10-08,"wk inc flu hosp",3,"01","quantile","0.8",165 +2022-10-08,"wk inc flu hosp",3,"01","quantile","0.85",170 +2022-10-08,"wk inc flu hosp",3,"01","quantile","0.9",175 +2022-10-08,"wk inc flu hosp",3,"01","quantile","0.95",176 +2022-10-08,"wk inc flu hosp",3,"01","quantile","0.975",176 +2022-10-08,"wk inc flu hosp",3,"01","quantile","0.99",205 +2022-10-08,"wk inc flu hosp",4,"01","quantile","0.01",135 +2022-10-08,"wk inc flu hosp",4,"01","quantile","0.025",137 +2022-10-08,"wk inc flu hosp",4,"01","quantile","0.05",139 +2022-10-08,"wk inc flu hosp",4,"01","quantile","0.1",140 +2022-10-08,"wk inc flu hosp",4,"01","quantile","0.15",141 +2022-10-08,"wk inc flu hosp",4,"01","quantile","0.2",141 +2022-10-08,"wk inc flu hosp",4,"01","quantile","0.25",142 +2022-10-08,"wk inc flu hosp",4,"01","quantile","0.3",143 +2022-10-08,"wk inc flu hosp",4,"01","quantile","0.35",144 +2022-10-08,"wk inc flu hosp",4,"01","quantile","0.4",145 +2022-10-08,"wk inc flu hosp",4,"01","quantile","0.45",147 +2022-10-08,"wk inc flu hosp",4,"01","quantile","0.5",148 +2022-10-08,"wk inc flu hosp",4,"01","quantile","0.55",149 +2022-10-08,"wk inc flu hosp",4,"01","quantile","0.6",150 +2022-10-08,"wk inc flu hosp",4,"01","quantile","0.65",152 +2022-10-08,"wk inc flu hosp",4,"01","quantile","0.7",155 +2022-10-08,"wk inc flu hosp",4,"01","quantile","0.75",161 +2022-10-08,"wk inc flu hosp",4,"01","quantile","0.8",165 +2022-10-08,"wk inc flu hosp",4,"01","quantile","0.85",170 +2022-10-08,"wk inc flu hosp",4,"01","quantile","0.9",175 +2022-10-08,"wk inc flu hosp",4,"01","quantile","0.95",176 +2022-10-08,"wk inc flu hosp",4,"01","quantile","0.975",176 +2022-10-08,"wk inc flu hosp",4,"01","quantile","0.99",205 diff --git a/inst/example-data/2022-10-08-team1-goodmodel.csv b/inst/example-data/2022-10-08-team1-goodmodel.csv new file mode 100644 index 0000000..3e6f875 --- /dev/null +++ b/inst/example-data/2022-10-08-team1-goodmodel.csv @@ -0,0 +1,24 @@ +"origin_date","target","horizon","location","type","type_id","value" +2022-10-08,"wk inc flu hosp",1,"US","quantile","0.01",135 +2022-10-08,"wk inc flu hosp",1,"US","quantile","0.025",137 +2022-10-08,"wk inc flu hosp",1,"US","quantile","0.05",139 +2022-10-08,"wk inc flu hosp",1,"US","quantile","0.1",140 +2022-10-08,"wk inc flu hosp",1,"US","quantile","0.15",141 +2022-10-08,"wk inc flu hosp",1,"US","quantile","0.2",141 +2022-10-08,"wk inc flu hosp",1,"US","quantile","0.25",142 +2022-10-08,"wk inc flu hosp",1,"US","quantile","0.3",143 +2022-10-08,"wk inc flu hosp",1,"US","quantile","0.35",144 +2022-10-08,"wk inc flu hosp",1,"US","quantile","0.4",145 +2022-10-08,"wk inc flu hosp",1,"US","quantile","0.45",147 +2022-10-08,"wk inc flu hosp",1,"US","quantile","0.5",148 +2022-10-08,"wk inc flu hosp",1,"US","quantile","0.55",149 +2022-10-08,"wk inc flu hosp",1,"US","quantile","0.6",150 +2022-10-08,"wk inc flu hosp",1,"US","quantile","0.65",152 +2022-10-08,"wk inc flu hosp",1,"US","quantile","0.7",155 +2022-10-08,"wk inc flu hosp",1,"US","quantile","0.75",161 +2022-10-08,"wk inc flu hosp",1,"US","quantile","0.8",165 +2022-10-08,"wk inc flu hosp",1,"US","quantile","0.85",170 +2022-10-08,"wk inc flu hosp",1,"US","quantile","0.9",175 +2022-10-08,"wk inc flu hosp",1,"US","quantile","0.95",176 +2022-10-08,"wk inc flu hosp",1,"US","quantile","0.975",176 +2022-10-08,"wk inc flu hosp",1,"US","quantile","0.99",205 From f67619090099ca1c9ca4c69ad497185866a3c9a9 Mon Sep 17 00:00:00 2001 From: Github Actions CI Date: Mon, 24 Apr 2023 15:28:38 -0400 Subject: [PATCH 03/24] fix weighted vs unweighted ensemble code --- R/simple_ensemble.R | 36 +++++++++++++++++++++--------------- 1 file changed, 21 insertions(+), 15 deletions(-) diff --git a/R/simple_ensemble.R b/R/simple_ensemble.R index b8c02e4..c96c660 100644 --- a/R/simple_ensemble.R +++ b/R/simple_ensemble.R @@ -53,26 +53,32 @@ simple_ensemble <- function(predictions, task_id_vars = NULL, hub_con = NULL, predictions <- relocate(predictions, all_of(col_names)) } - if (!all(predictions$output_type %in% c("mean", "median", "quantile", "cdf", "category"))) # throw warning or error + if (!all(predictions$output_type %in% c("mean", "median", "quantile", "cdf", "category"))) stop("Predictions contains unsupported output type") # throw warning or error - if (!all(names(weights) %in% c("team_abbr", "model_abbr", "weight"))) { - stop("weights did not have required columns", call. = FALSE) - } - if (is.null(weights)) { - weights <- predictions %>% - distinct(team_abbr, model_abbr) %>% - mutate(weight = 1/n()) - } - - if (agg_fun == "mean") agg_fun = "weightedMean" - if (agg_fun == "median") agg_fun = "weightedMedian" + if (agg_fun == "mean") agg_fun = "mean" + if (agg_fun == "median") agg_fun = "median" - - predictions %>% + ensemble_predictions <- predictions %>% + dplyr::group_by(across(all_of(c(task_id_vars, "output_type", "output_id")))) %>% + dplyr::summarize(value = do.call(agg_fun, args = c(agg_args, list(x=value)))) %>% + dplyr::mutate(team_abbr = team_abbr, model_abbr = model_abbr, .before = 1) + # do we want to have the horizon column before target? + } else { + if (!all(names(weights) %in% c("team_abbr", "model_abbr", "weight"))) { + stop("weights did not have required columns", call. = FALSE) + } + + if (agg_fun == "mean") agg_fun = "weightedMean" + if (agg_fun == "median") agg_fun = "weightedMedian" + + ensemble_predictions <- predictions %>% dplyr::left_join(weights) %>% dplyr::group_by(across(all_of(c(task_id_vars, "output_type", "output_id")))) %>% - dplyr::summarize(value = do.call(agg_fun, args = c(agg_args, list(x=value, w=weight)))) %>% + dplyr::summarize(value = do.call(agg_fun, args = c(agg_args, list(x=value, w=weights)))) %>% dplyr::mutate(team_abbr = team_abbr, model_abbr = model_abbr, .before = 1) # do we want to have the horizon column before target? + } + + return (ensemble_predictions) } From cf1a4593e1851115828534b0c34ecfb20d1b22db Mon Sep 17 00:00:00 2001 From: Github Actions CI Date: Tue, 25 Apr 2023 18:48:29 -0400 Subject: [PATCH 04/24] minor fixes to simple_ensemble() function --- R/simple_ensemble.R | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/R/simple_ensemble.R b/R/simple_ensemble.R index c96c660..dad2fe4 100644 --- a/R/simple_ensemble.R +++ b/R/simple_ensemble.R @@ -34,7 +34,7 @@ #' each task id variable, `output_type`, `output_id`, and `value`. simple_ensemble <- function(predictions, task_id_vars = NULL, hub_con = NULL, - weights = NULL, agg_fun = mean, agg_args = list(), + weights = NULL, agg_fun = "mean", agg_args = list(), team_abbr = "Hub", model_abbr = "ensemble") { # require(matrixStats) @@ -62,7 +62,8 @@ simple_ensemble <- function(predictions, task_id_vars = NULL, hub_con = NULL, ensemble_predictions <- predictions %>% dplyr::group_by(across(all_of(c(task_id_vars, "output_type", "output_id")))) %>% dplyr::summarize(value = do.call(agg_fun, args = c(agg_args, list(x=value)))) %>% - dplyr::mutate(team_abbr = team_abbr, model_abbr = model_abbr, .before = 1) + dplyr::mutate(team_abbr = team_abbr, model_abbr = model_abbr, .before = 1) %>% + dplyr::ungroup() # do we want to have the horizon column before target? } else { if (!all(names(weights) %in% c("team_abbr", "model_abbr", "weight"))) { @@ -76,7 +77,8 @@ simple_ensemble <- function(predictions, task_id_vars = NULL, hub_con = NULL, dplyr::left_join(weights) %>% dplyr::group_by(across(all_of(c(task_id_vars, "output_type", "output_id")))) %>% dplyr::summarize(value = do.call(agg_fun, args = c(agg_args, list(x=value, w=weights)))) %>% - dplyr::mutate(team_abbr = team_abbr, model_abbr = model_abbr, .before = 1) + dplyr::mutate(team_abbr = team_abbr, model_abbr = model_abbr, .before = 1) %>% + dplyr::ungroup() # do we want to have the horizon column before target? } From 10f1db55c70999d4b4b42d5112100a110cbc2643 Mon Sep 17 00:00:00 2001 From: Github Actions CI Date: Tue, 25 Apr 2023 18:48:57 -0400 Subject: [PATCH 05/24] add unweighted ensemble testing protocol and test data --- .../2022-10-01-simple_hub-baseline.csv | 25 ++ .../2022-10-08-simple_hub-baseline.csv | 277 ++++++++++++++++++ .../test-data/2022-10-08-team1-goodmodel.csv | 24 ++ tests/testthat/test-simple-ensemble.r | 81 +++++ tests/testthat/test-simple_ensemble.R | 91 ++++++ 5 files changed, 498 insertions(+) create mode 100644 tests/testthat/test-data/2022-10-01-simple_hub-baseline.csv create mode 100644 tests/testthat/test-data/2022-10-08-simple_hub-baseline.csv create mode 100644 tests/testthat/test-data/2022-10-08-team1-goodmodel.csv create mode 100644 tests/testthat/test-simple-ensemble.r create mode 100644 tests/testthat/test-simple_ensemble.R diff --git a/tests/testthat/test-data/2022-10-01-simple_hub-baseline.csv b/tests/testthat/test-data/2022-10-01-simple_hub-baseline.csv new file mode 100644 index 0000000..82c1d62 --- /dev/null +++ b/tests/testthat/test-data/2022-10-01-simple_hub-baseline.csv @@ -0,0 +1,25 @@ +"origin_date","target","horizon","location","type","type_id","value" +2022-10-01,"wk inc flu hosp",1,"US","mean",NA,150 +2022-10-01,"wk inc flu hosp",1,"US","quantile","0.01",135 +2022-10-01,"wk inc flu hosp",1,"US","quantile","0.025",137 +2022-10-01,"wk inc flu hosp",1,"US","quantile","0.05",139 +2022-10-01,"wk inc flu hosp",1,"US","quantile","0.1",140 +2022-10-01,"wk inc flu hosp",1,"US","quantile","0.15",141 +2022-10-01,"wk inc flu hosp",1,"US","quantile","0.2",141 +2022-10-01,"wk inc flu hosp",1,"US","quantile","0.25",142 +2022-10-01,"wk inc flu hosp",1,"US","quantile","0.3",143 +2022-10-01,"wk inc flu hosp",1,"US","quantile","0.35",144 +2022-10-01,"wk inc flu hosp",1,"US","quantile","0.4",145 +2022-10-01,"wk inc flu hosp",1,"US","quantile","0.45",147 +2022-10-01,"wk inc flu hosp",1,"US","quantile","0.5",148 +2022-10-01,"wk inc flu hosp",1,"US","quantile","0.55",149 +2022-10-01,"wk inc flu hosp",1,"US","quantile","0.6",150 +2022-10-01,"wk inc flu hosp",1,"US","quantile","0.65",152 +2022-10-01,"wk inc flu hosp",1,"US","quantile","0.7",155 +2022-10-01,"wk inc flu hosp",1,"US","quantile","0.75",161 +2022-10-01,"wk inc flu hosp",1,"US","quantile","0.8",165 +2022-10-01,"wk inc flu hosp",1,"US","quantile","0.85",170 +2022-10-01,"wk inc flu hosp",1,"US","quantile","0.9",175 +2022-10-01,"wk inc flu hosp",1,"US","quantile","0.95",176 +2022-10-01,"wk inc flu hosp",1,"US","quantile","0.975",176 +2022-10-01,"wk inc flu hosp",1,"US","quantile","0.99",205 diff --git a/tests/testthat/test-data/2022-10-08-simple_hub-baseline.csv b/tests/testthat/test-data/2022-10-08-simple_hub-baseline.csv new file mode 100644 index 0000000..98ad58c --- /dev/null +++ b/tests/testthat/test-data/2022-10-08-simple_hub-baseline.csv @@ -0,0 +1,277 @@ +"origin_date","target","horizon","location","type","type_id","value" +2022-10-08,"wk inc flu hosp",1,"US","quantile","0.01",135 +2022-10-08,"wk inc flu hosp",1,"US","quantile","0.025",137 +2022-10-08,"wk inc flu hosp",1,"US","quantile","0.05",139 +2022-10-08,"wk inc flu hosp",1,"US","quantile","0.1",140 +2022-10-08,"wk inc flu hosp",1,"US","quantile","0.15",141 +2022-10-08,"wk inc flu hosp",1,"US","quantile","0.2",141 +2022-10-08,"wk inc flu hosp",1,"US","quantile","0.25",142 +2022-10-08,"wk inc flu hosp",1,"US","quantile","0.3",143 +2022-10-08,"wk inc flu hosp",1,"US","quantile","0.35",144 +2022-10-08,"wk inc flu hosp",1,"US","quantile","0.4",145 +2022-10-08,"wk inc flu hosp",1,"US","quantile","0.45",147 +2022-10-08,"wk inc flu hosp",1,"US","quantile","0.5",148 +2022-10-08,"wk inc flu hosp",1,"US","quantile","0.55",149 +2022-10-08,"wk inc flu hosp",1,"US","quantile","0.6",150 +2022-10-08,"wk inc flu hosp",1,"US","quantile","0.65",152 +2022-10-08,"wk inc flu hosp",1,"US","quantile","0.7",155 +2022-10-08,"wk inc flu hosp",1,"US","quantile","0.75",161 +2022-10-08,"wk inc flu hosp",1,"US","quantile","0.8",165 +2022-10-08,"wk inc flu hosp",1,"US","quantile","0.85",170 +2022-10-08,"wk inc flu hosp",1,"US","quantile","0.9",175 +2022-10-08,"wk inc flu hosp",1,"US","quantile","0.95",176 +2022-10-08,"wk inc flu hosp",1,"US","quantile","0.975",176 +2022-10-08,"wk inc flu hosp",1,"US","quantile","0.99",205 +2022-10-08,"wk inc flu hosp",2,"US","quantile","0.01",135 +2022-10-08,"wk inc flu hosp",2,"US","quantile","0.025",137 +2022-10-08,"wk inc flu hosp",2,"US","quantile","0.05",139 +2022-10-08,"wk inc flu hosp",2,"US","quantile","0.1",140 +2022-10-08,"wk inc flu hosp",2,"US","quantile","0.15",141 +2022-10-08,"wk inc flu hosp",2,"US","quantile","0.2",141 +2022-10-08,"wk inc flu hosp",2,"US","quantile","0.25",142 +2022-10-08,"wk inc flu hosp",2,"US","quantile","0.3",143 +2022-10-08,"wk inc flu hosp",2,"US","quantile","0.35",144 +2022-10-08,"wk inc flu hosp",2,"US","quantile","0.4",145 +2022-10-08,"wk inc flu hosp",2,"US","quantile","0.45",147 +2022-10-08,"wk inc flu hosp",2,"US","quantile","0.5",148 +2022-10-08,"wk inc flu hosp",2,"US","quantile","0.55",149 +2022-10-08,"wk inc flu hosp",2,"US","quantile","0.6",150 +2022-10-08,"wk inc flu hosp",2,"US","quantile","0.65",152 +2022-10-08,"wk inc flu hosp",2,"US","quantile","0.7",155 +2022-10-08,"wk inc flu hosp",2,"US","quantile","0.75",161 +2022-10-08,"wk inc flu hosp",2,"US","quantile","0.8",165 +2022-10-08,"wk inc flu hosp",2,"US","quantile","0.85",170 +2022-10-08,"wk inc flu hosp",2,"US","quantile","0.9",175 +2022-10-08,"wk inc flu hosp",2,"US","quantile","0.95",176 +2022-10-08,"wk inc flu hosp",2,"US","quantile","0.975",176 +2022-10-08,"wk inc flu hosp",2,"US","quantile","0.99",205 +2022-10-08,"wk inc flu hosp",3,"US","quantile","0.01",135 +2022-10-08,"wk inc flu hosp",3,"US","quantile","0.025",137 +2022-10-08,"wk inc flu hosp",3,"US","quantile","0.05",139 +2022-10-08,"wk inc flu hosp",3,"US","quantile","0.1",140 +2022-10-08,"wk inc flu hosp",3,"US","quantile","0.15",141 +2022-10-08,"wk inc flu hosp",3,"US","quantile","0.2",141 +2022-10-08,"wk inc flu hosp",3,"US","quantile","0.25",142 +2022-10-08,"wk inc flu hosp",3,"US","quantile","0.3",143 +2022-10-08,"wk inc flu hosp",3,"US","quantile","0.35",144 +2022-10-08,"wk inc flu hosp",3,"US","quantile","0.4",145 +2022-10-08,"wk inc flu hosp",3,"US","quantile","0.45",147 +2022-10-08,"wk inc flu hosp",3,"US","quantile","0.5",148 +2022-10-08,"wk inc flu hosp",3,"US","quantile","0.55",149 +2022-10-08,"wk inc flu hosp",3,"US","quantile","0.6",150 +2022-10-08,"wk inc flu hosp",3,"US","quantile","0.65",152 +2022-10-08,"wk inc flu hosp",3,"US","quantile","0.7",155 +2022-10-08,"wk inc flu hosp",3,"US","quantile","0.75",161 +2022-10-08,"wk inc flu hosp",3,"US","quantile","0.8",165 +2022-10-08,"wk inc flu hosp",3,"US","quantile","0.85",170 +2022-10-08,"wk inc flu hosp",3,"US","quantile","0.9",175 +2022-10-08,"wk inc flu hosp",3,"US","quantile","0.95",176 +2022-10-08,"wk inc flu hosp",3,"US","quantile","0.975",176 +2022-10-08,"wk inc flu hosp",3,"US","quantile","0.99",205 +2022-10-08,"wk inc flu hosp",4,"US","quantile","0.01",135 +2022-10-08,"wk inc flu hosp",4,"US","quantile","0.025",137 +2022-10-08,"wk inc flu hosp",4,"US","quantile","0.05",139 +2022-10-08,"wk inc flu hosp",4,"US","quantile","0.1",140 +2022-10-08,"wk inc flu hosp",4,"US","quantile","0.15",141 +2022-10-08,"wk inc flu hosp",4,"US","quantile","0.2",141 +2022-10-08,"wk inc flu hosp",4,"US","quantile","0.25",142 +2022-10-08,"wk inc flu hosp",4,"US","quantile","0.3",143 +2022-10-08,"wk inc flu hosp",4,"US","quantile","0.35",144 +2022-10-08,"wk inc flu hosp",4,"US","quantile","0.4",145 +2022-10-08,"wk inc flu hosp",4,"US","quantile","0.45",147 +2022-10-08,"wk inc flu hosp",4,"US","quantile","0.5",148 +2022-10-08,"wk inc flu hosp",4,"US","quantile","0.55",149 +2022-10-08,"wk inc flu hosp",4,"US","quantile","0.6",150 +2022-10-08,"wk inc flu hosp",4,"US","quantile","0.65",152 +2022-10-08,"wk inc flu hosp",4,"US","quantile","0.7",155 +2022-10-08,"wk inc flu hosp",4,"US","quantile","0.75",161 +2022-10-08,"wk inc flu hosp",4,"US","quantile","0.8",165 +2022-10-08,"wk inc flu hosp",4,"US","quantile","0.85",170 +2022-10-08,"wk inc flu hosp",4,"US","quantile","0.9",175 +2022-10-08,"wk inc flu hosp",4,"US","quantile","0.95",176 +2022-10-08,"wk inc flu hosp",4,"US","quantile","0.975",176 +2022-10-08,"wk inc flu hosp",4,"US","quantile","0.99",205 +2022-10-08,"wk inc flu hosp",1,"04","quantile","0.01",135 +2022-10-08,"wk inc flu hosp",1,"04","quantile","0.025",137 +2022-10-08,"wk inc flu hosp",1,"04","quantile","0.05",139 +2022-10-08,"wk inc flu hosp",1,"04","quantile","0.1",140 +2022-10-08,"wk inc flu hosp",1,"04","quantile","0.15",141 +2022-10-08,"wk inc flu hosp",1,"04","quantile","0.2",141 +2022-10-08,"wk inc flu hosp",1,"04","quantile","0.25",142 +2022-10-08,"wk inc flu hosp",1,"04","quantile","0.3",143 +2022-10-08,"wk inc flu hosp",1,"04","quantile","0.35",144 +2022-10-08,"wk inc flu hosp",1,"04","quantile","0.4",145 +2022-10-08,"wk inc flu hosp",1,"04","quantile","0.45",147 +2022-10-08,"wk inc flu hosp",1,"04","quantile","0.5",148 +2022-10-08,"wk inc flu hosp",1,"04","quantile","0.55",149 +2022-10-08,"wk inc flu hosp",1,"04","quantile","0.6",150 +2022-10-08,"wk inc flu hosp",1,"04","quantile","0.65",152 +2022-10-08,"wk inc flu hosp",1,"04","quantile","0.7",155 +2022-10-08,"wk inc flu hosp",1,"04","quantile","0.75",161 +2022-10-08,"wk inc flu hosp",1,"04","quantile","0.8",165 +2022-10-08,"wk inc flu hosp",1,"04","quantile","0.85",170 +2022-10-08,"wk inc flu hosp",1,"04","quantile","0.9",175 +2022-10-08,"wk inc flu hosp",1,"04","quantile","0.95",176 +2022-10-08,"wk inc flu hosp",1,"04","quantile","0.975",176 +2022-10-08,"wk inc flu hosp",1,"04","quantile","0.99",205 +2022-10-08,"wk inc flu hosp",2,"04","quantile","0.01",135 +2022-10-08,"wk inc flu hosp",2,"04","quantile","0.025",137 +2022-10-08,"wk inc flu hosp",2,"04","quantile","0.05",139 +2022-10-08,"wk inc flu hosp",2,"04","quantile","0.1",140 +2022-10-08,"wk inc flu hosp",2,"04","quantile","0.15",141 +2022-10-08,"wk inc flu hosp",2,"04","quantile","0.2",141 +2022-10-08,"wk inc flu hosp",2,"04","quantile","0.25",142 +2022-10-08,"wk inc flu hosp",2,"04","quantile","0.3",143 +2022-10-08,"wk inc flu hosp",2,"04","quantile","0.35",144 +2022-10-08,"wk inc flu hosp",2,"04","quantile","0.4",145 +2022-10-08,"wk inc flu hosp",2,"04","quantile","0.45",147 +2022-10-08,"wk inc flu hosp",2,"04","quantile","0.5",148 +2022-10-08,"wk inc flu hosp",2,"04","quantile","0.55",149 +2022-10-08,"wk inc flu hosp",2,"04","quantile","0.6",150 +2022-10-08,"wk inc flu hosp",2,"04","quantile","0.65",152 +2022-10-08,"wk inc flu hosp",2,"04","quantile","0.7",155 +2022-10-08,"wk inc flu hosp",2,"04","quantile","0.75",161 +2022-10-08,"wk inc flu hosp",2,"04","quantile","0.8",165 +2022-10-08,"wk inc flu hosp",2,"04","quantile","0.85",170 +2022-10-08,"wk inc flu hosp",2,"04","quantile","0.9",175 +2022-10-08,"wk inc flu hosp",2,"04","quantile","0.95",176 +2022-10-08,"wk inc flu hosp",2,"04","quantile","0.975",176 +2022-10-08,"wk inc flu hosp",2,"04","quantile","0.99",205 +2022-10-08,"wk inc flu hosp",3,"04","quantile","0.01",135 +2022-10-08,"wk inc flu hosp",3,"04","quantile","0.025",137 +2022-10-08,"wk inc flu hosp",3,"04","quantile","0.05",139 +2022-10-08,"wk inc flu hosp",3,"04","quantile","0.1",140 +2022-10-08,"wk inc flu hosp",3,"04","quantile","0.15",141 +2022-10-08,"wk inc flu hosp",3,"04","quantile","0.2",141 +2022-10-08,"wk inc flu hosp",3,"04","quantile","0.25",142 +2022-10-08,"wk inc flu hosp",3,"04","quantile","0.3",143 +2022-10-08,"wk inc flu hosp",3,"04","quantile","0.35",144 +2022-10-08,"wk inc flu hosp",3,"04","quantile","0.4",145 +2022-10-08,"wk inc flu hosp",3,"04","quantile","0.45",147 +2022-10-08,"wk inc flu hosp",3,"04","quantile","0.5",148 +2022-10-08,"wk inc flu hosp",3,"04","quantile","0.55",149 +2022-10-08,"wk inc flu hosp",3,"04","quantile","0.6",150 +2022-10-08,"wk inc flu hosp",3,"04","quantile","0.65",152 +2022-10-08,"wk inc flu hosp",3,"04","quantile","0.7",155 +2022-10-08,"wk inc flu hosp",3,"04","quantile","0.75",161 +2022-10-08,"wk inc flu hosp",3,"04","quantile","0.8",165 +2022-10-08,"wk inc flu hosp",3,"04","quantile","0.85",170 +2022-10-08,"wk inc flu hosp",3,"04","quantile","0.9",175 +2022-10-08,"wk inc flu hosp",3,"04","quantile","0.95",176 +2022-10-08,"wk inc flu hosp",3,"04","quantile","0.975",176 +2022-10-08,"wk inc flu hosp",3,"04","quantile","0.99",205 +2022-10-08,"wk inc flu hosp",4,"04","quantile","0.01",135 +2022-10-08,"wk inc flu hosp",4,"04","quantile","0.025",137 +2022-10-08,"wk inc flu hosp",4,"04","quantile","0.05",139 +2022-10-08,"wk inc flu hosp",4,"04","quantile","0.1",140 +2022-10-08,"wk inc flu hosp",4,"04","quantile","0.15",141 +2022-10-08,"wk inc flu hosp",4,"04","quantile","0.2",141 +2022-10-08,"wk inc flu hosp",4,"04","quantile","0.25",142 +2022-10-08,"wk inc flu hosp",4,"04","quantile","0.3",143 +2022-10-08,"wk inc flu hosp",4,"04","quantile","0.35",144 +2022-10-08,"wk inc flu hosp",4,"04","quantile","0.4",145 +2022-10-08,"wk inc flu hosp",4,"04","quantile","0.45",147 +2022-10-08,"wk inc flu hosp",4,"04","quantile","0.5",148 +2022-10-08,"wk inc flu hosp",4,"04","quantile","0.55",149 +2022-10-08,"wk inc flu hosp",4,"04","quantile","0.6",150 +2022-10-08,"wk inc flu hosp",4,"04","quantile","0.65",152 +2022-10-08,"wk inc flu hosp",4,"04","quantile","0.7",155 +2022-10-08,"wk inc flu hosp",4,"04","quantile","0.75",161 +2022-10-08,"wk inc flu hosp",4,"04","quantile","0.8",165 +2022-10-08,"wk inc flu hosp",4,"04","quantile","0.85",170 +2022-10-08,"wk inc flu hosp",4,"04","quantile","0.9",175 +2022-10-08,"wk inc flu hosp",4,"04","quantile","0.95",176 +2022-10-08,"wk inc flu hosp",4,"04","quantile","0.975",176 +2022-10-08,"wk inc flu hosp",4,"04","quantile","0.99",205 +2022-10-08,"wk inc flu hosp",1,"01","quantile","0.01",135 +2022-10-08,"wk inc flu hosp",1,"01","quantile","0.025",137 +2022-10-08,"wk inc flu hosp",1,"01","quantile","0.05",139 +2022-10-08,"wk inc flu hosp",1,"01","quantile","0.1",140 +2022-10-08,"wk inc flu hosp",1,"01","quantile","0.15",141 +2022-10-08,"wk inc flu hosp",1,"01","quantile","0.2",141 +2022-10-08,"wk inc flu hosp",1,"01","quantile","0.25",142 +2022-10-08,"wk inc flu hosp",1,"01","quantile","0.3",143 +2022-10-08,"wk inc flu hosp",1,"01","quantile","0.35",144 +2022-10-08,"wk inc flu hosp",1,"01","quantile","0.4",145 +2022-10-08,"wk inc flu hosp",1,"01","quantile","0.45",147 +2022-10-08,"wk inc flu hosp",1,"01","quantile","0.5",148 +2022-10-08,"wk inc flu hosp",1,"01","quantile","0.55",149 +2022-10-08,"wk inc flu hosp",1,"01","quantile","0.6",150 +2022-10-08,"wk inc flu hosp",1,"01","quantile","0.65",152 +2022-10-08,"wk inc flu hosp",1,"01","quantile","0.7",155 +2022-10-08,"wk inc flu hosp",1,"01","quantile","0.75",161 +2022-10-08,"wk inc flu hosp",1,"01","quantile","0.8",165 +2022-10-08,"wk inc flu hosp",1,"01","quantile","0.85",170 +2022-10-08,"wk inc flu hosp",1,"01","quantile","0.9",175 +2022-10-08,"wk inc flu hosp",1,"01","quantile","0.95",176 +2022-10-08,"wk inc flu hosp",1,"01","quantile","0.975",176 +2022-10-08,"wk inc flu hosp",1,"01","quantile","0.99",205 +2022-10-08,"wk inc flu hosp",2,"01","quantile","0.01",135 +2022-10-08,"wk inc flu hosp",2,"01","quantile","0.025",137 +2022-10-08,"wk inc flu hosp",2,"01","quantile","0.05",139 +2022-10-08,"wk inc flu hosp",2,"01","quantile","0.1",140 +2022-10-08,"wk inc flu hosp",2,"01","quantile","0.15",141 +2022-10-08,"wk inc flu hosp",2,"01","quantile","0.2",141 +2022-10-08,"wk inc flu hosp",2,"01","quantile","0.25",142 +2022-10-08,"wk inc flu hosp",2,"01","quantile","0.3",143 +2022-10-08,"wk inc flu hosp",2,"01","quantile","0.35",144 +2022-10-08,"wk inc flu hosp",2,"01","quantile","0.4",145 +2022-10-08,"wk inc flu hosp",2,"01","quantile","0.45",147 +2022-10-08,"wk inc flu hosp",2,"01","quantile","0.5",148 +2022-10-08,"wk inc flu hosp",2,"01","quantile","0.55",149 +2022-10-08,"wk inc flu hosp",2,"01","quantile","0.6",150 +2022-10-08,"wk inc flu hosp",2,"01","quantile","0.65",152 +2022-10-08,"wk inc flu hosp",2,"01","quantile","0.7",155 +2022-10-08,"wk inc flu hosp",2,"01","quantile","0.75",161 +2022-10-08,"wk inc flu hosp",2,"01","quantile","0.8",165 +2022-10-08,"wk inc flu hosp",2,"01","quantile","0.85",170 +2022-10-08,"wk inc flu hosp",2,"01","quantile","0.9",175 +2022-10-08,"wk inc flu hosp",2,"01","quantile","0.95",176 +2022-10-08,"wk inc flu hosp",2,"01","quantile","0.975",176 +2022-10-08,"wk inc flu hosp",2,"01","quantile","0.99",205 +2022-10-08,"wk inc flu hosp",3,"01","quantile","0.01",135 +2022-10-08,"wk inc flu hosp",3,"01","quantile","0.025",137 +2022-10-08,"wk inc flu hosp",3,"01","quantile","0.05",139 +2022-10-08,"wk inc flu hosp",3,"01","quantile","0.1",140 +2022-10-08,"wk inc flu hosp",3,"01","quantile","0.15",141 +2022-10-08,"wk inc flu hosp",3,"01","quantile","0.2",141 +2022-10-08,"wk inc flu hosp",3,"01","quantile","0.25",142 +2022-10-08,"wk inc flu hosp",3,"01","quantile","0.3",143 +2022-10-08,"wk inc flu hosp",3,"01","quantile","0.35",144 +2022-10-08,"wk inc flu hosp",3,"01","quantile","0.4",145 +2022-10-08,"wk inc flu hosp",3,"01","quantile","0.45",147 +2022-10-08,"wk inc flu hosp",3,"01","quantile","0.5",148 +2022-10-08,"wk inc flu hosp",3,"01","quantile","0.55",149 +2022-10-08,"wk inc flu hosp",3,"01","quantile","0.6",150 +2022-10-08,"wk inc flu hosp",3,"01","quantile","0.65",152 +2022-10-08,"wk inc flu hosp",3,"01","quantile","0.7",155 +2022-10-08,"wk inc flu hosp",3,"01","quantile","0.75",161 +2022-10-08,"wk inc flu hosp",3,"01","quantile","0.8",165 +2022-10-08,"wk inc flu hosp",3,"01","quantile","0.85",170 +2022-10-08,"wk inc flu hosp",3,"01","quantile","0.9",175 +2022-10-08,"wk inc flu hosp",3,"01","quantile","0.95",176 +2022-10-08,"wk inc flu hosp",3,"01","quantile","0.975",176 +2022-10-08,"wk inc flu hosp",3,"01","quantile","0.99",205 +2022-10-08,"wk inc flu hosp",4,"01","quantile","0.01",135 +2022-10-08,"wk inc flu hosp",4,"01","quantile","0.025",137 +2022-10-08,"wk inc flu hosp",4,"01","quantile","0.05",139 +2022-10-08,"wk inc flu hosp",4,"01","quantile","0.1",140 +2022-10-08,"wk inc flu hosp",4,"01","quantile","0.15",141 +2022-10-08,"wk inc flu hosp",4,"01","quantile","0.2",141 +2022-10-08,"wk inc flu hosp",4,"01","quantile","0.25",142 +2022-10-08,"wk inc flu hosp",4,"01","quantile","0.3",143 +2022-10-08,"wk inc flu hosp",4,"01","quantile","0.35",144 +2022-10-08,"wk inc flu hosp",4,"01","quantile","0.4",145 +2022-10-08,"wk inc flu hosp",4,"01","quantile","0.45",147 +2022-10-08,"wk inc flu hosp",4,"01","quantile","0.5",148 +2022-10-08,"wk inc flu hosp",4,"01","quantile","0.55",149 +2022-10-08,"wk inc flu hosp",4,"01","quantile","0.6",150 +2022-10-08,"wk inc flu hosp",4,"01","quantile","0.65",152 +2022-10-08,"wk inc flu hosp",4,"01","quantile","0.7",155 +2022-10-08,"wk inc flu hosp",4,"01","quantile","0.75",161 +2022-10-08,"wk inc flu hosp",4,"01","quantile","0.8",165 +2022-10-08,"wk inc flu hosp",4,"01","quantile","0.85",170 +2022-10-08,"wk inc flu hosp",4,"01","quantile","0.9",175 +2022-10-08,"wk inc flu hosp",4,"01","quantile","0.95",176 +2022-10-08,"wk inc flu hosp",4,"01","quantile","0.975",176 +2022-10-08,"wk inc flu hosp",4,"01","quantile","0.99",205 diff --git a/tests/testthat/test-data/2022-10-08-team1-goodmodel.csv b/tests/testthat/test-data/2022-10-08-team1-goodmodel.csv new file mode 100644 index 0000000..3e6f875 --- /dev/null +++ b/tests/testthat/test-data/2022-10-08-team1-goodmodel.csv @@ -0,0 +1,24 @@ +"origin_date","target","horizon","location","type","type_id","value" +2022-10-08,"wk inc flu hosp",1,"US","quantile","0.01",135 +2022-10-08,"wk inc flu hosp",1,"US","quantile","0.025",137 +2022-10-08,"wk inc flu hosp",1,"US","quantile","0.05",139 +2022-10-08,"wk inc flu hosp",1,"US","quantile","0.1",140 +2022-10-08,"wk inc flu hosp",1,"US","quantile","0.15",141 +2022-10-08,"wk inc flu hosp",1,"US","quantile","0.2",141 +2022-10-08,"wk inc flu hosp",1,"US","quantile","0.25",142 +2022-10-08,"wk inc flu hosp",1,"US","quantile","0.3",143 +2022-10-08,"wk inc flu hosp",1,"US","quantile","0.35",144 +2022-10-08,"wk inc flu hosp",1,"US","quantile","0.4",145 +2022-10-08,"wk inc flu hosp",1,"US","quantile","0.45",147 +2022-10-08,"wk inc flu hosp",1,"US","quantile","0.5",148 +2022-10-08,"wk inc flu hosp",1,"US","quantile","0.55",149 +2022-10-08,"wk inc flu hosp",1,"US","quantile","0.6",150 +2022-10-08,"wk inc flu hosp",1,"US","quantile","0.65",152 +2022-10-08,"wk inc flu hosp",1,"US","quantile","0.7",155 +2022-10-08,"wk inc flu hosp",1,"US","quantile","0.75",161 +2022-10-08,"wk inc flu hosp",1,"US","quantile","0.8",165 +2022-10-08,"wk inc flu hosp",1,"US","quantile","0.85",170 +2022-10-08,"wk inc flu hosp",1,"US","quantile","0.9",175 +2022-10-08,"wk inc flu hosp",1,"US","quantile","0.95",176 +2022-10-08,"wk inc flu hosp",1,"US","quantile","0.975",176 +2022-10-08,"wk inc flu hosp",1,"US","quantile","0.99",205 diff --git a/tests/testthat/test-simple-ensemble.r b/tests/testthat/test-simple-ensemble.r new file mode 100644 index 0000000..ebdedbb --- /dev/null +++ b/tests/testthat/test-simple-ensemble.r @@ -0,0 +1,81 @@ +context("simple_ensemble") +library(hubEnsembles) +library(matrixStats) + +tmp_dat <- readr::read_csv("test-data/minimal-forecast.csv") + +test_that("invalid method argument throws error", { + expect_error( + build_quantile_ensemble(tmp_dat, + method="weighted mean", + model_name = "example") + ) +}) + +test_that("medians and means correctly calculated", { + fdat <- expand.grid( + stringsAsFactors = FALSE, + model = letters[1:4], + location = c("222", "888"), + horizon = 1, + temporal_resolution = "wk", + target_variable = "inc death", + target_end_date = as.Date("2021-12-25"), + type = "quantile", + quantile = c(.1, .5, .9)) + + fdat$value[fdat$location == "222" & fdat$quantile == .1] <- v2.1 <- c(10, 30, 15, 20) + fdat$value[fdat$location == "222" & fdat$quantile == .5] <- v2.5 <- c(40, 40, 45, 50) + fdat$value[fdat$location == "222" & fdat$quantile == .9] <- v2.9 <- c(60, 70, 75, 80) + fdat$value[fdat$location == "888" & fdat$quantile == .1] <- v8.1 <- c(100, 300, 400, 250) + fdat$value[fdat$location == "888" & fdat$quantile == .5] <- v8.5 <- c(150, 325, 500, 300) + fdat$value[fdat$location == "888" & fdat$quantile == .9] <- v8.9 <- c(250, 350, 500, 350) + + median_vals <- sapply(list(v2.1, v2.5, v2.9, v8.1, v8.5, v8.9), median) + mean_vals <- sapply(list(v2.1, v2.5, v2.9, v8.1, v8.5, v8.9), mean) + + fweight <- tibble(model = letters[1:4], weight = 0.1*(1:4)) + + weighted_median_vals <- map(list(v2.1, v2.5, v2.9, v8.1, v8.5, v8.9), weightedMedian, w = fweight$weight) + weighted_mean_vals <- map(list(v2.1, v2.5, v2.9, v8.1, v8.5, v8.9), weightedMean, w = fweight$weight) + + median_actual <- build_quantile_ensemble( + forecast_data = fdat, weights_df = NULL, method = "median", model_name = "median_ens", forecast_date = "2021-12-20") + mean_actual <- build_quantile_ensemble( + forecast_data = fdat, weights_df = NULL, method = "mean", model_name = "mean_ens", forecast_date = "2021-12-20") + + weighted_median_actual <- build_quantile_ensemble( + forecast_data = fdat, weights_df = fweight, method = "median", model_name = "weighted_median_ens", forecast_date = "2021-12-20") + weighted_mean_actual <- build_quantile_ensemble( + forecast_data = fdat, weights_df = fweight, method = "mean", model_name = "weighted_mean_ens", forecast_date = "2021-12-20") + + + median_expected <- mean_expected <- weighted_median_expected <- weighted_mean_expected <- tibble::tibble( + location = rep(c("222", "888"), each = 3), + horizon = 1, + temporal_resolution = "wk", + target_variable = "inc death", + target_end_date = as.Date("2021-12-25"), + type = "quantile", + quantile = rep(c(.1, .5, .9), 2), + forecast_count = 4, + value = 0, + model = NA, + forecast_date = as.Date("2021-12-20")) + + median_expected$value <- sapply(list(v2.1, v2.5, v2.9, v8.1, v8.5, v8.9), median) + median_expected$model <- "median_ens" + mean_expected$value <- sapply(list(v2.1, v2.5, v2.9, v8.1, v8.5, v8.9), mean) + mean_expected$model <- "mean_ens" + + weighted_mean_expected$value <- map_dbl(list(v2.1, v2.5, v2.9, v8.1, v8.5, v8.9), weightedMean, w = fweight$weight) + weighted_mean_expected$model <- "weighted_mean_ens" + weighted_median_expected$value <- map_dbl(list(v2.1, v2.5, v2.9, v8.1, v8.5, v8.9), weightedMedian, w = fweight$weight) + weighted_median_expected$model <- "weighted_median_ens" + + expect_equal(median_actual, median_expected) + expect_equal(mean_actual, mean_expected) + + expect_equal(weighted_median_actual, weighted_median_expected) + expect_equal(weighted_mean_actual, weighted_mean_expected) +}) diff --git a/tests/testthat/test-simple_ensemble.R b/tests/testthat/test-simple_ensemble.R new file mode 100644 index 0000000..8456c69 --- /dev/null +++ b/tests/testthat/test-simple_ensemble.R @@ -0,0 +1,91 @@ +library(matrixStats) + +test_that("multiplication works", { + expect_equal(2 * 2, 4) +}) + +#tmp_dat <- readr::read_csv("test-data/minimal-forecast.csv") + +baseline_1001 <- readr::read_csv("inst/test-data/2022-10-01-simple_hub-baseline.csv") %>% + mutate(team_abbr = "simple_hub", model_abbr = "baseline", .before = origin_date) +baseline_1008 <- readr::read_csv("inst/test-data/2022-10-08-simple_hub-baseline.csv") %>% + mutate(team_abbr = "simple_hub", model_abbr = "baseline", .before = origin_date) +team1_1008 <- readr::read_csv("inst/test-data/2022-10-08-team1-goodmodel.csv") %>% + mutate(team_abbr = "team_1", model_abbr = "goodmodel1", .before = origin_date) %>% + mutate(value = value + 2) + +test_that("invalid method argument throws error", { + expect_error( + simple_ensemble( + baseline_1008, + agg_fun="linear pool", + model_abbr = "example" + ) + ) +}) + +test_that("medians and means correctly calculated", { + fdat <- expand.grid( + stringsAsFactors = FALSE, + model_abbr = letters[1:4], + location = c("222", "888"), + horizon = 1, #week + target = "inc death", + target_date = as.Date("2021-12-25"), + output_type = "quantile", + output_id = c(.1, .5, .9)) + + fdat <- cbind(team_abbr = letters[23:26], fdat) + fdat$value[fdat$location == "222" & fdat$output_id == .1] <- v2.1 <- c(10, 30, 15, 20) + fdat$value[fdat$location == "222" & fdat$output_id == .5] <- v2.5 <- c(40, 40, 45, 50) + fdat$value[fdat$location == "222" & fdat$output_id == .9] <- v2.9 <- c(60, 70, 75, 80) + fdat$value[fdat$location == "888" & fdat$output_id == .1] <- v8.1 <- c(100, 300, 400, 250) + fdat$value[fdat$location == "888" & fdat$output_id == .5] <- v8.5 <- c(150, 325, 500, 300) + fdat$value[fdat$location == "888" & fdat$output_id == .9] <- v8.9 <- c(250, 350, 500, 350) + + median_vals <- sapply(list(v2.1, v2.5, v2.9, v8.1, v8.5, v8.9), median) + mean_vals <- sapply(list(v2.1, v2.5, v2.9, v8.1, v8.5, v8.9), mean) + + fweight <- tibble(model_abbr = letters[1:4], weight = 0.1*(1:4)) + + weighted_median_vals <- map(list(v2.1, v2.5, v2.9, v8.1, v8.5, v8.9), weightedMedian, w = fweight$weight) + weighted_mean_vals <- map(list(v2.1, v2.5, v2.9, v8.1, v8.5, v8.9), weightedMean, w = fweight$weight) + + median_actual <- simple_ensemble( + predictions = fdat, weights = NULL, agg_fun = "median", model_abbr = "median_ens") + mean_actual <- simple_ensemble( + predictions = fdat, weights = NULL, agg_fun = "mean", model_abbr = "mean_ens") + + weighted_median_actual <- simple_ensemble( + predictions = fdat, weights = fweight, agg_fun = "median", model_abbr = "weighted_median_ens") + weighted_mean_actual <- simple_ensemble( + predictions = fdat, weights = fweight, agg_fun = "mean", model_abbr = "weighted_mean_ens") + + + median_expected <- mean_expected <- weighted_median_expected <- weighted_mean_expected <- tibble::tibble( + team_abbr = "Hub", + model_abbr = NA, + location = rep(c("222", "888"), each = 3), + horizon = 1, #week + target = "inc death", + target_date = as.Date("2021-12-25"), + output_type = "quantile", + output_id = rep(c(.1, .5, .9), 2), + value = 0) + + median_expected$value <- sapply(list(v2.1, v2.5, v2.9, v8.1, v8.5, v8.9), median) + median_expected$model_abbr <- "median_ens" + mean_expected$value <- sapply(list(v2.1, v2.5, v2.9, v8.1, v8.5, v8.9), mean) + mean_expected$model_abbr <- "mean_ens" + + weighted_mean_expected$value <- map_dbl(list(v2.1, v2.5, v2.9, v8.1, v8.5, v8.9), weightedMean, w = fweight$weight) + weighted_mean_expected$model_abbr <- "weighted_mean_ens" + weighted_median_expected$value <- map_dbl(list(v2.1, v2.5, v2.9, v8.1, v8.5, v8.9), weightedMedian, w = fweight$weight) + weighted_median_expected$model_abbr <- "weighted_median_ens" + + expect_equal(median_actual, median_expected) + expect_equal(mean_actual, mean_expected) + + expect_equal(weighted_median_actual, weighted_median_expected) + expect_equal(weighted_mean_actual, weighted_mean_expected) +}) \ No newline at end of file From 4420a9f35567de2f25077977d998618743c8ae55 Mon Sep 17 00:00:00 2001 From: Github Actions CI Date: Fri, 28 Apr 2023 15:24:44 -0400 Subject: [PATCH 06/24] Remove hard coding of output_type, output_id, value columns --- R/simple_ensemble.R | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/R/simple_ensemble.R b/R/simple_ensemble.R index dad2fe4..3fb8ab9 100644 --- a/R/simple_ensemble.R +++ b/R/simple_ensemble.R @@ -29,38 +29,41 @@ #' to `agg_fun`. #' @param team_abbr, model_abbr `character` strings with the name of the team #' and model to use for the ensemble predictions. +#' @param output_type_col, output_id_col, value_col `character` strings with the +#' names of the columns in `predictions` for the output's type, additional +#' identifying information, and value (of the prediction) #' #' @return a data.frame with columns `team_abbr`, `model_abbr`, one column for #' each task id variable, `output_type`, `output_id`, and `value`. simple_ensemble <- function(predictions, task_id_vars = NULL, hub_con = NULL, weights = NULL, agg_fun = "mean", agg_args = list(), - team_abbr = "Hub", model_abbr = "ensemble") { + team_abbr = "Hub", model_abbr = "ensemble", output_type_col = "output_type", output_id_col = "output_id", value_col = "value") { # require(matrixStats) if (is.null(task_id_vars) && is.null(hub_con)) { temp <- colnames(predictions) - task_id_vars <- temp[!temp %in% c("team_abbr", "model_abbr", "output_type", "output_id", "value")] + task_id_vars <- temp[!temp %in% c("team_abbr", "model_abbr", output_type_col, output_id_col, value_col)] } else if (is.null(task_id_vars)) { # task_id variables looked up from `hub_con` } - col_names <- c("team_abbr", "model_abbr", task_id_vars, "output_type", "output_id", "value") + col_names <- c("team_abbr", "model_abbr", task_id_vars, output_type_col, output_id_col, value_col) if ((length(predictions) == 0) || !all(names(predictions) %in% col_names)) { stop("predictions did not have required columns", call. = FALSE) - } else if (!all(names(predictions) == col_names) && names(predictions) %in% col_names) { + } else if (!all(names(predictions) == col_names) && all(names(predictions) %in% col_names)) { predictions <- relocate(predictions, all_of(col_names)) } - if (!all(predictions$output_type %in% c("mean", "median", "quantile", "cdf", "category"))) stop("Predictions contains unsupported output type") # throw warning or error + if (!all(pull(predictions[temp == output_type_col], 1) %in% c("mean", "median", "quantile", "cdf", "category"))) stop("Predictions contains unsupported output type") # throw warning or error if (is.null(weights)) { if (agg_fun == "mean") agg_fun = "mean" if (agg_fun == "median") agg_fun = "median" ensemble_predictions <- predictions %>% - dplyr::group_by(across(all_of(c(task_id_vars, "output_type", "output_id")))) %>% + dplyr::group_by(across(all_of(c(task_id_vars, output_type_col, output_id_col)))) %>% dplyr::summarize(value = do.call(agg_fun, args = c(agg_args, list(x=value)))) %>% dplyr::mutate(team_abbr = team_abbr, model_abbr = model_abbr, .before = 1) %>% dplyr::ungroup() @@ -75,8 +78,8 @@ simple_ensemble <- function(predictions, task_id_vars = NULL, hub_con = NULL, ensemble_predictions <- predictions %>% dplyr::left_join(weights) %>% - dplyr::group_by(across(all_of(c(task_id_vars, "output_type", "output_id")))) %>% - dplyr::summarize(value = do.call(agg_fun, args = c(agg_args, list(x=value, w=weights)))) %>% + dplyr::group_by(across(all_of(c(task_id_vars, output_type_col, output_id_col)))) %>% + dplyr::summarize(value = do.call(agg_fun, args = c(agg_args, list(x=value, w=weight)))) %>% dplyr::mutate(team_abbr = team_abbr, model_abbr = model_abbr, .before = 1) %>% dplyr::ungroup() # do we want to have the horizon column before target? From d6cad091c0bafa984200330d707bc38cf22496ae Mon Sep 17 00:00:00 2001 From: Github Actions CI Date: Fri, 28 Apr 2023 16:02:09 -0400 Subject: [PATCH 07/24] Add additional tests for simple_ensemble() --- tests/testthat/test-simple_ensemble.R | 35 +++++++++++++++++++++++---- 1 file changed, 30 insertions(+), 5 deletions(-) diff --git a/tests/testthat/test-simple_ensemble.R b/tests/testthat/test-simple_ensemble.R index 8456c69..c7e6099 100644 --- a/tests/testthat/test-simple_ensemble.R +++ b/tests/testthat/test-simple_ensemble.R @@ -14,13 +14,38 @@ team1_1008 <- readr::read_csv("inst/test-data/2022-10-08-team1-goodmodel.csv") % mutate(team_abbr = "team_1", model_abbr = "goodmodel1", .before = origin_date) %>% mutate(value = value + 2) +test_that("non-default column names are preserved in output data frame", { + output_names <- baseline_1008 %>% + simple_ensemble( + agg_fun="mean", + model_abbr = "example", + output_type_col = "type", + output_id_col = "type_id" + ) %>% + names() + expect_equal(names(baseline_1008), output_names) +}) + +test_that("invalid output type throws error", { + expect_error( + baseline_1008 %>% + rename(output_type = type, output_id=type_id) %>% + mutate(output_type="sample") %>% + simple_ensemble( + agg_fun="mean", + model_abbr = "example" + ) + ) +}) + test_that("invalid method argument throws error", { expect_error( - simple_ensemble( - baseline_1008, - agg_fun="linear pool", - model_abbr = "example" - ) + baseline_1008 %>% + rename(output_type = type, output_id=type_id) %>% + simple_ensemble( + agg_fun="linear pool", + model_abbr = "example" + ) ) }) From 3900a91bbcb853a9c7ace0ff1365c29a2f0a7a02 Mon Sep 17 00:00:00 2001 From: Github Actions CI Date: Tue, 2 May 2023 11:02:38 -0400 Subject: [PATCH 08/24] clean up simple_ensemble() testing file --- tests/testthat/test-simple_ensemble.R | 21 +++++---------------- 1 file changed, 5 insertions(+), 16 deletions(-) diff --git a/tests/testthat/test-simple_ensemble.R b/tests/testthat/test-simple_ensemble.R index c7e6099..e8fe157 100644 --- a/tests/testthat/test-simple_ensemble.R +++ b/tests/testthat/test-simple_ensemble.R @@ -1,21 +1,10 @@ library(matrixStats) -test_that("multiplication works", { - expect_equal(2 * 2, 4) -}) - -#tmp_dat <- readr::read_csv("test-data/minimal-forecast.csv") - -baseline_1001 <- readr::read_csv("inst/test-data/2022-10-01-simple_hub-baseline.csv") %>% - mutate(team_abbr = "simple_hub", model_abbr = "baseline", .before = origin_date) -baseline_1008 <- readr::read_csv("inst/test-data/2022-10-08-simple_hub-baseline.csv") %>% +pred <- readr::read_csv("inst/test-data/2022-10-08-simple_hub-baseline.csv") %>% mutate(team_abbr = "simple_hub", model_abbr = "baseline", .before = origin_date) -team1_1008 <- readr::read_csv("inst/test-data/2022-10-08-team1-goodmodel.csv") %>% - mutate(team_abbr = "team_1", model_abbr = "goodmodel1", .before = origin_date) %>% - mutate(value = value + 2) test_that("non-default column names are preserved in output data frame", { - output_names <- baseline_1008 %>% + output_names <- pred %>% simple_ensemble( agg_fun="mean", model_abbr = "example", @@ -23,12 +12,12 @@ test_that("non-default column names are preserved in output data frame", { output_id_col = "type_id" ) %>% names() - expect_equal(names(baseline_1008), output_names) + expect_equal(names(pred), output_names) }) test_that("invalid output type throws error", { expect_error( - baseline_1008 %>% + pred %>% rename(output_type = type, output_id=type_id) %>% mutate(output_type="sample") %>% simple_ensemble( @@ -40,7 +29,7 @@ test_that("invalid output type throws error", { test_that("invalid method argument throws error", { expect_error( - baseline_1008 %>% + pred %>% rename(output_type = type, output_id=type_id) %>% simple_ensemble( agg_fun="linear pool", From 935c4201ce226a2b40686938038d501170446ae0 Mon Sep 17 00:00:00 2001 From: Github Actions CI Date: Tue, 2 May 2023 11:07:02 -0400 Subject: [PATCH 09/24] remove extraneous test data --- .../2022-10-01-simple_hub-baseline.csv | 25 ------------------- .../test-data/2022-10-08-team1-goodmodel.csv | 24 ------------------ 2 files changed, 49 deletions(-) delete mode 100644 tests/testthat/test-data/2022-10-01-simple_hub-baseline.csv delete mode 100644 tests/testthat/test-data/2022-10-08-team1-goodmodel.csv diff --git a/tests/testthat/test-data/2022-10-01-simple_hub-baseline.csv b/tests/testthat/test-data/2022-10-01-simple_hub-baseline.csv deleted file mode 100644 index 82c1d62..0000000 --- a/tests/testthat/test-data/2022-10-01-simple_hub-baseline.csv +++ /dev/null @@ -1,25 +0,0 @@ -"origin_date","target","horizon","location","type","type_id","value" -2022-10-01,"wk inc flu hosp",1,"US","mean",NA,150 -2022-10-01,"wk inc flu hosp",1,"US","quantile","0.01",135 -2022-10-01,"wk inc flu hosp",1,"US","quantile","0.025",137 -2022-10-01,"wk inc flu hosp",1,"US","quantile","0.05",139 -2022-10-01,"wk inc flu hosp",1,"US","quantile","0.1",140 -2022-10-01,"wk inc flu hosp",1,"US","quantile","0.15",141 -2022-10-01,"wk inc flu hosp",1,"US","quantile","0.2",141 -2022-10-01,"wk inc flu hosp",1,"US","quantile","0.25",142 -2022-10-01,"wk inc flu hosp",1,"US","quantile","0.3",143 -2022-10-01,"wk inc flu hosp",1,"US","quantile","0.35",144 -2022-10-01,"wk inc flu hosp",1,"US","quantile","0.4",145 -2022-10-01,"wk inc flu hosp",1,"US","quantile","0.45",147 -2022-10-01,"wk inc flu hosp",1,"US","quantile","0.5",148 -2022-10-01,"wk inc flu hosp",1,"US","quantile","0.55",149 -2022-10-01,"wk inc flu hosp",1,"US","quantile","0.6",150 -2022-10-01,"wk inc flu hosp",1,"US","quantile","0.65",152 -2022-10-01,"wk inc flu hosp",1,"US","quantile","0.7",155 -2022-10-01,"wk inc flu hosp",1,"US","quantile","0.75",161 -2022-10-01,"wk inc flu hosp",1,"US","quantile","0.8",165 -2022-10-01,"wk inc flu hosp",1,"US","quantile","0.85",170 -2022-10-01,"wk inc flu hosp",1,"US","quantile","0.9",175 -2022-10-01,"wk inc flu hosp",1,"US","quantile","0.95",176 -2022-10-01,"wk inc flu hosp",1,"US","quantile","0.975",176 -2022-10-01,"wk inc flu hosp",1,"US","quantile","0.99",205 diff --git a/tests/testthat/test-data/2022-10-08-team1-goodmodel.csv b/tests/testthat/test-data/2022-10-08-team1-goodmodel.csv deleted file mode 100644 index 3e6f875..0000000 --- a/tests/testthat/test-data/2022-10-08-team1-goodmodel.csv +++ /dev/null @@ -1,24 +0,0 @@ -"origin_date","target","horizon","location","type","type_id","value" -2022-10-08,"wk inc flu hosp",1,"US","quantile","0.01",135 -2022-10-08,"wk inc flu hosp",1,"US","quantile","0.025",137 -2022-10-08,"wk inc flu hosp",1,"US","quantile","0.05",139 -2022-10-08,"wk inc flu hosp",1,"US","quantile","0.1",140 -2022-10-08,"wk inc flu hosp",1,"US","quantile","0.15",141 -2022-10-08,"wk inc flu hosp",1,"US","quantile","0.2",141 -2022-10-08,"wk inc flu hosp",1,"US","quantile","0.25",142 -2022-10-08,"wk inc flu hosp",1,"US","quantile","0.3",143 -2022-10-08,"wk inc flu hosp",1,"US","quantile","0.35",144 -2022-10-08,"wk inc flu hosp",1,"US","quantile","0.4",145 -2022-10-08,"wk inc flu hosp",1,"US","quantile","0.45",147 -2022-10-08,"wk inc flu hosp",1,"US","quantile","0.5",148 -2022-10-08,"wk inc flu hosp",1,"US","quantile","0.55",149 -2022-10-08,"wk inc flu hosp",1,"US","quantile","0.6",150 -2022-10-08,"wk inc flu hosp",1,"US","quantile","0.65",152 -2022-10-08,"wk inc flu hosp",1,"US","quantile","0.7",155 -2022-10-08,"wk inc flu hosp",1,"US","quantile","0.75",161 -2022-10-08,"wk inc flu hosp",1,"US","quantile","0.8",165 -2022-10-08,"wk inc flu hosp",1,"US","quantile","0.85",170 -2022-10-08,"wk inc flu hosp",1,"US","quantile","0.9",175 -2022-10-08,"wk inc flu hosp",1,"US","quantile","0.95",176 -2022-10-08,"wk inc flu hosp",1,"US","quantile","0.975",176 -2022-10-08,"wk inc flu hosp",1,"US","quantile","0.99",205 From 15a9299435faf735a8eec92af6168eb8b11a377c Mon Sep 17 00:00:00 2001 From: Github Actions CI Date: Tue, 2 May 2023 11:07:33 -0400 Subject: [PATCH 10/24] remove old testing file --- tests/testthat/test-simple-ensemble.r | 81 --------------------------- 1 file changed, 81 deletions(-) delete mode 100644 tests/testthat/test-simple-ensemble.r diff --git a/tests/testthat/test-simple-ensemble.r b/tests/testthat/test-simple-ensemble.r deleted file mode 100644 index ebdedbb..0000000 --- a/tests/testthat/test-simple-ensemble.r +++ /dev/null @@ -1,81 +0,0 @@ -context("simple_ensemble") -library(hubEnsembles) -library(matrixStats) - -tmp_dat <- readr::read_csv("test-data/minimal-forecast.csv") - -test_that("invalid method argument throws error", { - expect_error( - build_quantile_ensemble(tmp_dat, - method="weighted mean", - model_name = "example") - ) -}) - -test_that("medians and means correctly calculated", { - fdat <- expand.grid( - stringsAsFactors = FALSE, - model = letters[1:4], - location = c("222", "888"), - horizon = 1, - temporal_resolution = "wk", - target_variable = "inc death", - target_end_date = as.Date("2021-12-25"), - type = "quantile", - quantile = c(.1, .5, .9)) - - fdat$value[fdat$location == "222" & fdat$quantile == .1] <- v2.1 <- c(10, 30, 15, 20) - fdat$value[fdat$location == "222" & fdat$quantile == .5] <- v2.5 <- c(40, 40, 45, 50) - fdat$value[fdat$location == "222" & fdat$quantile == .9] <- v2.9 <- c(60, 70, 75, 80) - fdat$value[fdat$location == "888" & fdat$quantile == .1] <- v8.1 <- c(100, 300, 400, 250) - fdat$value[fdat$location == "888" & fdat$quantile == .5] <- v8.5 <- c(150, 325, 500, 300) - fdat$value[fdat$location == "888" & fdat$quantile == .9] <- v8.9 <- c(250, 350, 500, 350) - - median_vals <- sapply(list(v2.1, v2.5, v2.9, v8.1, v8.5, v8.9), median) - mean_vals <- sapply(list(v2.1, v2.5, v2.9, v8.1, v8.5, v8.9), mean) - - fweight <- tibble(model = letters[1:4], weight = 0.1*(1:4)) - - weighted_median_vals <- map(list(v2.1, v2.5, v2.9, v8.1, v8.5, v8.9), weightedMedian, w = fweight$weight) - weighted_mean_vals <- map(list(v2.1, v2.5, v2.9, v8.1, v8.5, v8.9), weightedMean, w = fweight$weight) - - median_actual <- build_quantile_ensemble( - forecast_data = fdat, weights_df = NULL, method = "median", model_name = "median_ens", forecast_date = "2021-12-20") - mean_actual <- build_quantile_ensemble( - forecast_data = fdat, weights_df = NULL, method = "mean", model_name = "mean_ens", forecast_date = "2021-12-20") - - weighted_median_actual <- build_quantile_ensemble( - forecast_data = fdat, weights_df = fweight, method = "median", model_name = "weighted_median_ens", forecast_date = "2021-12-20") - weighted_mean_actual <- build_quantile_ensemble( - forecast_data = fdat, weights_df = fweight, method = "mean", model_name = "weighted_mean_ens", forecast_date = "2021-12-20") - - - median_expected <- mean_expected <- weighted_median_expected <- weighted_mean_expected <- tibble::tibble( - location = rep(c("222", "888"), each = 3), - horizon = 1, - temporal_resolution = "wk", - target_variable = "inc death", - target_end_date = as.Date("2021-12-25"), - type = "quantile", - quantile = rep(c(.1, .5, .9), 2), - forecast_count = 4, - value = 0, - model = NA, - forecast_date = as.Date("2021-12-20")) - - median_expected$value <- sapply(list(v2.1, v2.5, v2.9, v8.1, v8.5, v8.9), median) - median_expected$model <- "median_ens" - mean_expected$value <- sapply(list(v2.1, v2.5, v2.9, v8.1, v8.5, v8.9), mean) - mean_expected$model <- "mean_ens" - - weighted_mean_expected$value <- map_dbl(list(v2.1, v2.5, v2.9, v8.1, v8.5, v8.9), weightedMean, w = fweight$weight) - weighted_mean_expected$model <- "weighted_mean_ens" - weighted_median_expected$value <- map_dbl(list(v2.1, v2.5, v2.9, v8.1, v8.5, v8.9), weightedMedian, w = fweight$weight) - weighted_median_expected$model <- "weighted_median_ens" - - expect_equal(median_actual, median_expected) - expect_equal(mean_actual, mean_expected) - - expect_equal(weighted_median_actual, weighted_median_expected) - expect_equal(weighted_mean_actual, weighted_mean_expected) -}) From 7588b8a53eeed8431aad2174bc8f36948106e3b8 Mon Sep 17 00:00:00 2001 From: Github Actions CI Date: Wed, 24 May 2023 10:09:39 -0400 Subject: [PATCH 11/24] fix file path to test data --- tests/testthat/test-simple_ensemble.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/testthat/test-simple_ensemble.R b/tests/testthat/test-simple_ensemble.R index e8fe157..e116a83 100644 --- a/tests/testthat/test-simple_ensemble.R +++ b/tests/testthat/test-simple_ensemble.R @@ -1,6 +1,6 @@ library(matrixStats) -pred <- readr::read_csv("inst/test-data/2022-10-08-simple_hub-baseline.csv") %>% +pred <- readr::read_csv("inst/example-data/2022-10-08-simple_hub-baseline.csv") %>% mutate(team_abbr = "simple_hub", model_abbr = "baseline", .before = origin_date) test_that("non-default column names are preserved in output data frame", { From a88c6939438b18ffa4113fea0a012326fc92b5a4 Mon Sep 17 00:00:00 2001 From: Evan Ray Date: Mon, 5 Jun 2023 17:11:01 -0400 Subject: [PATCH 12/24] progress on refactoring simple_ensemble to use cli for messaging and streamline logic --- R/simple_ensemble.R | 127 +++++++++++++++++++++++++------------------- R/utils-pipe.R | 14 +++++ 2 files changed, 85 insertions(+), 56 deletions(-) create mode 100644 R/utils-pipe.R diff --git a/R/simple_ensemble.R b/R/simple_ensemble.R index 3fb8ab9..18201c4 100644 --- a/R/simple_ensemble.R +++ b/R/simple_ensemble.R @@ -1,18 +1,15 @@ #' Compute ensemble predictions by summarizing component predictions for each -#' combination of model task, output type, and type id. Supported output types -#' include `mean`, `median`, `quantile`, `cdf`, and `category`. +#' combination of model task, output type, and output type id. Supported output +#' types include `mean`, `median`, `quantile`, `cdf`, and `pmf`. #' #' @param predictions a `data.frame` with component model predictions. It should #' have columns `team_abbr`, `model_abbr`, one column for each task id #' variable, `output_type`, `output_type_id`, and `value`. -#' @param task_id_vars an optional `character` vector naming the columns of +#' @param task_id_cols an optional `character` vector naming the columns of #' `predictions` that correspond to task id variables. The default is `NULL`, -#' in which case the task id variables are looked up from the `hub_con` if one -#' is provided. If neither `task_id_vars` nor `hub_con` are provided, all -#' columns in `predictions` _other than_ `team_abbr`, `model_abbr`, -#' `output_type`, `output_id`, and `value` will be used as task id -#' variables. -#' @param hub_con an optional hub connection object; see `hubUtils::connect_hub` +#' in which case all columns in `predictions` _other than_ `team_abbr`, +#' `model_abbr`, and the specified `output_type_col`, `output_id_col`, and +#' `value_col` are used as task id variables. #' @param weights an optional `data.frame` with component model weights. If #' provided, it should have columns `team_name`, `model_abbr`, `weight`, #' and optionally, additional columns corresponding to task id variables, @@ -23,67 +20,85 @@ #' aggregating component predictions into the ensemble prediction. The default #' is `mean`, in which case the ensemble prediction is the simple average of #' the component model predictions. The provided function should have an -#' argument `x` for the vector of numeric values to summarize, and for weighted -#' methods, an argument `w` with a numeric vector of weights. +#' argument `x` for the vector of numeric values to summarize, and for +#' weighted methods, an argument `w` with a numeric vector of weights. #' @param agg_args a named list of any additional arguments that will be passed #' to `agg_fun`. #' @param team_abbr, model_abbr `character` strings with the name of the team #' and model to use for the ensemble predictions. -#' @param output_type_col, output_id_col, value_col `character` strings with the -#' names of the columns in `predictions` for the output's type, additional +#' @param output_type_col, output_id_col, value_col `character` strings with the +#' names of the columns in `predictions` for the output's type, additional #' identifying information, and value (of the prediction) #' #' @return a data.frame with columns `team_abbr`, `model_abbr`, one column for #' each task id variable, `output_type`, `output_id`, and `value`. +simple_ensemble <- function(predictions, task_id_cols = NULL, + weights = NULL, agg_fun = "mean", agg_args = list(), + team_abbr = "Hub", model_abbr = "ensemble", + output_type_col = "output_type", + output_id_col = "output_id", + value_col = "value") { + if (!is.data.frame(predictions)) { + cli::cli_abort(c("x" = "{.arg predictions} must be a `data.frame`.")) + } -simple_ensemble <- function(predictions, task_id_vars = NULL, hub_con = NULL, - weights = NULL, agg_fun = "mean", agg_args = list(), - team_abbr = "Hub", model_abbr = "ensemble", output_type_col = "output_type", output_id_col = "output_id", value_col = "value") { + if (is.null(task_id_cols)) { + cols <- colnames(predictions) + non_task_cols <- c("team_abbr", "model_abbr", output_type_col, + output_id_col, value_col) + task_id_cols <- cols[!cols %in% non_task_cols] + } - # require(matrixStats) - - if (is.null(task_id_vars) && is.null(hub_con)) { - temp <- colnames(predictions) - task_id_vars <- temp[!temp %in% c("team_abbr", "model_abbr", output_type_col, output_id_col, value_col)] - } else if (is.null(task_id_vars)) { - # task_id variables looked up from `hub_con` + col_names <- c("team_abbr", "model_abbr", task_id_cols, output_type_col, + output_id_col, value_col) + if (!all(names(predictions) %in% col_names)) { + cli::cli_abort(c( + "x" = "{.arg predictions} did not have all required columns + {.val {col_names}}." + )) } - - col_names <- c("team_abbr", "model_abbr", task_id_vars, output_type_col, output_id_col, value_col) - if ((length(predictions) == 0) || !all(names(predictions) %in% col_names)) { - stop("predictions did not have required columns", call. = FALSE) - } else if (!all(names(predictions) == col_names) && all(names(predictions) %in% col_names)) { - predictions <- relocate(predictions, all_of(col_names)) + + valid_types <- c("mean", "median", "quantile", "cdf", "pmf") + unique_types <- unique(predictions[[output_type_col]]) + invalid_types <- unique_types[!unique_types %in% valid_types] + if (length(invalid_types) > 0) { + cli::cli_abort(c( + "x" = "{.arg predictions} contains unsupported output type.", + "i" = "Included type{?s}: {.val {invalid_types}}.", + "i" = "Supported output types: {.val {valid_types}}." + )) } - - if (!all(pull(predictions[temp == output_type_col], 1) %in% c("mean", "median", "quantile", "cdf", "category"))) stop("Predictions contains unsupported output type") # throw warning or error if (is.null(weights)) { - if (agg_fun == "mean") agg_fun = "mean" - if (agg_fun == "median") agg_fun = "median" - - ensemble_predictions <- predictions %>% - dplyr::group_by(across(all_of(c(task_id_vars, output_type_col, output_id_col)))) %>% - dplyr::summarize(value = do.call(agg_fun, args = c(agg_args, list(x=value)))) %>% - dplyr::mutate(team_abbr = team_abbr, model_abbr = model_abbr, .before = 1) %>% - dplyr::ungroup() - # do we want to have the horizon column before target? + if (agg_fun == "mean") agg_fun <- mean + if (agg_fun == "median") agg_fun <- median + + agg_args <- c(agg_args, list(x = "value")) } else { - if (!all(names(weights) %in% c("team_abbr", "model_abbr", "weight"))) { - stop("weights did not have required columns", call. = FALSE) + req_weight_cols <- c("model_abbr", "team_abbr", "weight") + if (!isTRUE(all.equal(sort(names(weights)), req_weight_cols))) { + cli::cli_abort(c( + "x" = "{.arg weights} did not have required columns + {.val {req_weight_cols}}." + )) } - - if (agg_fun == "mean") agg_fun = "weightedMean" - if (agg_fun == "median") agg_fun = "weightedMedian" - - ensemble_predictions <- predictions %>% - dplyr::left_join(weights) %>% - dplyr::group_by(across(all_of(c(task_id_vars, output_type_col, output_id_col)))) %>% - dplyr::summarize(value = do.call(agg_fun, args = c(agg_args, list(x=value, w=weight)))) %>% - dplyr::mutate(team_abbr = team_abbr, model_abbr = model_abbr, .before = 1) %>% - dplyr::ungroup() - # do we want to have the horizon column before target? - } - - return (ensemble_predictions) + + predictions <- predictions %>% + dplyr::left_join(weights, by = c("team_abbr", "model_abbr")) + + if (agg_fun == "mean") agg_fun <- matrixStats::weightedMean + if (agg_fun == "median") agg_fun <- matrixStats::weightedMedian + + agg_args <- c(agg_args, list(x = "value", w = "weight")) + } + + ensemble_predictions <- predictions %>% + dplyr::group_by(dplyr::across(dplyr::all_of(c(task_id_cols, output_type_col, + output_id_col)))) %>% + dplyr::summarize(value = do.call(agg_fun, args = agg_args)) %>% + dplyr::mutate(team_abbr = team_abbr, model_abbr = model_abbr, + .before = 1) %>% + dplyr::ungroup() + + return(ensemble_predictions) } diff --git a/R/utils-pipe.R b/R/utils-pipe.R new file mode 100644 index 0000000..fd0b1d1 --- /dev/null +++ b/R/utils-pipe.R @@ -0,0 +1,14 @@ +#' Pipe operator +#' +#' See \code{magrittr::\link[magrittr:pipe]{\%>\%}} for details. +#' +#' @name %>% +#' @rdname pipe +#' @keywords internal +#' @export +#' @importFrom magrittr %>% +#' @usage lhs \%>\% rhs +#' @param lhs A value or the magrittr placeholder. +#' @param rhs A function call using the magrittr semantics. +#' @return The result of calling `rhs(lhs)`. +NULL From 39f601ed8ec044adb55b8a81716ab2d1e7d5e3f4 Mon Sep 17 00:00:00 2001 From: Evan Ray Date: Mon, 5 Jun 2023 22:58:00 -0400 Subject: [PATCH 13/24] warn if 0 rows --- R/simple_ensemble.R | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/R/simple_ensemble.R b/R/simple_ensemble.R index 18201c4..9dcd3b4 100644 --- a/R/simple_ensemble.R +++ b/R/simple_ensemble.R @@ -42,6 +42,10 @@ simple_ensemble <- function(predictions, task_id_cols = NULL, cli::cli_abort(c("x" = "{.arg predictions} must be a `data.frame`.")) } + if (nrow(predictions) == 0) { + cli::cli_warn(c("!" = "{.arg predictions} has zero rows.")) + } + if (is.null(task_id_cols)) { cols <- colnames(predictions) non_task_cols <- c("team_abbr", "model_abbr", output_type_col, From 0794377e29ae50462c3b3fde27d754492f24d0c7 Mon Sep 17 00:00:00 2001 From: Evan Ray Date: Tue, 6 Jun 2023 12:07:05 -0400 Subject: [PATCH 14/24] add package imports, validations for weights --- DESCRIPTION | 5 ++++ R/simple_ensemble.R | 62 ++++++++++++++++++++++++++++++++------------- 2 files changed, 50 insertions(+), 17 deletions(-) diff --git a/DESCRIPTION b/DESCRIPTION index 90bc71c..4dd6a11 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -14,3 +14,8 @@ Roxygen: list(markdown = TRUE) RoxygenNote: 7.2.2 URL: https://github.com/Infectious-Disease-Modeling-Hubs/hubEnsembles BugReports: https://github.com/Infectious-Disease-Modeling-Hubs/hubEnsembles/issues +Imports: + cli, + dplyr, + matrixStats, + rlang diff --git a/R/simple_ensemble.R b/R/simple_ensemble.R index 9dcd3b4..45afcf8 100644 --- a/R/simple_ensemble.R +++ b/R/simple_ensemble.R @@ -32,35 +32,38 @@ #' #' @return a data.frame with columns `team_abbr`, `model_abbr`, one column for #' each task id variable, `output_type`, `output_id`, and `value`. -simple_ensemble <- function(predictions, task_id_cols = NULL, - weights = NULL, agg_fun = "mean", agg_args = list(), - team_abbr = "Hub", model_abbr = "ensemble", +simple_ensemble <- function(predictions, weights = NULL, + agg_fun = "mean", agg_args = list(), + team_abbr = "hub", model_abbr = "ensemble", + task_id_cols = NULL, output_type_col = "output_type", - output_id_col = "output_id", - value_col = "value") { + output_id_col = "output_type_id", + hub_connection = NULL) { if (!is.data.frame(predictions)) { cli::cli_abort(c("x" = "{.arg predictions} must be a `data.frame`.")) } - if (nrow(predictions) == 0) { - cli::cli_warn(c("!" = "{.arg predictions} has zero rows.")) - } - if (is.null(task_id_cols)) { cols <- colnames(predictions) non_task_cols <- c("team_abbr", "model_abbr", output_type_col, - output_id_col, value_col) + output_id_col, "value") task_id_cols <- cols[!cols %in% non_task_cols] } col_names <- c("team_abbr", "model_abbr", task_id_cols, output_type_col, - output_id_col, value_col) - if (!all(names(predictions) %in% col_names)) { + output_id_col, "value") + if (!all(colnames(predictions) %in% col_names)) { cli::cli_abort(c( "x" = "{.arg predictions} did not have all required columns {.val {col_names}}." )) } + + ## Validations above this point to be relocated to hubUtils + + if (nrow(predictions) == 0) { + cli::cli_warn(c("!" = "{.arg predictions} has zero rows.")) + } valid_types <- c("mean", "median", "quantile", "cdf", "pmf") unique_types <- unique(predictions[[output_type_col]]) @@ -68,7 +71,7 @@ simple_ensemble <- function(predictions, task_id_cols = NULL, if (length(invalid_types) > 0) { cli::cli_abort(c( "x" = "{.arg predictions} contains unsupported output type.", - "i" = "Included type{?s}: {.val {invalid_types}}.", + "i" = "Included output type{?s}: {.val {invalid_types}}.", "i" = "Supported output types: {.val {valid_types}}." )) } @@ -80,20 +83,45 @@ simple_ensemble <- function(predictions, task_id_cols = NULL, agg_args <- c(agg_args, list(x = "value")) } else { req_weight_cols <- c("model_abbr", "team_abbr", "weight") - if (!isTRUE(all.equal(sort(names(weights)), req_weight_cols))) { + if (!all(req_weight_cols %in% colnames(weights))) { cli::cli_abort(c( - "x" = "{.arg weights} did not have required columns + "x" = "{.arg weights} did not include required columns {.val {req_weight_cols}}." )) } + weight_by_cols <- colnames(weights)[colnames(weights) != "weight"] + + if ("value" %in% weight_by_cols) { + cli::cli_abort(c( + "x" = "{.arg weights} included a column named {.val {\"value\"}}, + which is not allowed." + )) + } + + invalid_cols <- weight_by_cols[!weight_by_cols %in% colnames(predictions)] + if (!all(weight_by_cols %in% colnames(predictions))) { + cli::cli_abort(c( + "x" = "{.arg weights} included {length(invalid_cols)} column{?s} that + {?was/were} not present in {.arg predictions}: + {.val {invalid_cols}}" + )) + } + + if ("weight" %in% colnames(predictions)) { + weight_col_name <- paste0("weight_", rlang::hash(colnames(predictions))) + weights <- weights %>% dplyr::rename(!!weight_col_name := "weight") + } else { + weight_col_name <- "weight" + } + predictions <- predictions %>% - dplyr::left_join(weights, by = c("team_abbr", "model_abbr")) + dplyr::left_join(weights, by = weight_by_cols) if (agg_fun == "mean") agg_fun <- matrixStats::weightedMean if (agg_fun == "median") agg_fun <- matrixStats::weightedMedian - agg_args <- c(agg_args, list(x = "value", w = "weight")) + agg_args <- c(agg_args, list(x = "value", w = weight_col_name)) } ensemble_predictions <- predictions %>% From 783e936e6d36b0d119d94a15e780b4ad19363841 Mon Sep 17 00:00:00 2001 From: Evan Ray Date: Wed, 7 Jun 2023 14:48:01 -0400 Subject: [PATCH 15/24] misc updates --- DESCRIPTION | 2 +- NAMESPACE | 2 + R/simple_ensemble.R | 136 ++++----- .../2022-10-01-simple_hub-baseline.csv | 2 +- .../2022-10-08-simple_hub-baseline.csv | 2 +- .../2022-10-08-team1-goodmodel.csv | 2 +- man/pipe.Rd | 20 ++ man/simple_ensemble.Rd | 65 ++++ .../2022-10-08-simple_hub-baseline.csv | 277 ------------------ tests/testthat/test-simple_ensemble.R | 153 +++++----- 10 files changed, 245 insertions(+), 416 deletions(-) create mode 100644 man/pipe.Rd create mode 100644 man/simple_ensemble.Rd delete mode 100644 tests/testthat/test-data/2022-10-08-simple_hub-baseline.csv diff --git a/DESCRIPTION b/DESCRIPTION index 4dd6a11..1abb6bd 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -11,7 +11,7 @@ Suggests: Config/testthat/edition: 3 Encoding: UTF-8 Roxygen: list(markdown = TRUE) -RoxygenNote: 7.2.2 +RoxygenNote: 7.2.3 URL: https://github.com/Infectious-Disease-Modeling-Hubs/hubEnsembles BugReports: https://github.com/Infectious-Disease-Modeling-Hubs/hubEnsembles/issues Imports: diff --git a/NAMESPACE b/NAMESPACE index 6ae9268..ea39623 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -1,2 +1,4 @@ # Generated by roxygen2: do not edit by hand +export("%>%") +importFrom(magrittr,"%>%") diff --git a/R/simple_ensemble.R b/R/simple_ensemble.R index 45afcf8..a2763c0 100644 --- a/R/simple_ensemble.R +++ b/R/simple_ensemble.R @@ -1,88 +1,85 @@ -#' Compute ensemble predictions by summarizing component predictions for each -#' combination of model task, output type, and output type id. Supported output -#' types include `mean`, `median`, `quantile`, `cdf`, and `pmf`. +#' Compute ensemble model outputs by summarizing component model outputs for +#' each combination of model task, output type, and output type id. Supported +#' output types include `mean`, `median`, `quantile`, `cdf`, and `pmf`. #' -#' @param predictions a `data.frame` with component model predictions. It should -#' have columns `team_abbr`, `model_abbr`, one column for each task id -#' variable, `output_type`, `output_type_id`, and `value`. -#' @param task_id_cols an optional `character` vector naming the columns of -#' `predictions` that correspond to task id variables. The default is `NULL`, -#' in which case all columns in `predictions` _other than_ `team_abbr`, -#' `model_abbr`, and the specified `output_type_col`, `output_id_col`, and -#' `value_col` are used as task id variables. +#' @param model_outputs an object of class `model_output_df` with component +#' model outputs (e.g., predictions). #' @param weights an optional `data.frame` with component model weights. If -#' provided, it should have columns `team_name`, `model_abbr`, `weight`, -#' and optionally, additional columns corresponding to task id variables, -#' `output_type`, or `output_id`, if weights are specific to values of -#' those variables. The default is `NULL`, in which case an equally-weighted -#' ensemble is calculated. +#' provided, it should have columns `model_id`, `weight`, and optionally, +#' additional columns corresponding to task id variables, `output_type`, or +#' `output_type_id`, if weights are specific to values of those variables. The +#' default is `NULL`, in which case an equally-weighted ensemble is calculated. #' @param agg_fun a function or character string name of a function to use for -#' aggregating component predictions into the ensemble prediction. The default -#' is `mean`, in which case the ensemble prediction is the simple average of -#' the component model predictions. The provided function should have an -#' argument `x` for the vector of numeric values to summarize, and for -#' weighted methods, an argument `w` with a numeric vector of weights. +#' aggregating component model outputs into the ensemble outputs. See the +#' details for more information. #' @param agg_args a named list of any additional arguments that will be passed #' to `agg_fun`. -#' @param team_abbr, model_abbr `character` strings with the name of the team -#' and model to use for the ensemble predictions. -#' @param output_type_col, output_id_col, value_col `character` strings with the -#' names of the columns in `predictions` for the output's type, additional -#' identifying information, and value (of the prediction) +#' @param model_id `character` string with the identifier to use for the +#' ensemble model. +#' @param task_id_cols, output_type_col, output_type_id_col, value_col +#' `character` vectors with the names of the columns in `model_outputs` for +#' the output's type, additional identifying information, and value of the +#' model output. +#' +#' @details The default for `agg_fun` is `mean`, in which case the ensemble's +#' output is the average of the component model outputs within each group +#' defined by a combination of values in the task id columns, output type, and +#' output type id. The provided `agg_fun` should have an argument `x` for the +#' vector of numeric values to summarize, and for weighted methods, an +#' argument `w` with a numeric vector of weights. For weighted methods, +#' `agg_fun = "mean"` and `agg_fun = "median"` are translated to use +#' `matrixStats::weightedMean` and `matrixStats::weightedMedian` respectively. #' #' @return a data.frame with columns `team_abbr`, `model_abbr`, one column for -#' each task id variable, `output_type`, `output_id`, and `value`. -simple_ensemble <- function(predictions, weights = NULL, +#' each task id variable, `output_type`, `output_id`, and `value`. Note that +#' any additional columns in the input `model_outputs` are dropped. +simple_ensemble <- function(model_outputs, weights = NULL, agg_fun = "mean", agg_args = list(), - team_abbr = "hub", model_abbr = "ensemble", + model_id = "hub-ensemble", task_id_cols = NULL, output_type_col = "output_type", - output_id_col = "output_type_id", + output_type_id_col = "output_type_id", hub_connection = NULL) { - if (!is.data.frame(predictions)) { - cli::cli_abort(c("x" = "{.arg predictions} must be a `data.frame`.")) + model_out_cols <- colnames(model_outputs) + if (!is.data.frame(model_outputs)) { + cli::cli_abort(c("x" = "{.arg model_outputs} must be a `data.frame`.")) } + non_task_cols <- c("model_id", output_type_col, output_type_id_col, "value") if (is.null(task_id_cols)) { - cols <- colnames(predictions) - non_task_cols <- c("team_abbr", "model_abbr", output_type_col, - output_id_col, "value") - task_id_cols <- cols[!cols %in% non_task_cols] + task_id_cols <- model_out_cols[!model_out_cols %in% non_task_cols] } - col_names <- c("team_abbr", "model_abbr", task_id_cols, output_type_col, - output_id_col, "value") - if (!all(colnames(predictions) %in% col_names)) { + req_col_names <- c(non_task_cols, task_id_cols) + if (!all(req_col_names %in% model_out_cols)) { cli::cli_abort(c( - "x" = "{.arg predictions} did not have all required columns - {.val {col_names}}." + "x" = "{.arg model_outputs} did not have all required columns + {.val {req_col_names}}." )) } - + ## Validations above this point to be relocated to hubUtils + # hubUtils::validate_model_output_df(model_outputs) - if (nrow(predictions) == 0) { - cli::cli_warn(c("!" = "{.arg predictions} has zero rows.")) + if (nrow(model_outputs) == 0) { + cli::cli_warn(c("!" = "{.arg model_outputs} has zero rows.")) } valid_types <- c("mean", "median", "quantile", "cdf", "pmf") - unique_types <- unique(predictions[[output_type_col]]) + unique_types <- unique(model_outputs[[output_type_col]]) invalid_types <- unique_types[!unique_types %in% valid_types] if (length(invalid_types) > 0) { cli::cli_abort(c( - "x" = "{.arg predictions} contains unsupported output type.", + "x" = "{.arg model_outputs} contains unsupported output type.", "i" = "Included output type{?s}: {.val {invalid_types}}.", "i" = "Supported output types: {.val {valid_types}}." )) } if (is.null(weights)) { - if (agg_fun == "mean") agg_fun <- mean - if (agg_fun == "median") agg_fun <- median - - agg_args <- c(agg_args, list(x = "value")) + agg_args <- c(agg_args, list(x = quote(.data[["value"]]))) } else { - req_weight_cols <- c("model_abbr", "team_abbr", "weight") + req_weight_cols <- c("model_id", "weight") if (!all(req_weight_cols %in% colnames(weights))) { cli::cli_abort(c( "x" = "{.arg weights} did not include required columns @@ -99,38 +96,45 @@ simple_ensemble <- function(predictions, weights = NULL, )) } - invalid_cols <- weight_by_cols[!weight_by_cols %in% colnames(predictions)] - if (!all(weight_by_cols %in% colnames(predictions))) { + invalid_cols <- weight_by_cols[!weight_by_cols %in% colnames(model_outputs)] + if (!all(weight_by_cols %in% colnames(model_outputs))) { cli::cli_abort(c( "x" = "{.arg weights} included {length(invalid_cols)} column{?s} that - {?was/were} not present in {.arg predictions}: + {?was/were} not present in {.arg model_outputs}: {.val {invalid_cols}}" )) } - if ("weight" %in% colnames(predictions)) { - weight_col_name <- paste0("weight_", rlang::hash(colnames(predictions))) + if ("weight" %in% colnames(model_outputs)) { + weight_col_name <- paste0("weight_", rlang::hash(colnames(model_outputs))) weights <- weights %>% dplyr::rename(!!weight_col_name := "weight") } else { weight_col_name <- "weight" } - predictions <- predictions %>% + model_outputs <- model_outputs %>% dplyr::left_join(weights, by = weight_by_cols) - if (agg_fun == "mean") agg_fun <- matrixStats::weightedMean - if (agg_fun == "median") agg_fun <- matrixStats::weightedMedian + if (is.character(agg_fun)) { + if (agg_fun == "mean") { + agg_fun <- matrixStats::weightedMean + } else if (agg_fun == "median") { + agg_fun <- matrixStats::weightedMedian + } + } - agg_args <- c(agg_args, list(x = "value", w = weight_col_name)) + agg_args <- c(agg_args, list(x = quote(.data[["value"]]), + w = quote(.data[[weight_col_name]]))) } - ensemble_predictions <- predictions %>% - dplyr::group_by(dplyr::across(dplyr::all_of(c(task_id_cols, output_type_col, - output_id_col)))) %>% + group_by_cols <- c(task_id_cols, output_type_col, output_type_id_col) + ensemble_model_outputs <- model_outputs %>% + dplyr::group_by(dplyr::across(dplyr::all_of(group_by_cols))) %>% dplyr::summarize(value = do.call(agg_fun, args = agg_args)) %>% - dplyr::mutate(team_abbr = team_abbr, model_abbr = model_abbr, - .before = 1) %>% + dplyr::mutate(model_id = model_id, .before = 1) %>% dplyr::ungroup() - return(ensemble_predictions) + # hubUtils::as_model_output_df(ensemble_model_outputs) + + return(ensemble_model_outputs) } diff --git a/inst/example-data/2022-10-01-simple_hub-baseline.csv b/inst/example-data/2022-10-01-simple_hub-baseline.csv index 82c1d62..703831b 100644 --- a/inst/example-data/2022-10-01-simple_hub-baseline.csv +++ b/inst/example-data/2022-10-01-simple_hub-baseline.csv @@ -1,4 +1,4 @@ -"origin_date","target","horizon","location","type","type_id","value" +"origin_date","target","horizon","location","output_type","output_type_id","value" 2022-10-01,"wk inc flu hosp",1,"US","mean",NA,150 2022-10-01,"wk inc flu hosp",1,"US","quantile","0.01",135 2022-10-01,"wk inc flu hosp",1,"US","quantile","0.025",137 diff --git a/inst/example-data/2022-10-08-simple_hub-baseline.csv b/inst/example-data/2022-10-08-simple_hub-baseline.csv index 98ad58c..7387c87 100644 --- a/inst/example-data/2022-10-08-simple_hub-baseline.csv +++ b/inst/example-data/2022-10-08-simple_hub-baseline.csv @@ -1,4 +1,4 @@ -"origin_date","target","horizon","location","type","type_id","value" +"origin_date","target","horizon","location","output_type","output_type_id","value" 2022-10-08,"wk inc flu hosp",1,"US","quantile","0.01",135 2022-10-08,"wk inc flu hosp",1,"US","quantile","0.025",137 2022-10-08,"wk inc flu hosp",1,"US","quantile","0.05",139 diff --git a/inst/example-data/2022-10-08-team1-goodmodel.csv b/inst/example-data/2022-10-08-team1-goodmodel.csv index 3e6f875..71108b3 100644 --- a/inst/example-data/2022-10-08-team1-goodmodel.csv +++ b/inst/example-data/2022-10-08-team1-goodmodel.csv @@ -1,4 +1,4 @@ -"origin_date","target","horizon","location","type","type_id","value" +"origin_date","target","horizon","location","output_type","output_type_id","value" 2022-10-08,"wk inc flu hosp",1,"US","quantile","0.01",135 2022-10-08,"wk inc flu hosp",1,"US","quantile","0.025",137 2022-10-08,"wk inc flu hosp",1,"US","quantile","0.05",139 diff --git a/man/pipe.Rd b/man/pipe.Rd new file mode 100644 index 0000000..a648c29 --- /dev/null +++ b/man/pipe.Rd @@ -0,0 +1,20 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/utils-pipe.R +\name{\%>\%} +\alias{\%>\%} +\title{Pipe operator} +\usage{ +lhs \%>\% rhs +} +\arguments{ +\item{lhs}{A value or the magrittr placeholder.} + +\item{rhs}{A function call using the magrittr semantics.} +} +\value{ +The result of calling \code{rhs(lhs)}. +} +\description{ +See \code{magrittr::\link[magrittr:pipe]{\%>\%}} for details. +} +\keyword{internal} diff --git a/man/simple_ensemble.Rd b/man/simple_ensemble.Rd new file mode 100644 index 0000000..b231d55 --- /dev/null +++ b/man/simple_ensemble.Rd @@ -0,0 +1,65 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/simple_ensemble.R +\name{simple_ensemble} +\alias{simple_ensemble} +\title{Compute ensemble model outputs by summarizing component model outputs for +each combination of model task, output type, and output type id. Supported +output types include \code{mean}, \code{median}, \code{quantile}, \code{cdf}, and \code{pmf}.} +\usage{ +simple_ensemble( + model_outputs, + weights = NULL, + agg_fun = "mean", + agg_args = list(), + model_id = "hub-ensemble", + task_id_cols = NULL, + output_type_col = "output_type", + output_type_id_col = "output_type_id", + hub_connection = NULL +) +} +\arguments{ +\item{model_outputs}{an object of class \code{model_output_df} with component +model outputs (e.g., predictions).} + +\item{weights}{an optional \code{data.frame} with component model weights. If +provided, it should have columns \code{model_id}, \code{weight}, and optionally, +additional columns corresponding to task id variables, \code{output_type}, or +\code{output_type_id}, if weights are specific to values of those variables. The +default is \code{NULL}, in which case an equally-weighted ensemble is calculated.} + +\item{agg_fun}{a function or character string name of a function to use for +aggregating component model outputs into the ensemble outputs. See the +details for more information.} + +\item{agg_args}{a named list of any additional arguments that will be passed +to \code{agg_fun}.} + +\item{model_id}{\code{character} string with the identifier to use for the +ensemble model.} + +\item{task_id_cols, }{output_type_col, output_type_id_col, value_col +\code{character} vectors with the names of the columns in \code{model_outputs} for +the output's type, additional identifying information, and value of the +model output.} +} +\value{ +a data.frame with columns \code{team_abbr}, \code{model_abbr}, one column for +each task id variable, \code{output_type}, \code{output_id}, and \code{value}. Note that +any additional columns in the input \code{model_outputs} are dropped. +} +\description{ +Compute ensemble model outputs by summarizing component model outputs for +each combination of model task, output type, and output type id. Supported +output types include \code{mean}, \code{median}, \code{quantile}, \code{cdf}, and \code{pmf}. +} +\details{ +The default for \code{agg_fun} is \code{mean}, in which case the ensemble's +output is the average of the component model outputs within each group +defined by a combination of values in the task id columns, output type, and +output type id. The provided \code{agg_fun} should have an argument \code{x} for the +vector of numeric values to summarize, and for weighted methods, an +argument \code{w} with a numeric vector of weights. For weighted methods, +\code{agg_fun = "mean"} and \code{agg_fun = "median"} are translated to use +\code{matrixStats::weightedMean} and \code{matrixStats::weightedMedian} respectively. +} diff --git a/tests/testthat/test-data/2022-10-08-simple_hub-baseline.csv b/tests/testthat/test-data/2022-10-08-simple_hub-baseline.csv deleted file mode 100644 index 98ad58c..0000000 --- a/tests/testthat/test-data/2022-10-08-simple_hub-baseline.csv +++ /dev/null @@ -1,277 +0,0 @@ -"origin_date","target","horizon","location","type","type_id","value" -2022-10-08,"wk inc flu hosp",1,"US","quantile","0.01",135 -2022-10-08,"wk inc flu hosp",1,"US","quantile","0.025",137 -2022-10-08,"wk inc flu hosp",1,"US","quantile","0.05",139 -2022-10-08,"wk inc flu hosp",1,"US","quantile","0.1",140 -2022-10-08,"wk inc flu hosp",1,"US","quantile","0.15",141 -2022-10-08,"wk inc flu hosp",1,"US","quantile","0.2",141 -2022-10-08,"wk inc flu hosp",1,"US","quantile","0.25",142 -2022-10-08,"wk inc flu hosp",1,"US","quantile","0.3",143 -2022-10-08,"wk inc flu hosp",1,"US","quantile","0.35",144 -2022-10-08,"wk inc flu hosp",1,"US","quantile","0.4",145 -2022-10-08,"wk inc flu hosp",1,"US","quantile","0.45",147 -2022-10-08,"wk inc flu hosp",1,"US","quantile","0.5",148 -2022-10-08,"wk inc flu hosp",1,"US","quantile","0.55",149 -2022-10-08,"wk inc flu hosp",1,"US","quantile","0.6",150 -2022-10-08,"wk inc flu hosp",1,"US","quantile","0.65",152 -2022-10-08,"wk inc flu hosp",1,"US","quantile","0.7",155 -2022-10-08,"wk inc flu hosp",1,"US","quantile","0.75",161 -2022-10-08,"wk inc flu hosp",1,"US","quantile","0.8",165 -2022-10-08,"wk inc flu hosp",1,"US","quantile","0.85",170 -2022-10-08,"wk inc flu hosp",1,"US","quantile","0.9",175 -2022-10-08,"wk inc flu hosp",1,"US","quantile","0.95",176 -2022-10-08,"wk inc flu hosp",1,"US","quantile","0.975",176 -2022-10-08,"wk inc flu hosp",1,"US","quantile","0.99",205 -2022-10-08,"wk inc flu hosp",2,"US","quantile","0.01",135 -2022-10-08,"wk inc flu hosp",2,"US","quantile","0.025",137 -2022-10-08,"wk inc flu hosp",2,"US","quantile","0.05",139 -2022-10-08,"wk inc flu hosp",2,"US","quantile","0.1",140 -2022-10-08,"wk inc flu hosp",2,"US","quantile","0.15",141 -2022-10-08,"wk inc flu hosp",2,"US","quantile","0.2",141 -2022-10-08,"wk inc flu hosp",2,"US","quantile","0.25",142 -2022-10-08,"wk inc flu hosp",2,"US","quantile","0.3",143 -2022-10-08,"wk inc flu hosp",2,"US","quantile","0.35",144 -2022-10-08,"wk inc flu hosp",2,"US","quantile","0.4",145 -2022-10-08,"wk inc flu hosp",2,"US","quantile","0.45",147 -2022-10-08,"wk inc flu hosp",2,"US","quantile","0.5",148 -2022-10-08,"wk inc flu hosp",2,"US","quantile","0.55",149 -2022-10-08,"wk inc flu hosp",2,"US","quantile","0.6",150 -2022-10-08,"wk inc flu hosp",2,"US","quantile","0.65",152 -2022-10-08,"wk inc flu hosp",2,"US","quantile","0.7",155 -2022-10-08,"wk inc flu hosp",2,"US","quantile","0.75",161 -2022-10-08,"wk inc flu hosp",2,"US","quantile","0.8",165 -2022-10-08,"wk inc flu hosp",2,"US","quantile","0.85",170 -2022-10-08,"wk inc flu hosp",2,"US","quantile","0.9",175 -2022-10-08,"wk inc flu hosp",2,"US","quantile","0.95",176 -2022-10-08,"wk inc flu hosp",2,"US","quantile","0.975",176 -2022-10-08,"wk inc flu hosp",2,"US","quantile","0.99",205 -2022-10-08,"wk inc flu hosp",3,"US","quantile","0.01",135 -2022-10-08,"wk inc flu hosp",3,"US","quantile","0.025",137 -2022-10-08,"wk inc flu hosp",3,"US","quantile","0.05",139 -2022-10-08,"wk inc flu hosp",3,"US","quantile","0.1",140 -2022-10-08,"wk inc flu hosp",3,"US","quantile","0.15",141 -2022-10-08,"wk inc flu hosp",3,"US","quantile","0.2",141 -2022-10-08,"wk inc flu hosp",3,"US","quantile","0.25",142 -2022-10-08,"wk inc flu hosp",3,"US","quantile","0.3",143 -2022-10-08,"wk inc flu hosp",3,"US","quantile","0.35",144 -2022-10-08,"wk inc flu hosp",3,"US","quantile","0.4",145 -2022-10-08,"wk inc flu hosp",3,"US","quantile","0.45",147 -2022-10-08,"wk inc flu hosp",3,"US","quantile","0.5",148 -2022-10-08,"wk inc flu hosp",3,"US","quantile","0.55",149 -2022-10-08,"wk inc flu hosp",3,"US","quantile","0.6",150 -2022-10-08,"wk inc flu hosp",3,"US","quantile","0.65",152 -2022-10-08,"wk inc flu hosp",3,"US","quantile","0.7",155 -2022-10-08,"wk inc flu hosp",3,"US","quantile","0.75",161 -2022-10-08,"wk inc flu hosp",3,"US","quantile","0.8",165 -2022-10-08,"wk inc flu hosp",3,"US","quantile","0.85",170 -2022-10-08,"wk inc flu hosp",3,"US","quantile","0.9",175 -2022-10-08,"wk inc flu hosp",3,"US","quantile","0.95",176 -2022-10-08,"wk inc flu hosp",3,"US","quantile","0.975",176 -2022-10-08,"wk inc flu hosp",3,"US","quantile","0.99",205 -2022-10-08,"wk inc flu hosp",4,"US","quantile","0.01",135 -2022-10-08,"wk inc flu hosp",4,"US","quantile","0.025",137 -2022-10-08,"wk inc flu hosp",4,"US","quantile","0.05",139 -2022-10-08,"wk inc flu hosp",4,"US","quantile","0.1",140 -2022-10-08,"wk inc flu hosp",4,"US","quantile","0.15",141 -2022-10-08,"wk inc flu hosp",4,"US","quantile","0.2",141 -2022-10-08,"wk inc flu hosp",4,"US","quantile","0.25",142 -2022-10-08,"wk inc flu hosp",4,"US","quantile","0.3",143 -2022-10-08,"wk inc flu hosp",4,"US","quantile","0.35",144 -2022-10-08,"wk inc flu hosp",4,"US","quantile","0.4",145 -2022-10-08,"wk inc flu hosp",4,"US","quantile","0.45",147 -2022-10-08,"wk inc flu hosp",4,"US","quantile","0.5",148 -2022-10-08,"wk inc flu hosp",4,"US","quantile","0.55",149 -2022-10-08,"wk inc flu hosp",4,"US","quantile","0.6",150 -2022-10-08,"wk inc flu hosp",4,"US","quantile","0.65",152 -2022-10-08,"wk inc flu hosp",4,"US","quantile","0.7",155 -2022-10-08,"wk inc flu hosp",4,"US","quantile","0.75",161 -2022-10-08,"wk inc flu hosp",4,"US","quantile","0.8",165 -2022-10-08,"wk inc flu hosp",4,"US","quantile","0.85",170 -2022-10-08,"wk inc flu hosp",4,"US","quantile","0.9",175 -2022-10-08,"wk inc flu hosp",4,"US","quantile","0.95",176 -2022-10-08,"wk inc flu hosp",4,"US","quantile","0.975",176 -2022-10-08,"wk inc flu hosp",4,"US","quantile","0.99",205 -2022-10-08,"wk inc flu hosp",1,"04","quantile","0.01",135 -2022-10-08,"wk inc flu hosp",1,"04","quantile","0.025",137 -2022-10-08,"wk inc flu hosp",1,"04","quantile","0.05",139 -2022-10-08,"wk inc flu hosp",1,"04","quantile","0.1",140 -2022-10-08,"wk inc flu hosp",1,"04","quantile","0.15",141 -2022-10-08,"wk inc flu hosp",1,"04","quantile","0.2",141 -2022-10-08,"wk inc flu hosp",1,"04","quantile","0.25",142 -2022-10-08,"wk inc flu hosp",1,"04","quantile","0.3",143 -2022-10-08,"wk inc flu hosp",1,"04","quantile","0.35",144 -2022-10-08,"wk inc flu hosp",1,"04","quantile","0.4",145 -2022-10-08,"wk inc flu hosp",1,"04","quantile","0.45",147 -2022-10-08,"wk inc flu hosp",1,"04","quantile","0.5",148 -2022-10-08,"wk inc flu hosp",1,"04","quantile","0.55",149 -2022-10-08,"wk inc flu hosp",1,"04","quantile","0.6",150 -2022-10-08,"wk inc flu hosp",1,"04","quantile","0.65",152 -2022-10-08,"wk inc flu hosp",1,"04","quantile","0.7",155 -2022-10-08,"wk inc flu hosp",1,"04","quantile","0.75",161 -2022-10-08,"wk inc flu hosp",1,"04","quantile","0.8",165 -2022-10-08,"wk inc flu hosp",1,"04","quantile","0.85",170 -2022-10-08,"wk inc flu hosp",1,"04","quantile","0.9",175 -2022-10-08,"wk inc flu hosp",1,"04","quantile","0.95",176 -2022-10-08,"wk inc flu hosp",1,"04","quantile","0.975",176 -2022-10-08,"wk inc flu hosp",1,"04","quantile","0.99",205 -2022-10-08,"wk inc flu hosp",2,"04","quantile","0.01",135 -2022-10-08,"wk inc flu hosp",2,"04","quantile","0.025",137 -2022-10-08,"wk inc flu hosp",2,"04","quantile","0.05",139 -2022-10-08,"wk inc flu hosp",2,"04","quantile","0.1",140 -2022-10-08,"wk inc flu hosp",2,"04","quantile","0.15",141 -2022-10-08,"wk inc flu hosp",2,"04","quantile","0.2",141 -2022-10-08,"wk inc flu hosp",2,"04","quantile","0.25",142 -2022-10-08,"wk inc flu hosp",2,"04","quantile","0.3",143 -2022-10-08,"wk inc flu hosp",2,"04","quantile","0.35",144 -2022-10-08,"wk inc flu hosp",2,"04","quantile","0.4",145 -2022-10-08,"wk inc flu hosp",2,"04","quantile","0.45",147 -2022-10-08,"wk inc flu hosp",2,"04","quantile","0.5",148 -2022-10-08,"wk inc flu hosp",2,"04","quantile","0.55",149 -2022-10-08,"wk inc flu hosp",2,"04","quantile","0.6",150 -2022-10-08,"wk inc flu hosp",2,"04","quantile","0.65",152 -2022-10-08,"wk inc flu hosp",2,"04","quantile","0.7",155 -2022-10-08,"wk inc flu hosp",2,"04","quantile","0.75",161 -2022-10-08,"wk inc flu hosp",2,"04","quantile","0.8",165 -2022-10-08,"wk inc flu hosp",2,"04","quantile","0.85",170 -2022-10-08,"wk inc flu hosp",2,"04","quantile","0.9",175 -2022-10-08,"wk inc flu hosp",2,"04","quantile","0.95",176 -2022-10-08,"wk inc flu hosp",2,"04","quantile","0.975",176 -2022-10-08,"wk inc flu hosp",2,"04","quantile","0.99",205 -2022-10-08,"wk inc flu hosp",3,"04","quantile","0.01",135 -2022-10-08,"wk inc flu hosp",3,"04","quantile","0.025",137 -2022-10-08,"wk inc flu hosp",3,"04","quantile","0.05",139 -2022-10-08,"wk inc flu hosp",3,"04","quantile","0.1",140 -2022-10-08,"wk inc flu hosp",3,"04","quantile","0.15",141 -2022-10-08,"wk inc flu hosp",3,"04","quantile","0.2",141 -2022-10-08,"wk inc flu hosp",3,"04","quantile","0.25",142 -2022-10-08,"wk inc flu hosp",3,"04","quantile","0.3",143 -2022-10-08,"wk inc flu hosp",3,"04","quantile","0.35",144 -2022-10-08,"wk inc flu hosp",3,"04","quantile","0.4",145 -2022-10-08,"wk inc flu hosp",3,"04","quantile","0.45",147 -2022-10-08,"wk inc flu hosp",3,"04","quantile","0.5",148 -2022-10-08,"wk inc flu hosp",3,"04","quantile","0.55",149 -2022-10-08,"wk inc flu hosp",3,"04","quantile","0.6",150 -2022-10-08,"wk inc flu hosp",3,"04","quantile","0.65",152 -2022-10-08,"wk inc flu hosp",3,"04","quantile","0.7",155 -2022-10-08,"wk inc flu hosp",3,"04","quantile","0.75",161 -2022-10-08,"wk inc flu hosp",3,"04","quantile","0.8",165 -2022-10-08,"wk inc flu hosp",3,"04","quantile","0.85",170 -2022-10-08,"wk inc flu hosp",3,"04","quantile","0.9",175 -2022-10-08,"wk inc flu hosp",3,"04","quantile","0.95",176 -2022-10-08,"wk inc flu hosp",3,"04","quantile","0.975",176 -2022-10-08,"wk inc flu hosp",3,"04","quantile","0.99",205 -2022-10-08,"wk inc flu hosp",4,"04","quantile","0.01",135 -2022-10-08,"wk inc flu hosp",4,"04","quantile","0.025",137 -2022-10-08,"wk inc flu hosp",4,"04","quantile","0.05",139 -2022-10-08,"wk inc flu hosp",4,"04","quantile","0.1",140 -2022-10-08,"wk inc flu hosp",4,"04","quantile","0.15",141 -2022-10-08,"wk inc flu hosp",4,"04","quantile","0.2",141 -2022-10-08,"wk inc flu hosp",4,"04","quantile","0.25",142 -2022-10-08,"wk inc flu hosp",4,"04","quantile","0.3",143 -2022-10-08,"wk inc flu hosp",4,"04","quantile","0.35",144 -2022-10-08,"wk inc flu hosp",4,"04","quantile","0.4",145 -2022-10-08,"wk inc flu hosp",4,"04","quantile","0.45",147 -2022-10-08,"wk inc flu hosp",4,"04","quantile","0.5",148 -2022-10-08,"wk inc flu hosp",4,"04","quantile","0.55",149 -2022-10-08,"wk inc flu hosp",4,"04","quantile","0.6",150 -2022-10-08,"wk inc flu hosp",4,"04","quantile","0.65",152 -2022-10-08,"wk inc flu hosp",4,"04","quantile","0.7",155 -2022-10-08,"wk inc flu hosp",4,"04","quantile","0.75",161 -2022-10-08,"wk inc flu hosp",4,"04","quantile","0.8",165 -2022-10-08,"wk inc flu hosp",4,"04","quantile","0.85",170 -2022-10-08,"wk inc flu hosp",4,"04","quantile","0.9",175 -2022-10-08,"wk inc flu hosp",4,"04","quantile","0.95",176 -2022-10-08,"wk inc flu hosp",4,"04","quantile","0.975",176 -2022-10-08,"wk inc flu hosp",4,"04","quantile","0.99",205 -2022-10-08,"wk inc flu hosp",1,"01","quantile","0.01",135 -2022-10-08,"wk inc flu hosp",1,"01","quantile","0.025",137 -2022-10-08,"wk inc flu hosp",1,"01","quantile","0.05",139 -2022-10-08,"wk inc flu hosp",1,"01","quantile","0.1",140 -2022-10-08,"wk inc flu hosp",1,"01","quantile","0.15",141 -2022-10-08,"wk inc flu hosp",1,"01","quantile","0.2",141 -2022-10-08,"wk inc flu hosp",1,"01","quantile","0.25",142 -2022-10-08,"wk inc flu hosp",1,"01","quantile","0.3",143 -2022-10-08,"wk inc flu hosp",1,"01","quantile","0.35",144 -2022-10-08,"wk inc flu hosp",1,"01","quantile","0.4",145 -2022-10-08,"wk inc flu hosp",1,"01","quantile","0.45",147 -2022-10-08,"wk inc flu hosp",1,"01","quantile","0.5",148 -2022-10-08,"wk inc flu hosp",1,"01","quantile","0.55",149 -2022-10-08,"wk inc flu hosp",1,"01","quantile","0.6",150 -2022-10-08,"wk inc flu hosp",1,"01","quantile","0.65",152 -2022-10-08,"wk inc flu hosp",1,"01","quantile","0.7",155 -2022-10-08,"wk inc flu hosp",1,"01","quantile","0.75",161 -2022-10-08,"wk inc flu hosp",1,"01","quantile","0.8",165 -2022-10-08,"wk inc flu hosp",1,"01","quantile","0.85",170 -2022-10-08,"wk inc flu hosp",1,"01","quantile","0.9",175 -2022-10-08,"wk inc flu hosp",1,"01","quantile","0.95",176 -2022-10-08,"wk inc flu hosp",1,"01","quantile","0.975",176 -2022-10-08,"wk inc flu hosp",1,"01","quantile","0.99",205 -2022-10-08,"wk inc flu hosp",2,"01","quantile","0.01",135 -2022-10-08,"wk inc flu hosp",2,"01","quantile","0.025",137 -2022-10-08,"wk inc flu hosp",2,"01","quantile","0.05",139 -2022-10-08,"wk inc flu hosp",2,"01","quantile","0.1",140 -2022-10-08,"wk inc flu hosp",2,"01","quantile","0.15",141 -2022-10-08,"wk inc flu hosp",2,"01","quantile","0.2",141 -2022-10-08,"wk inc flu hosp",2,"01","quantile","0.25",142 -2022-10-08,"wk inc flu hosp",2,"01","quantile","0.3",143 -2022-10-08,"wk inc flu hosp",2,"01","quantile","0.35",144 -2022-10-08,"wk inc flu hosp",2,"01","quantile","0.4",145 -2022-10-08,"wk inc flu hosp",2,"01","quantile","0.45",147 -2022-10-08,"wk inc flu hosp",2,"01","quantile","0.5",148 -2022-10-08,"wk inc flu hosp",2,"01","quantile","0.55",149 -2022-10-08,"wk inc flu hosp",2,"01","quantile","0.6",150 -2022-10-08,"wk inc flu hosp",2,"01","quantile","0.65",152 -2022-10-08,"wk inc flu hosp",2,"01","quantile","0.7",155 -2022-10-08,"wk inc flu hosp",2,"01","quantile","0.75",161 -2022-10-08,"wk inc flu hosp",2,"01","quantile","0.8",165 -2022-10-08,"wk inc flu hosp",2,"01","quantile","0.85",170 -2022-10-08,"wk inc flu hosp",2,"01","quantile","0.9",175 -2022-10-08,"wk inc flu hosp",2,"01","quantile","0.95",176 -2022-10-08,"wk inc flu hosp",2,"01","quantile","0.975",176 -2022-10-08,"wk inc flu hosp",2,"01","quantile","0.99",205 -2022-10-08,"wk inc flu hosp",3,"01","quantile","0.01",135 -2022-10-08,"wk inc flu hosp",3,"01","quantile","0.025",137 -2022-10-08,"wk inc flu hosp",3,"01","quantile","0.05",139 -2022-10-08,"wk inc flu hosp",3,"01","quantile","0.1",140 -2022-10-08,"wk inc flu hosp",3,"01","quantile","0.15",141 -2022-10-08,"wk inc flu hosp",3,"01","quantile","0.2",141 -2022-10-08,"wk inc flu hosp",3,"01","quantile","0.25",142 -2022-10-08,"wk inc flu hosp",3,"01","quantile","0.3",143 -2022-10-08,"wk inc flu hosp",3,"01","quantile","0.35",144 -2022-10-08,"wk inc flu hosp",3,"01","quantile","0.4",145 -2022-10-08,"wk inc flu hosp",3,"01","quantile","0.45",147 -2022-10-08,"wk inc flu hosp",3,"01","quantile","0.5",148 -2022-10-08,"wk inc flu hosp",3,"01","quantile","0.55",149 -2022-10-08,"wk inc flu hosp",3,"01","quantile","0.6",150 -2022-10-08,"wk inc flu hosp",3,"01","quantile","0.65",152 -2022-10-08,"wk inc flu hosp",3,"01","quantile","0.7",155 -2022-10-08,"wk inc flu hosp",3,"01","quantile","0.75",161 -2022-10-08,"wk inc flu hosp",3,"01","quantile","0.8",165 -2022-10-08,"wk inc flu hosp",3,"01","quantile","0.85",170 -2022-10-08,"wk inc flu hosp",3,"01","quantile","0.9",175 -2022-10-08,"wk inc flu hosp",3,"01","quantile","0.95",176 -2022-10-08,"wk inc flu hosp",3,"01","quantile","0.975",176 -2022-10-08,"wk inc flu hosp",3,"01","quantile","0.99",205 -2022-10-08,"wk inc flu hosp",4,"01","quantile","0.01",135 -2022-10-08,"wk inc flu hosp",4,"01","quantile","0.025",137 -2022-10-08,"wk inc flu hosp",4,"01","quantile","0.05",139 -2022-10-08,"wk inc flu hosp",4,"01","quantile","0.1",140 -2022-10-08,"wk inc flu hosp",4,"01","quantile","0.15",141 -2022-10-08,"wk inc flu hosp",4,"01","quantile","0.2",141 -2022-10-08,"wk inc flu hosp",4,"01","quantile","0.25",142 -2022-10-08,"wk inc flu hosp",4,"01","quantile","0.3",143 -2022-10-08,"wk inc flu hosp",4,"01","quantile","0.35",144 -2022-10-08,"wk inc flu hosp",4,"01","quantile","0.4",145 -2022-10-08,"wk inc flu hosp",4,"01","quantile","0.45",147 -2022-10-08,"wk inc flu hosp",4,"01","quantile","0.5",148 -2022-10-08,"wk inc flu hosp",4,"01","quantile","0.55",149 -2022-10-08,"wk inc flu hosp",4,"01","quantile","0.6",150 -2022-10-08,"wk inc flu hosp",4,"01","quantile","0.65",152 -2022-10-08,"wk inc flu hosp",4,"01","quantile","0.7",155 -2022-10-08,"wk inc flu hosp",4,"01","quantile","0.75",161 -2022-10-08,"wk inc flu hosp",4,"01","quantile","0.8",165 -2022-10-08,"wk inc flu hosp",4,"01","quantile","0.85",170 -2022-10-08,"wk inc flu hosp",4,"01","quantile","0.9",175 -2022-10-08,"wk inc flu hosp",4,"01","quantile","0.95",176 -2022-10-08,"wk inc flu hosp",4,"01","quantile","0.975",176 -2022-10-08,"wk inc flu hosp",4,"01","quantile","0.99",205 diff --git a/tests/testthat/test-simple_ensemble.R b/tests/testthat/test-simple_ensemble.R index e116a83..a9f6907 100644 --- a/tests/testthat/test-simple_ensemble.R +++ b/tests/testthat/test-simple_ensemble.R @@ -1,105 +1,120 @@ library(matrixStats) +library(dplyr) -pred <- readr::read_csv("inst/example-data/2022-10-08-simple_hub-baseline.csv") %>% - mutate(team_abbr = "simple_hub", model_abbr = "baseline", .before = origin_date) +pred <- read.csv(system.file("example-data/2022-10-08-simple_hub-baseline.csv", + package = "hubEnsembles")) %>% + dplyr::mutate(model_id = "simple_hub-baseline", .before = origin_date) -test_that("non-default column names are preserved in output data frame", { +test_that("non-default columns are dropped from output", { output_names <- pred %>% + dplyr::mutate(extra_col_1 = "a", extra_col_2 = "a") %>% simple_ensemble( - agg_fun="mean", - model_abbr = "example", - output_type_col = "type", - output_id_col = "type_id" + task_id_cols = c("origin_date", "target", "horizon", "location") ) %>% names() - expect_equal(names(pred), output_names) + + expect_equal(sort(names(pred)), sort(output_names)) }) + test_that("invalid output type throws error", { expect_error( pred %>% - rename(output_type = type, output_id=type_id) %>% - mutate(output_type="sample") %>% - simple_ensemble( - agg_fun="mean", - model_abbr = "example" - ) + dplyr::mutate(output_type = "sample") %>% + simple_ensemble() ) }) + test_that("invalid method argument throws error", { expect_error( - pred %>% - rename(output_type = type, output_id=type_id) %>% - simple_ensemble( - agg_fun="linear pool", - model_abbr = "example" - ) + simple_ensemble(pred, agg_fun="linear pool") ) }) -test_that("medians and means correctly calculated", { + +test_that("(weighted) medians and means correctly calculated", { fdat <- expand.grid( stringsAsFactors = FALSE, - model_abbr = letters[1:4], + model_id = letters[1:4], location = c("222", "888"), horizon = 1, #week target = "inc death", target_date = as.Date("2021-12-25"), output_type = "quantile", - output_id = c(.1, .5, .9)) + output_type_id = c(.1, .5, .9), + value = NA_real_) - fdat <- cbind(team_abbr = letters[23:26], fdat) - fdat$value[fdat$location == "222" & fdat$output_id == .1] <- v2.1 <- c(10, 30, 15, 20) - fdat$value[fdat$location == "222" & fdat$output_id == .5] <- v2.5 <- c(40, 40, 45, 50) - fdat$value[fdat$location == "222" & fdat$output_id == .9] <- v2.9 <- c(60, 70, 75, 80) - fdat$value[fdat$location == "888" & fdat$output_id == .1] <- v8.1 <- c(100, 300, 400, 250) - fdat$value[fdat$location == "888" & fdat$output_id == .5] <- v8.5 <- c(150, 325, 500, 300) - fdat$value[fdat$location == "888" & fdat$output_id == .9] <- v8.9 <- c(250, 350, 500, 350) + fdat$value[fdat$location == "222" & fdat$output_type_id == .1] <- v2.1 <- + c(10, 30, 15, 20) + fdat$value[fdat$location == "222" & fdat$output_type_id == .5] <- v2.5 <- + c(40, 40, 45, 50) + fdat$value[fdat$location == "222" & fdat$output_type_id == .9] <- v2.9 <- + c(60, 70, 75, 80) + fdat$value[fdat$location == "888" & fdat$output_type_id == .1] <- v8.1 <- + c(100, 300, 400, 250) + fdat$value[fdat$location == "888" & fdat$output_type_id == .5] <- v8.5 <- + c(150, 325, 500, 300) + fdat$value[fdat$location == "888" & fdat$output_type_id == .9] <- v8.9 <- + c(250, 350, 500, 350) + + fweight2 <- data.frame(model_id = letters[1:4], + location = "222", + weight = 0.1 * (1:4)) + fweight8 <- data.frame(model_id = letters[1:4], + location = "888", + weight = 0.1 * (4:1)) + fweight <- bind_rows(fweight2, fweight8) + + median_expected <- mean_expected <- + weighted_median_expected <- weighted_mean_expected <- data.frame( + model_id = "hub-ensemble", + location = rep(c("222", "888"), each = 3), + horizon = 1, + target = "inc death", + target_date = as.Date("2021-12-25"), + output_type = "quantile", + output_type_id = rep(c(.1, .5, .9), 2), + value = NA_real_) median_vals <- sapply(list(v2.1, v2.5, v2.9, v8.1, v8.5, v8.9), median) mean_vals <- sapply(list(v2.1, v2.5, v2.9, v8.1, v8.5, v8.9), mean) - fweight <- tibble(model_abbr = letters[1:4], weight = 0.1*(1:4)) + weighted_median_vals <- c( + sapply(list(v2.1, v2.5, v2.9), + matrixStats::weightedMedian, + w = fweight2$weight), + sapply(list(v8.1, v8.5, v8.9), + matrixStats::weightedMedian, + w = fweight8$weight)) + weighted_mean_vals <- c( + sapply(list(v2.1, v2.5, v2.9), + matrixStats::weightedMean, + w = fweight2$weight), + sapply(list(v8.1, v8.5, v8.9), + matrixStats::weightedMean, + w = fweight8$weight)) - weighted_median_vals <- map(list(v2.1, v2.5, v2.9, v8.1, v8.5, v8.9), weightedMedian, w = fweight$weight) - weighted_mean_vals <- map(list(v2.1, v2.5, v2.9, v8.1, v8.5, v8.9), weightedMean, w = fweight$weight) + median_expected$value <- median_vals + mean_expected$value <- mean_vals + weighted_mean_expected$value <- weighted_mean_vals + weighted_median_expected$value <- weighted_median_vals - median_actual <- simple_ensemble( - predictions = fdat, weights = NULL, agg_fun = "median", model_abbr = "median_ens") - mean_actual <- simple_ensemble( - predictions = fdat, weights = NULL, agg_fun = "mean", model_abbr = "mean_ens") - - weighted_median_actual <- simple_ensemble( - predictions = fdat, weights = fweight, agg_fun = "median", model_abbr = "weighted_median_ens") - weighted_mean_actual <- simple_ensemble( - predictions = fdat, weights = fweight, agg_fun = "mean", model_abbr = "weighted_mean_ens") + median_actual <- simple_ensemble(model_outputs = fdat, weights = NULL, + agg_fun = "median") + mean_actual <- simple_ensemble(model_outputs = fdat, weights = NULL, + agg_fun = "mean") + weighted_median_actual <- simple_ensemble(model_outputs = fdat, + weights = fweight, + agg_fun = "median") + weighted_mean_actual <- simple_ensemble(model_outputs = fdat, + weights = fweight, + agg_fun = "mean") - median_expected <- mean_expected <- weighted_median_expected <- weighted_mean_expected <- tibble::tibble( - team_abbr = "Hub", - model_abbr = NA, - location = rep(c("222", "888"), each = 3), - horizon = 1, #week - target = "inc death", - target_date = as.Date("2021-12-25"), - output_type = "quantile", - output_id = rep(c(.1, .5, .9), 2), - value = 0) - - median_expected$value <- sapply(list(v2.1, v2.5, v2.9, v8.1, v8.5, v8.9), median) - median_expected$model_abbr <- "median_ens" - mean_expected$value <- sapply(list(v2.1, v2.5, v2.9, v8.1, v8.5, v8.9), mean) - mean_expected$model_abbr <- "mean_ens" - - weighted_mean_expected$value <- map_dbl(list(v2.1, v2.5, v2.9, v8.1, v8.5, v8.9), weightedMean, w = fweight$weight) - weighted_mean_expected$model_abbr <- "weighted_mean_ens" - weighted_median_expected$value <- map_dbl(list(v2.1, v2.5, v2.9, v8.1, v8.5, v8.9), weightedMedian, w = fweight$weight) - weighted_median_expected$model_abbr <- "weighted_median_ens" - - expect_equal(median_actual, median_expected) - expect_equal(mean_actual, mean_expected) - - expect_equal(weighted_median_actual, weighted_median_expected) - expect_equal(weighted_mean_actual, weighted_mean_expected) -}) \ No newline at end of file + expect_equal(as.data.frame(median_actual), median_expected) + expect_equal(as.data.frame(mean_actual), mean_expected) + + expect_equal(as.data.frame(weighted_median_actual), weighted_median_expected) + expect_equal(as.data.frame(weighted_mean_actual), weighted_mean_expected) +}) From 6671da285e7e112cee35843983c33c7fb5b2d383 Mon Sep 17 00:00:00 2001 From: Evan Ray Date: Wed, 7 Jun 2023 15:14:17 -0400 Subject: [PATCH 16/24] remove stale example data, update data example in unit tests --- .../2022-10-01-simple_hub-baseline.csv | 25 -- .../2022-10-08-simple_hub-baseline.csv | 277 ------------------ .../2022-10-08-team1-goodmodel.csv | 24 -- tests/testthat/test-simple_ensemble.R | 92 +++--- 4 files changed, 48 insertions(+), 370 deletions(-) delete mode 100644 inst/example-data/2022-10-01-simple_hub-baseline.csv delete mode 100644 inst/example-data/2022-10-08-simple_hub-baseline.csv delete mode 100644 inst/example-data/2022-10-08-team1-goodmodel.csv diff --git a/inst/example-data/2022-10-01-simple_hub-baseline.csv b/inst/example-data/2022-10-01-simple_hub-baseline.csv deleted file mode 100644 index 703831b..0000000 --- a/inst/example-data/2022-10-01-simple_hub-baseline.csv +++ /dev/null @@ -1,25 +0,0 @@ -"origin_date","target","horizon","location","output_type","output_type_id","value" -2022-10-01,"wk inc flu hosp",1,"US","mean",NA,150 -2022-10-01,"wk inc flu hosp",1,"US","quantile","0.01",135 -2022-10-01,"wk inc flu hosp",1,"US","quantile","0.025",137 -2022-10-01,"wk inc flu hosp",1,"US","quantile","0.05",139 -2022-10-01,"wk inc flu hosp",1,"US","quantile","0.1",140 -2022-10-01,"wk inc flu hosp",1,"US","quantile","0.15",141 -2022-10-01,"wk inc flu hosp",1,"US","quantile","0.2",141 -2022-10-01,"wk inc flu hosp",1,"US","quantile","0.25",142 -2022-10-01,"wk inc flu hosp",1,"US","quantile","0.3",143 -2022-10-01,"wk inc flu hosp",1,"US","quantile","0.35",144 -2022-10-01,"wk inc flu hosp",1,"US","quantile","0.4",145 -2022-10-01,"wk inc flu hosp",1,"US","quantile","0.45",147 -2022-10-01,"wk inc flu hosp",1,"US","quantile","0.5",148 -2022-10-01,"wk inc flu hosp",1,"US","quantile","0.55",149 -2022-10-01,"wk inc flu hosp",1,"US","quantile","0.6",150 -2022-10-01,"wk inc flu hosp",1,"US","quantile","0.65",152 -2022-10-01,"wk inc flu hosp",1,"US","quantile","0.7",155 -2022-10-01,"wk inc flu hosp",1,"US","quantile","0.75",161 -2022-10-01,"wk inc flu hosp",1,"US","quantile","0.8",165 -2022-10-01,"wk inc flu hosp",1,"US","quantile","0.85",170 -2022-10-01,"wk inc flu hosp",1,"US","quantile","0.9",175 -2022-10-01,"wk inc flu hosp",1,"US","quantile","0.95",176 -2022-10-01,"wk inc flu hosp",1,"US","quantile","0.975",176 -2022-10-01,"wk inc flu hosp",1,"US","quantile","0.99",205 diff --git a/inst/example-data/2022-10-08-simple_hub-baseline.csv b/inst/example-data/2022-10-08-simple_hub-baseline.csv deleted file mode 100644 index 7387c87..0000000 --- a/inst/example-data/2022-10-08-simple_hub-baseline.csv +++ /dev/null @@ -1,277 +0,0 @@ -"origin_date","target","horizon","location","output_type","output_type_id","value" -2022-10-08,"wk inc flu hosp",1,"US","quantile","0.01",135 -2022-10-08,"wk inc flu hosp",1,"US","quantile","0.025",137 -2022-10-08,"wk inc flu hosp",1,"US","quantile","0.05",139 -2022-10-08,"wk inc flu hosp",1,"US","quantile","0.1",140 -2022-10-08,"wk inc flu hosp",1,"US","quantile","0.15",141 -2022-10-08,"wk inc flu hosp",1,"US","quantile","0.2",141 -2022-10-08,"wk inc flu hosp",1,"US","quantile","0.25",142 -2022-10-08,"wk inc flu hosp",1,"US","quantile","0.3",143 -2022-10-08,"wk inc flu hosp",1,"US","quantile","0.35",144 -2022-10-08,"wk inc flu hosp",1,"US","quantile","0.4",145 -2022-10-08,"wk inc flu hosp",1,"US","quantile","0.45",147 -2022-10-08,"wk inc flu hosp",1,"US","quantile","0.5",148 -2022-10-08,"wk inc flu hosp",1,"US","quantile","0.55",149 -2022-10-08,"wk inc flu hosp",1,"US","quantile","0.6",150 -2022-10-08,"wk inc flu hosp",1,"US","quantile","0.65",152 -2022-10-08,"wk inc flu hosp",1,"US","quantile","0.7",155 -2022-10-08,"wk inc flu hosp",1,"US","quantile","0.75",161 -2022-10-08,"wk inc flu hosp",1,"US","quantile","0.8",165 -2022-10-08,"wk inc flu hosp",1,"US","quantile","0.85",170 -2022-10-08,"wk inc flu hosp",1,"US","quantile","0.9",175 -2022-10-08,"wk inc flu hosp",1,"US","quantile","0.95",176 -2022-10-08,"wk inc flu hosp",1,"US","quantile","0.975",176 -2022-10-08,"wk inc flu hosp",1,"US","quantile","0.99",205 -2022-10-08,"wk inc flu hosp",2,"US","quantile","0.01",135 -2022-10-08,"wk inc flu hosp",2,"US","quantile","0.025",137 -2022-10-08,"wk inc flu hosp",2,"US","quantile","0.05",139 -2022-10-08,"wk inc flu hosp",2,"US","quantile","0.1",140 -2022-10-08,"wk inc flu hosp",2,"US","quantile","0.15",141 -2022-10-08,"wk inc flu hosp",2,"US","quantile","0.2",141 -2022-10-08,"wk inc flu hosp",2,"US","quantile","0.25",142 -2022-10-08,"wk inc flu hosp",2,"US","quantile","0.3",143 -2022-10-08,"wk inc flu hosp",2,"US","quantile","0.35",144 -2022-10-08,"wk inc flu hosp",2,"US","quantile","0.4",145 -2022-10-08,"wk inc flu hosp",2,"US","quantile","0.45",147 -2022-10-08,"wk inc flu hosp",2,"US","quantile","0.5",148 -2022-10-08,"wk inc flu hosp",2,"US","quantile","0.55",149 -2022-10-08,"wk inc flu hosp",2,"US","quantile","0.6",150 -2022-10-08,"wk inc flu hosp",2,"US","quantile","0.65",152 -2022-10-08,"wk inc flu hosp",2,"US","quantile","0.7",155 -2022-10-08,"wk inc flu hosp",2,"US","quantile","0.75",161 -2022-10-08,"wk inc flu hosp",2,"US","quantile","0.8",165 -2022-10-08,"wk inc flu hosp",2,"US","quantile","0.85",170 -2022-10-08,"wk inc flu hosp",2,"US","quantile","0.9",175 -2022-10-08,"wk inc flu hosp",2,"US","quantile","0.95",176 -2022-10-08,"wk inc flu hosp",2,"US","quantile","0.975",176 -2022-10-08,"wk inc flu hosp",2,"US","quantile","0.99",205 -2022-10-08,"wk inc flu hosp",3,"US","quantile","0.01",135 -2022-10-08,"wk inc flu hosp",3,"US","quantile","0.025",137 -2022-10-08,"wk inc flu hosp",3,"US","quantile","0.05",139 -2022-10-08,"wk inc flu hosp",3,"US","quantile","0.1",140 -2022-10-08,"wk inc flu hosp",3,"US","quantile","0.15",141 -2022-10-08,"wk inc flu hosp",3,"US","quantile","0.2",141 -2022-10-08,"wk inc flu hosp",3,"US","quantile","0.25",142 -2022-10-08,"wk inc flu hosp",3,"US","quantile","0.3",143 -2022-10-08,"wk inc flu hosp",3,"US","quantile","0.35",144 -2022-10-08,"wk inc flu hosp",3,"US","quantile","0.4",145 -2022-10-08,"wk inc flu hosp",3,"US","quantile","0.45",147 -2022-10-08,"wk inc flu hosp",3,"US","quantile","0.5",148 -2022-10-08,"wk inc flu hosp",3,"US","quantile","0.55",149 -2022-10-08,"wk inc flu hosp",3,"US","quantile","0.6",150 -2022-10-08,"wk inc flu hosp",3,"US","quantile","0.65",152 -2022-10-08,"wk inc flu hosp",3,"US","quantile","0.7",155 -2022-10-08,"wk inc flu hosp",3,"US","quantile","0.75",161 -2022-10-08,"wk inc flu hosp",3,"US","quantile","0.8",165 -2022-10-08,"wk inc flu hosp",3,"US","quantile","0.85",170 -2022-10-08,"wk inc flu hosp",3,"US","quantile","0.9",175 -2022-10-08,"wk inc flu hosp",3,"US","quantile","0.95",176 -2022-10-08,"wk inc flu hosp",3,"US","quantile","0.975",176 -2022-10-08,"wk inc flu hosp",3,"US","quantile","0.99",205 -2022-10-08,"wk inc flu hosp",4,"US","quantile","0.01",135 -2022-10-08,"wk inc flu hosp",4,"US","quantile","0.025",137 -2022-10-08,"wk inc flu hosp",4,"US","quantile","0.05",139 -2022-10-08,"wk inc flu hosp",4,"US","quantile","0.1",140 -2022-10-08,"wk inc flu hosp",4,"US","quantile","0.15",141 -2022-10-08,"wk inc flu hosp",4,"US","quantile","0.2",141 -2022-10-08,"wk inc flu hosp",4,"US","quantile","0.25",142 -2022-10-08,"wk inc flu hosp",4,"US","quantile","0.3",143 -2022-10-08,"wk inc flu hosp",4,"US","quantile","0.35",144 -2022-10-08,"wk inc flu hosp",4,"US","quantile","0.4",145 -2022-10-08,"wk inc flu hosp",4,"US","quantile","0.45",147 -2022-10-08,"wk inc flu hosp",4,"US","quantile","0.5",148 -2022-10-08,"wk inc flu hosp",4,"US","quantile","0.55",149 -2022-10-08,"wk inc flu hosp",4,"US","quantile","0.6",150 -2022-10-08,"wk inc flu hosp",4,"US","quantile","0.65",152 -2022-10-08,"wk inc flu hosp",4,"US","quantile","0.7",155 -2022-10-08,"wk inc flu hosp",4,"US","quantile","0.75",161 -2022-10-08,"wk inc flu hosp",4,"US","quantile","0.8",165 -2022-10-08,"wk inc flu hosp",4,"US","quantile","0.85",170 -2022-10-08,"wk inc flu hosp",4,"US","quantile","0.9",175 -2022-10-08,"wk inc flu hosp",4,"US","quantile","0.95",176 -2022-10-08,"wk inc flu hosp",4,"US","quantile","0.975",176 -2022-10-08,"wk inc flu hosp",4,"US","quantile","0.99",205 -2022-10-08,"wk inc flu hosp",1,"04","quantile","0.01",135 -2022-10-08,"wk inc flu hosp",1,"04","quantile","0.025",137 -2022-10-08,"wk inc flu hosp",1,"04","quantile","0.05",139 -2022-10-08,"wk inc flu hosp",1,"04","quantile","0.1",140 -2022-10-08,"wk inc flu hosp",1,"04","quantile","0.15",141 -2022-10-08,"wk inc flu hosp",1,"04","quantile","0.2",141 -2022-10-08,"wk inc flu hosp",1,"04","quantile","0.25",142 -2022-10-08,"wk inc flu hosp",1,"04","quantile","0.3",143 -2022-10-08,"wk inc flu hosp",1,"04","quantile","0.35",144 -2022-10-08,"wk inc flu hosp",1,"04","quantile","0.4",145 -2022-10-08,"wk inc flu hosp",1,"04","quantile","0.45",147 -2022-10-08,"wk inc flu hosp",1,"04","quantile","0.5",148 -2022-10-08,"wk inc flu hosp",1,"04","quantile","0.55",149 -2022-10-08,"wk inc flu hosp",1,"04","quantile","0.6",150 -2022-10-08,"wk inc flu hosp",1,"04","quantile","0.65",152 -2022-10-08,"wk inc flu hosp",1,"04","quantile","0.7",155 -2022-10-08,"wk inc flu hosp",1,"04","quantile","0.75",161 -2022-10-08,"wk inc flu hosp",1,"04","quantile","0.8",165 -2022-10-08,"wk inc flu hosp",1,"04","quantile","0.85",170 -2022-10-08,"wk inc flu hosp",1,"04","quantile","0.9",175 -2022-10-08,"wk inc flu hosp",1,"04","quantile","0.95",176 -2022-10-08,"wk inc flu hosp",1,"04","quantile","0.975",176 -2022-10-08,"wk inc flu hosp",1,"04","quantile","0.99",205 -2022-10-08,"wk inc flu hosp",2,"04","quantile","0.01",135 -2022-10-08,"wk inc flu hosp",2,"04","quantile","0.025",137 -2022-10-08,"wk inc flu hosp",2,"04","quantile","0.05",139 -2022-10-08,"wk inc flu hosp",2,"04","quantile","0.1",140 -2022-10-08,"wk inc flu hosp",2,"04","quantile","0.15",141 -2022-10-08,"wk inc flu hosp",2,"04","quantile","0.2",141 -2022-10-08,"wk inc flu hosp",2,"04","quantile","0.25",142 -2022-10-08,"wk inc flu hosp",2,"04","quantile","0.3",143 -2022-10-08,"wk inc flu hosp",2,"04","quantile","0.35",144 -2022-10-08,"wk inc flu hosp",2,"04","quantile","0.4",145 -2022-10-08,"wk inc flu hosp",2,"04","quantile","0.45",147 -2022-10-08,"wk inc flu hosp",2,"04","quantile","0.5",148 -2022-10-08,"wk inc flu hosp",2,"04","quantile","0.55",149 -2022-10-08,"wk inc flu hosp",2,"04","quantile","0.6",150 -2022-10-08,"wk inc flu hosp",2,"04","quantile","0.65",152 -2022-10-08,"wk inc flu hosp",2,"04","quantile","0.7",155 -2022-10-08,"wk inc flu hosp",2,"04","quantile","0.75",161 -2022-10-08,"wk inc flu hosp",2,"04","quantile","0.8",165 -2022-10-08,"wk inc flu hosp",2,"04","quantile","0.85",170 -2022-10-08,"wk inc flu hosp",2,"04","quantile","0.9",175 -2022-10-08,"wk inc flu hosp",2,"04","quantile","0.95",176 -2022-10-08,"wk inc flu hosp",2,"04","quantile","0.975",176 -2022-10-08,"wk inc flu hosp",2,"04","quantile","0.99",205 -2022-10-08,"wk inc flu hosp",3,"04","quantile","0.01",135 -2022-10-08,"wk inc flu hosp",3,"04","quantile","0.025",137 -2022-10-08,"wk inc flu hosp",3,"04","quantile","0.05",139 -2022-10-08,"wk inc flu hosp",3,"04","quantile","0.1",140 -2022-10-08,"wk inc flu hosp",3,"04","quantile","0.15",141 -2022-10-08,"wk inc flu hosp",3,"04","quantile","0.2",141 -2022-10-08,"wk inc flu hosp",3,"04","quantile","0.25",142 -2022-10-08,"wk inc flu hosp",3,"04","quantile","0.3",143 -2022-10-08,"wk inc flu hosp",3,"04","quantile","0.35",144 -2022-10-08,"wk inc flu hosp",3,"04","quantile","0.4",145 -2022-10-08,"wk inc flu hosp",3,"04","quantile","0.45",147 -2022-10-08,"wk inc flu hosp",3,"04","quantile","0.5",148 -2022-10-08,"wk inc flu hosp",3,"04","quantile","0.55",149 -2022-10-08,"wk inc flu hosp",3,"04","quantile","0.6",150 -2022-10-08,"wk inc flu hosp",3,"04","quantile","0.65",152 -2022-10-08,"wk inc flu hosp",3,"04","quantile","0.7",155 -2022-10-08,"wk inc flu hosp",3,"04","quantile","0.75",161 -2022-10-08,"wk inc flu hosp",3,"04","quantile","0.8",165 -2022-10-08,"wk inc flu hosp",3,"04","quantile","0.85",170 -2022-10-08,"wk inc flu hosp",3,"04","quantile","0.9",175 -2022-10-08,"wk inc flu hosp",3,"04","quantile","0.95",176 -2022-10-08,"wk inc flu hosp",3,"04","quantile","0.975",176 -2022-10-08,"wk inc flu hosp",3,"04","quantile","0.99",205 -2022-10-08,"wk inc flu hosp",4,"04","quantile","0.01",135 -2022-10-08,"wk inc flu hosp",4,"04","quantile","0.025",137 -2022-10-08,"wk inc flu hosp",4,"04","quantile","0.05",139 -2022-10-08,"wk inc flu hosp",4,"04","quantile","0.1",140 -2022-10-08,"wk inc flu hosp",4,"04","quantile","0.15",141 -2022-10-08,"wk inc flu hosp",4,"04","quantile","0.2",141 -2022-10-08,"wk inc flu hosp",4,"04","quantile","0.25",142 -2022-10-08,"wk inc flu hosp",4,"04","quantile","0.3",143 -2022-10-08,"wk inc flu hosp",4,"04","quantile","0.35",144 -2022-10-08,"wk inc flu hosp",4,"04","quantile","0.4",145 -2022-10-08,"wk inc flu hosp",4,"04","quantile","0.45",147 -2022-10-08,"wk inc flu hosp",4,"04","quantile","0.5",148 -2022-10-08,"wk inc flu hosp",4,"04","quantile","0.55",149 -2022-10-08,"wk inc flu hosp",4,"04","quantile","0.6",150 -2022-10-08,"wk inc flu hosp",4,"04","quantile","0.65",152 -2022-10-08,"wk inc flu hosp",4,"04","quantile","0.7",155 -2022-10-08,"wk inc flu hosp",4,"04","quantile","0.75",161 -2022-10-08,"wk inc flu hosp",4,"04","quantile","0.8",165 -2022-10-08,"wk inc flu hosp",4,"04","quantile","0.85",170 -2022-10-08,"wk inc flu hosp",4,"04","quantile","0.9",175 -2022-10-08,"wk inc flu hosp",4,"04","quantile","0.95",176 -2022-10-08,"wk inc flu hosp",4,"04","quantile","0.975",176 -2022-10-08,"wk inc flu hosp",4,"04","quantile","0.99",205 -2022-10-08,"wk inc flu hosp",1,"01","quantile","0.01",135 -2022-10-08,"wk inc flu hosp",1,"01","quantile","0.025",137 -2022-10-08,"wk inc flu hosp",1,"01","quantile","0.05",139 -2022-10-08,"wk inc flu hosp",1,"01","quantile","0.1",140 -2022-10-08,"wk inc flu hosp",1,"01","quantile","0.15",141 -2022-10-08,"wk inc flu hosp",1,"01","quantile","0.2",141 -2022-10-08,"wk inc flu hosp",1,"01","quantile","0.25",142 -2022-10-08,"wk inc flu hosp",1,"01","quantile","0.3",143 -2022-10-08,"wk inc flu hosp",1,"01","quantile","0.35",144 -2022-10-08,"wk inc flu hosp",1,"01","quantile","0.4",145 -2022-10-08,"wk inc flu hosp",1,"01","quantile","0.45",147 -2022-10-08,"wk inc flu hosp",1,"01","quantile","0.5",148 -2022-10-08,"wk inc flu hosp",1,"01","quantile","0.55",149 -2022-10-08,"wk inc flu hosp",1,"01","quantile","0.6",150 -2022-10-08,"wk inc flu hosp",1,"01","quantile","0.65",152 -2022-10-08,"wk inc flu hosp",1,"01","quantile","0.7",155 -2022-10-08,"wk inc flu hosp",1,"01","quantile","0.75",161 -2022-10-08,"wk inc flu hosp",1,"01","quantile","0.8",165 -2022-10-08,"wk inc flu hosp",1,"01","quantile","0.85",170 -2022-10-08,"wk inc flu hosp",1,"01","quantile","0.9",175 -2022-10-08,"wk inc flu hosp",1,"01","quantile","0.95",176 -2022-10-08,"wk inc flu hosp",1,"01","quantile","0.975",176 -2022-10-08,"wk inc flu hosp",1,"01","quantile","0.99",205 -2022-10-08,"wk inc flu hosp",2,"01","quantile","0.01",135 -2022-10-08,"wk inc flu hosp",2,"01","quantile","0.025",137 -2022-10-08,"wk inc flu hosp",2,"01","quantile","0.05",139 -2022-10-08,"wk inc flu hosp",2,"01","quantile","0.1",140 -2022-10-08,"wk inc flu hosp",2,"01","quantile","0.15",141 -2022-10-08,"wk inc flu hosp",2,"01","quantile","0.2",141 -2022-10-08,"wk inc flu hosp",2,"01","quantile","0.25",142 -2022-10-08,"wk inc flu hosp",2,"01","quantile","0.3",143 -2022-10-08,"wk inc flu hosp",2,"01","quantile","0.35",144 -2022-10-08,"wk inc flu hosp",2,"01","quantile","0.4",145 -2022-10-08,"wk inc flu hosp",2,"01","quantile","0.45",147 -2022-10-08,"wk inc flu hosp",2,"01","quantile","0.5",148 -2022-10-08,"wk inc flu hosp",2,"01","quantile","0.55",149 -2022-10-08,"wk inc flu hosp",2,"01","quantile","0.6",150 -2022-10-08,"wk inc flu hosp",2,"01","quantile","0.65",152 -2022-10-08,"wk inc flu hosp",2,"01","quantile","0.7",155 -2022-10-08,"wk inc flu hosp",2,"01","quantile","0.75",161 -2022-10-08,"wk inc flu hosp",2,"01","quantile","0.8",165 -2022-10-08,"wk inc flu hosp",2,"01","quantile","0.85",170 -2022-10-08,"wk inc flu hosp",2,"01","quantile","0.9",175 -2022-10-08,"wk inc flu hosp",2,"01","quantile","0.95",176 -2022-10-08,"wk inc flu hosp",2,"01","quantile","0.975",176 -2022-10-08,"wk inc flu hosp",2,"01","quantile","0.99",205 -2022-10-08,"wk inc flu hosp",3,"01","quantile","0.01",135 -2022-10-08,"wk inc flu hosp",3,"01","quantile","0.025",137 -2022-10-08,"wk inc flu hosp",3,"01","quantile","0.05",139 -2022-10-08,"wk inc flu hosp",3,"01","quantile","0.1",140 -2022-10-08,"wk inc flu hosp",3,"01","quantile","0.15",141 -2022-10-08,"wk inc flu hosp",3,"01","quantile","0.2",141 -2022-10-08,"wk inc flu hosp",3,"01","quantile","0.25",142 -2022-10-08,"wk inc flu hosp",3,"01","quantile","0.3",143 -2022-10-08,"wk inc flu hosp",3,"01","quantile","0.35",144 -2022-10-08,"wk inc flu hosp",3,"01","quantile","0.4",145 -2022-10-08,"wk inc flu hosp",3,"01","quantile","0.45",147 -2022-10-08,"wk inc flu hosp",3,"01","quantile","0.5",148 -2022-10-08,"wk inc flu hosp",3,"01","quantile","0.55",149 -2022-10-08,"wk inc flu hosp",3,"01","quantile","0.6",150 -2022-10-08,"wk inc flu hosp",3,"01","quantile","0.65",152 -2022-10-08,"wk inc flu hosp",3,"01","quantile","0.7",155 -2022-10-08,"wk inc flu hosp",3,"01","quantile","0.75",161 -2022-10-08,"wk inc flu hosp",3,"01","quantile","0.8",165 -2022-10-08,"wk inc flu hosp",3,"01","quantile","0.85",170 -2022-10-08,"wk inc flu hosp",3,"01","quantile","0.9",175 -2022-10-08,"wk inc flu hosp",3,"01","quantile","0.95",176 -2022-10-08,"wk inc flu hosp",3,"01","quantile","0.975",176 -2022-10-08,"wk inc flu hosp",3,"01","quantile","0.99",205 -2022-10-08,"wk inc flu hosp",4,"01","quantile","0.01",135 -2022-10-08,"wk inc flu hosp",4,"01","quantile","0.025",137 -2022-10-08,"wk inc flu hosp",4,"01","quantile","0.05",139 -2022-10-08,"wk inc flu hosp",4,"01","quantile","0.1",140 -2022-10-08,"wk inc flu hosp",4,"01","quantile","0.15",141 -2022-10-08,"wk inc flu hosp",4,"01","quantile","0.2",141 -2022-10-08,"wk inc flu hosp",4,"01","quantile","0.25",142 -2022-10-08,"wk inc flu hosp",4,"01","quantile","0.3",143 -2022-10-08,"wk inc flu hosp",4,"01","quantile","0.35",144 -2022-10-08,"wk inc flu hosp",4,"01","quantile","0.4",145 -2022-10-08,"wk inc flu hosp",4,"01","quantile","0.45",147 -2022-10-08,"wk inc flu hosp",4,"01","quantile","0.5",148 -2022-10-08,"wk inc flu hosp",4,"01","quantile","0.55",149 -2022-10-08,"wk inc flu hosp",4,"01","quantile","0.6",150 -2022-10-08,"wk inc flu hosp",4,"01","quantile","0.65",152 -2022-10-08,"wk inc flu hosp",4,"01","quantile","0.7",155 -2022-10-08,"wk inc flu hosp",4,"01","quantile","0.75",161 -2022-10-08,"wk inc flu hosp",4,"01","quantile","0.8",165 -2022-10-08,"wk inc flu hosp",4,"01","quantile","0.85",170 -2022-10-08,"wk inc flu hosp",4,"01","quantile","0.9",175 -2022-10-08,"wk inc flu hosp",4,"01","quantile","0.95",176 -2022-10-08,"wk inc flu hosp",4,"01","quantile","0.975",176 -2022-10-08,"wk inc flu hosp",4,"01","quantile","0.99",205 diff --git a/inst/example-data/2022-10-08-team1-goodmodel.csv b/inst/example-data/2022-10-08-team1-goodmodel.csv deleted file mode 100644 index 71108b3..0000000 --- a/inst/example-data/2022-10-08-team1-goodmodel.csv +++ /dev/null @@ -1,24 +0,0 @@ -"origin_date","target","horizon","location","output_type","output_type_id","value" -2022-10-08,"wk inc flu hosp",1,"US","quantile","0.01",135 -2022-10-08,"wk inc flu hosp",1,"US","quantile","0.025",137 -2022-10-08,"wk inc flu hosp",1,"US","quantile","0.05",139 -2022-10-08,"wk inc flu hosp",1,"US","quantile","0.1",140 -2022-10-08,"wk inc flu hosp",1,"US","quantile","0.15",141 -2022-10-08,"wk inc flu hosp",1,"US","quantile","0.2",141 -2022-10-08,"wk inc flu hosp",1,"US","quantile","0.25",142 -2022-10-08,"wk inc flu hosp",1,"US","quantile","0.3",143 -2022-10-08,"wk inc flu hosp",1,"US","quantile","0.35",144 -2022-10-08,"wk inc flu hosp",1,"US","quantile","0.4",145 -2022-10-08,"wk inc flu hosp",1,"US","quantile","0.45",147 -2022-10-08,"wk inc flu hosp",1,"US","quantile","0.5",148 -2022-10-08,"wk inc flu hosp",1,"US","quantile","0.55",149 -2022-10-08,"wk inc flu hosp",1,"US","quantile","0.6",150 -2022-10-08,"wk inc flu hosp",1,"US","quantile","0.65",152 -2022-10-08,"wk inc flu hosp",1,"US","quantile","0.7",155 -2022-10-08,"wk inc flu hosp",1,"US","quantile","0.75",161 -2022-10-08,"wk inc flu hosp",1,"US","quantile","0.8",165 -2022-10-08,"wk inc flu hosp",1,"US","quantile","0.85",170 -2022-10-08,"wk inc flu hosp",1,"US","quantile","0.9",175 -2022-10-08,"wk inc flu hosp",1,"US","quantile","0.95",176 -2022-10-08,"wk inc flu hosp",1,"US","quantile","0.975",176 -2022-10-08,"wk inc flu hosp",1,"US","quantile","0.99",205 diff --git a/tests/testthat/test-simple_ensemble.R b/tests/testthat/test-simple_ensemble.R index a9f6907..d23fae4 100644 --- a/tests/testthat/test-simple_ensemble.R +++ b/tests/testthat/test-simple_ensemble.R @@ -1,25 +1,61 @@ library(matrixStats) library(dplyr) -pred <- read.csv(system.file("example-data/2022-10-08-simple_hub-baseline.csv", - package = "hubEnsembles")) %>% - dplyr::mutate(model_id = "simple_hub-baseline", .before = origin_date) +# set up simple data for test cases +model_outputs <- expand.grid( + stringsAsFactors = FALSE, + model_id = letters[1:4], + location = c("222", "888"), + horizon = 1, #week + target = "inc death", + target_date = as.Date("2021-12-25"), + output_type = "quantile", + output_type_id = c(.1, .5, .9), + value = NA_real_) + +v2.1 <- model_outputs$value[model_outputs$location == "222" & + model_outputs$output_type_id == .1] <- + c(10, 30, 15, 20) +v2.5 <- model_outputs$value[model_outputs$location == "222" & + model_outputs$output_type_id == .5] <- + c(40, 40, 45, 50) +v2.9 <- model_outputs$value[model_outputs$location == "222" & + model_outputs$output_type_id == .9] <- + c(60, 70, 75, 80) +v8.1 <- model_outputs$value[model_outputs$location == "888" & + model_outputs$output_type_id == .1] <- + c(100, 300, 400, 250) +v8.5 <- model_outputs$value[model_outputs$location == "888" & + model_outputs$output_type_id == .5] <- + c(150, 325, 500, 300) +v8.9 <- model_outputs$value[model_outputs$location == "888" & + model_outputs$output_type_id == .9] <- + c(250, 350, 500, 350) + +fweight2 <- data.frame(model_id = letters[1:4], + location = "222", + weight = 0.1 * (1:4)) +fweight8 <- data.frame(model_id = letters[1:4], + location = "888", + weight = 0.1 * (4:1)) +fweight <- bind_rows(fweight2, fweight8) + test_that("non-default columns are dropped from output", { - output_names <- pred %>% + output_names <- model_outputs %>% dplyr::mutate(extra_col_1 = "a", extra_col_2 = "a") %>% simple_ensemble( - task_id_cols = c("origin_date", "target", "horizon", "location") + task_id_cols = c("target_date", "target", "horizon", "location") ) %>% names() - expect_equal(sort(names(pred)), sort(output_names)) + expect_equal(sort(names(model_outputs)), sort(output_names)) }) test_that("invalid output type throws error", { expect_error( - pred %>% + model_outputs %>% dplyr::mutate(output_type = "sample") %>% simple_ensemble() ) @@ -28,44 +64,12 @@ test_that("invalid output type throws error", { test_that("invalid method argument throws error", { expect_error( - simple_ensemble(pred, agg_fun="linear pool") + simple_ensemble(model_outputs, agg_fun="linear pool") ) }) test_that("(weighted) medians and means correctly calculated", { - fdat <- expand.grid( - stringsAsFactors = FALSE, - model_id = letters[1:4], - location = c("222", "888"), - horizon = 1, #week - target = "inc death", - target_date = as.Date("2021-12-25"), - output_type = "quantile", - output_type_id = c(.1, .5, .9), - value = NA_real_) - - fdat$value[fdat$location == "222" & fdat$output_type_id == .1] <- v2.1 <- - c(10, 30, 15, 20) - fdat$value[fdat$location == "222" & fdat$output_type_id == .5] <- v2.5 <- - c(40, 40, 45, 50) - fdat$value[fdat$location == "222" & fdat$output_type_id == .9] <- v2.9 <- - c(60, 70, 75, 80) - fdat$value[fdat$location == "888" & fdat$output_type_id == .1] <- v8.1 <- - c(100, 300, 400, 250) - fdat$value[fdat$location == "888" & fdat$output_type_id == .5] <- v8.5 <- - c(150, 325, 500, 300) - fdat$value[fdat$location == "888" & fdat$output_type_id == .9] <- v8.9 <- - c(250, 350, 500, 350) - - fweight2 <- data.frame(model_id = letters[1:4], - location = "222", - weight = 0.1 * (1:4)) - fweight8 <- data.frame(model_id = letters[1:4], - location = "888", - weight = 0.1 * (4:1)) - fweight <- bind_rows(fweight2, fweight8) - median_expected <- mean_expected <- weighted_median_expected <- weighted_mean_expected <- data.frame( model_id = "hub-ensemble", @@ -100,15 +104,15 @@ test_that("(weighted) medians and means correctly calculated", { weighted_mean_expected$value <- weighted_mean_vals weighted_median_expected$value <- weighted_median_vals - median_actual <- simple_ensemble(model_outputs = fdat, weights = NULL, + median_actual <- simple_ensemble(model_outputs = model_outputs, weights = NULL, agg_fun = "median") - mean_actual <- simple_ensemble(model_outputs = fdat, weights = NULL, + mean_actual <- simple_ensemble(model_outputs = model_outputs, weights = NULL, agg_fun = "mean") - weighted_median_actual <- simple_ensemble(model_outputs = fdat, + weighted_median_actual <- simple_ensemble(model_outputs = model_outputs, weights = fweight, agg_fun = "median") - weighted_mean_actual <- simple_ensemble(model_outputs = fdat, + weighted_mean_actual <- simple_ensemble(model_outputs = model_outputs, weights = fweight, agg_fun = "mean") From dbe05e1a91c6e0d4a994c185589aeb71fa7b6727 Mon Sep 17 00:00:00 2001 From: Evan Ray Date: Wed, 7 Jun 2023 22:02:41 -0400 Subject: [PATCH 17/24] updates to description --- DESCRIPTION | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/DESCRIPTION b/DESCRIPTION index 1abb6bd..397bd39 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,10 +1,17 @@ Package: hubEnsembles -Title: What the Package Does (One Line, Title Case) +Title: Ensemble methods for combining hub model outputs. Version: 0.0.0.9000 Authors@R: person("Anna", "Krystalli", , "annakrystalli@googlemail.com", role = c("aut", "cre"), - comment = c(ORCID = "0000-0002-2378-4915")) -Description: What the package does (one paragraph). + comment = c(ORCID = "0000-0002-2378-4915")), + person(given = "Evan L", + family = "Ray", + role = c("aut")), + person(given = "Li", + family = "Shandross", + role = c("aut")) +Description: Functions for combining model outputs (e.g. predictions or + estimates) from multiple models into an aggregated ensemble model output. License: MIT + file LICENSE Suggests: testthat (>= 3.0.0) @@ -17,5 +24,8 @@ BugReports: https://github.com/Infectious-Disease-Modeling-Hubs/hubEnsembles/iss Imports: cli, dplyr, + hubUtils, matrixStats, rlang +Remotes: + Infectious-Disease-Modeling-Hubs/hubUtils From 79515c08936a56b020330a632cffd51a9ec482ff Mon Sep 17 00:00:00 2001 From: Evan Ray Date: Wed, 7 Jun 2023 22:40:22 -0400 Subject: [PATCH 18/24] fix authors field in DESCRIPTION --- DESCRIPTION | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/DESCRIPTION b/DESCRIPTION index 397bd39..4b76f73 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,7 +1,7 @@ Package: hubEnsembles Title: Ensemble methods for combining hub model outputs. Version: 0.0.0.9000 -Authors@R: +Authors@R: c( person("Anna", "Krystalli", , "annakrystalli@googlemail.com", role = c("aut", "cre"), comment = c(ORCID = "0000-0002-2378-4915")), person(given = "Evan L", @@ -9,7 +9,7 @@ Authors@R: role = c("aut")), person(given = "Li", family = "Shandross", - role = c("aut")) + role = c("aut"))) Description: Functions for combining model outputs (e.g. predictions or estimates) from multiple models into an aggregated ensemble model output. License: MIT + file LICENSE From ac7e964d11c9aaef27468312d718b633e638f82d Mon Sep 17 00:00:00 2001 From: Evan Ray Date: Wed, 7 Jun 2023 22:44:28 -0400 Subject: [PATCH 19/24] add magrittr to package imports --- DESCRIPTION | 1 + 1 file changed, 1 insertion(+) diff --git a/DESCRIPTION b/DESCRIPTION index 4b76f73..04875dc 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -25,6 +25,7 @@ Imports: cli, dplyr, hubUtils, + magrittr, matrixStats, rlang Remotes: From a8e6ecee20a63ae29bd18004bf45cdeaa2330593 Mon Sep 17 00:00:00 2001 From: Evan Ray Date: Thu, 8 Jun 2023 10:50:03 -0400 Subject: [PATCH 20/24] Update R/simple_ensemble.R Co-authored-by: Anna Krystalli --- R/simple_ensemble.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/simple_ensemble.R b/R/simple_ensemble.R index a2763c0..338ce83 100644 --- a/R/simple_ensemble.R +++ b/R/simple_ensemble.R @@ -71,7 +71,7 @@ simple_ensemble <- function(model_outputs, weights = NULL, if (length(invalid_types) > 0) { cli::cli_abort(c( "x" = "{.arg model_outputs} contains unsupported output type.", - "i" = "Included output type{?s}: {.val {invalid_types}}.", + "!" = "Included invalid output type{?s}: {.val {invalid_types}}.", "i" = "Supported output types: {.val {valid_types}}." )) } From 4bc4d1dab857e32f799777945a00a76f92407ddd Mon Sep 17 00:00:00 2001 From: Evan Ray Date: Thu, 8 Jun 2023 12:09:19 -0400 Subject: [PATCH 21/24] updates in response to anna's pr comments, weights_col_name as argument --- R/simple_ensemble.R | 57 ++++++++++++++++----------- tests/testthat/test-simple_ensemble.R | 49 ++++++++++++++++++++--- 2 files changed, 77 insertions(+), 29 deletions(-) diff --git a/R/simple_ensemble.R b/R/simple_ensemble.R index 338ce83..6c4fbbe 100644 --- a/R/simple_ensemble.R +++ b/R/simple_ensemble.R @@ -5,10 +5,13 @@ #' @param model_outputs an object of class `model_output_df` with component #' model outputs (e.g., predictions). #' @param weights an optional `data.frame` with component model weights. If -#' provided, it should have columns `model_id`, `weight`, and optionally, -#' additional columns corresponding to task id variables, `output_type`, or -#' `output_type_id`, if weights are specific to values of those variables. The -#' default is `NULL`, in which case an equally-weighted ensemble is calculated. +#' provided, it should have a column named `model_id` and a column containing +#' model weights. Optionally, it may contain additional columns corresponding +#' to task id variables, `output_type`, or `output_type_id`, if weights are +#' specific to values of those variables. The default is `NULL`, in which case +#' an equally-weighted ensemble is calculated. +#' @param weights_col_name `character` string naming the column in `weights` +#' with model weights. Defaults to `"weights"` #' @param agg_fun a function or character string name of a function to use for #' aggregating component model outputs into the ensemble outputs. See the #' details for more information. @@ -16,35 +19,43 @@ #' to `agg_fun`. #' @param model_id `character` string with the identifier to use for the #' ensemble model. -#' @param task_id_cols, output_type_col, output_type_id_col, value_col -#' `character` vectors with the names of the columns in `model_outputs` for -#' the output's type, additional identifying information, and value of the -#' model output. +#' @param task_id_cols `character` vector with names of columns in +#' `model_outputs` that specify modeling tasks. +#' @param output_type_col `character` string with the name of the column in +#' `model_outputs` that contains the output type. +#' @param output_type_id_col `character` string with the name of the column in +#' `model_outputs` that contains the output type id. +#' @param value_col `character` string with the name of the column in +#' `model_outputs` that contains model output values. #' -#' @details The default for `agg_fun` is `mean`, in which case the ensemble's +#' @details The default for `agg_fun` is `"mean"`, in which case the ensemble's #' output is the average of the component model outputs within each group #' defined by a combination of values in the task id columns, output type, and #' output type id. The provided `agg_fun` should have an argument `x` for the #' vector of numeric values to summarize, and for weighted methods, an -#' argument `w` with a numeric vector of weights. For weighted methods, -#' `agg_fun = "mean"` and `agg_fun = "median"` are translated to use -#' `matrixStats::weightedMean` and `matrixStats::weightedMedian` respectively. +#' argument `w` with a numeric vector of weights. If it desired to use an +#' aggregation function that does not accept these arguments, a wrapper +#' would need to be written. For weighted methods, `agg_fun = "mean"` and +#' `agg_fun = "median"` are translated to use `matrixStats::weightedMean` and +#' `matrixStats::weightedMedian` respectively. #' -#' @return a data.frame with columns `team_abbr`, `model_abbr`, one column for +#' @return a data.frame with columns `model_id`, one column for #' each task id variable, `output_type`, `output_id`, and `value`. Note that #' any additional columns in the input `model_outputs` are dropped. simple_ensemble <- function(model_outputs, weights = NULL, + weights_col_name = "weight", agg_fun = "mean", agg_args = list(), model_id = "hub-ensemble", task_id_cols = NULL, output_type_col = "output_type", output_type_id_col = "output_type_id", hub_connection = NULL) { - model_out_cols <- colnames(model_outputs) if (!is.data.frame(model_outputs)) { cli::cli_abort(c("x" = "{.arg model_outputs} must be a `data.frame`.")) } + model_out_cols <- colnames(model_outputs) + non_task_cols <- c("model_id", output_type_col, output_type_id_col, "value") if (is.null(task_id_cols)) { task_id_cols <- model_out_cols[!model_out_cols %in% non_task_cols] @@ -79,7 +90,7 @@ simple_ensemble <- function(model_outputs, weights = NULL, if (is.null(weights)) { agg_args <- c(agg_args, list(x = quote(.data[["value"]]))) } else { - req_weight_cols <- c("model_id", "weight") + req_weight_cols <- c("model_id", weights_col_name) if (!all(req_weight_cols %in% colnames(weights))) { cli::cli_abort(c( "x" = "{.arg weights} did not include required columns @@ -87,7 +98,7 @@ simple_ensemble <- function(model_outputs, weights = NULL, )) } - weight_by_cols <- colnames(weights)[colnames(weights) != "weight"] + weight_by_cols <- colnames(weights)[colnames(weights) != weights_col_name] if ("value" %in% weight_by_cols) { cli::cli_abort(c( @@ -97,7 +108,7 @@ simple_ensemble <- function(model_outputs, weights = NULL, } invalid_cols <- weight_by_cols[!weight_by_cols %in% colnames(model_outputs)] - if (!all(weight_by_cols %in% colnames(model_outputs))) { + if (length(invalid_cols) > 0) { cli::cli_abort(c( "x" = "{.arg weights} included {length(invalid_cols)} column{?s} that {?was/were} not present in {.arg model_outputs}: @@ -105,11 +116,11 @@ simple_ensemble <- function(model_outputs, weights = NULL, )) } - if ("weight" %in% colnames(model_outputs)) { - weight_col_name <- paste0("weight_", rlang::hash(colnames(model_outputs))) - weights <- weights %>% dplyr::rename(!!weight_col_name := "weight") - } else { - weight_col_name <- "weight" + if (weights_col_name %in% colnames(model_outputs)) { + cli::cli_abort(c( + "x" = "The specified {.arg weights_col_name}, {.val {weights_col_name}}, + is already a column in {.arg model_outputs}." + )) } model_outputs <- model_outputs %>% @@ -124,7 +135,7 @@ simple_ensemble <- function(model_outputs, weights = NULL, } agg_args <- c(agg_args, list(x = quote(.data[["value"]]), - w = quote(.data[[weight_col_name]]))) + w = quote(.data[[weights_col_name]]))) } group_by_cols <- c(task_id_cols, output_type_col, output_type_id_col) diff --git a/tests/testthat/test-simple_ensemble.R b/tests/testthat/test-simple_ensemble.R index d23fae4..854e324 100644 --- a/tests/testthat/test-simple_ensemble.R +++ b/tests/testthat/test-simple_ensemble.R @@ -64,7 +64,16 @@ test_that("invalid output type throws error", { test_that("invalid method argument throws error", { expect_error( - simple_ensemble(model_outputs, agg_fun="linear pool") + simple_ensemble(model_outputs, agg_fun = "linear pool") + ) +}) + + +test_that("weights column already in model_outputs generates error", { + expect_error( + model_outputs %>% + dplyr::mutate(weight = "a") %>% + simple_ensemble(weights = fweight) ) }) @@ -76,9 +85,9 @@ test_that("(weighted) medians and means correctly calculated", { location = rep(c("222", "888"), each = 3), horizon = 1, target = "inc death", - target_date = as.Date("2021-12-25"), - output_type = "quantile", - output_type_id = rep(c(.1, .5, .9), 2), + target_date = as.Date("2021-12-25"), + output_type = "quantile", + output_type_id = rep(c(.1, .5, .9), 2), value = NA_real_) median_vals <- sapply(list(v2.1, v2.5, v2.9, v8.1, v8.5, v8.9), median) @@ -104,9 +113,11 @@ test_that("(weighted) medians and means correctly calculated", { weighted_mean_expected$value <- weighted_mean_vals weighted_median_expected$value <- weighted_median_vals - median_actual <- simple_ensemble(model_outputs = model_outputs, weights = NULL, + median_actual <- simple_ensemble(model_outputs = model_outputs, + weights = NULL, agg_fun = "median") - mean_actual <- simple_ensemble(model_outputs = model_outputs, weights = NULL, + mean_actual <- simple_ensemble(model_outputs = model_outputs, + weights = NULL, agg_fun = "mean") weighted_median_actual <- simple_ensemble(model_outputs = model_outputs, @@ -122,3 +133,29 @@ test_that("(weighted) medians and means correctly calculated", { expect_equal(as.data.frame(weighted_median_actual), weighted_median_expected) expect_equal(as.data.frame(weighted_mean_actual), weighted_mean_expected) }) + + +test_that("(weighted) medians and means work with alternate name for weights columns", { + weighted_median_actual <- simple_ensemble( + model_outputs = model_outputs, + weights = fweight %>% + dplyr::rename(w = weight), + weights_col_name = "w", + agg_fun = "median") + weighted_mean_actual <- simple_ensemble( + model_outputs = model_outputs, + weights = fweight %>% + dplyr::rename(w = weight), + weights_col_name = "w", + agg_fun = "mean") + + weighted_median_expected <- simple_ensemble(model_outputs = model_outputs, + weights = fweight, + agg_fun = "median") + weighted_mean_expected <- simple_ensemble(model_outputs = model_outputs, + weights = fweight, + agg_fun = "mean") + + expect_equal(weighted_mean_actual, weighted_mean_expected) + expect_equal(weighted_median_actual, weighted_median_expected) +}) From 38d1a29de2749820080f0fae5efbdec76d326d69 Mon Sep 17 00:00:00 2001 From: Evan Ray Date: Thu, 8 Jun 2023 12:14:18 -0400 Subject: [PATCH 22/24] update docs for simple_ensemble --- man/simple_ensemble.Rd | 40 +++++++++++++++++++++++++++------------- 1 file changed, 27 insertions(+), 13 deletions(-) diff --git a/man/simple_ensemble.Rd b/man/simple_ensemble.Rd index b231d55..d922a62 100644 --- a/man/simple_ensemble.Rd +++ b/man/simple_ensemble.Rd @@ -9,6 +9,7 @@ output types include \code{mean}, \code{median}, \code{quantile}, \code{cdf}, an simple_ensemble( model_outputs, weights = NULL, + weights_col_name = "weight", agg_fun = "mean", agg_args = list(), model_id = "hub-ensemble", @@ -23,10 +24,14 @@ simple_ensemble( model outputs (e.g., predictions).} \item{weights}{an optional \code{data.frame} with component model weights. If -provided, it should have columns \code{model_id}, \code{weight}, and optionally, -additional columns corresponding to task id variables, \code{output_type}, or -\code{output_type_id}, if weights are specific to values of those variables. The -default is \code{NULL}, in which case an equally-weighted ensemble is calculated.} +provided, it should have a column named \code{model_id} and a column containing +model weights. Optionally, it may contain additional columns corresponding +to task id variables, \code{output_type}, or \code{output_type_id}, if weights are +specific to values of those variables. The default is \code{NULL}, in which case +an equally-weighted ensemble is calculated.} + +\item{weights_col_name}{\code{character} string naming the column in \code{weights} +with model weights. Defaults to \code{"weights"}} \item{agg_fun}{a function or character string name of a function to use for aggregating component model outputs into the ensemble outputs. See the @@ -38,13 +43,20 @@ to \code{agg_fun}.} \item{model_id}{\code{character} string with the identifier to use for the ensemble model.} -\item{task_id_cols, }{output_type_col, output_type_id_col, value_col -\code{character} vectors with the names of the columns in \code{model_outputs} for -the output's type, additional identifying information, and value of the -model output.} +\item{task_id_cols}{\code{character} vector with names of columns in +\code{model_outputs} that specify modeling tasks.} + +\item{output_type_col}{\code{character} string with the name of the column in +\code{model_outputs} that contains the output type.} + +\item{output_type_id_col}{\code{character} string with the name of the column in +\code{model_outputs} that contains the output type id.} + +\item{value_col}{\code{character} string with the name of the column in +\code{model_outputs} that contains model output values.} } \value{ -a data.frame with columns \code{team_abbr}, \code{model_abbr}, one column for +a data.frame with columns \code{model_id}, one column for each task id variable, \code{output_type}, \code{output_id}, and \code{value}. Note that any additional columns in the input \code{model_outputs} are dropped. } @@ -54,12 +66,14 @@ each combination of model task, output type, and output type id. Supported output types include \code{mean}, \code{median}, \code{quantile}, \code{cdf}, and \code{pmf}. } \details{ -The default for \code{agg_fun} is \code{mean}, in which case the ensemble's +The default for \code{agg_fun} is \code{"mean"}, in which case the ensemble's output is the average of the component model outputs within each group defined by a combination of values in the task id columns, output type, and output type id. The provided \code{agg_fun} should have an argument \code{x} for the vector of numeric values to summarize, and for weighted methods, an -argument \code{w} with a numeric vector of weights. For weighted methods, -\code{agg_fun = "mean"} and \code{agg_fun = "median"} are translated to use -\code{matrixStats::weightedMean} and \code{matrixStats::weightedMedian} respectively. +argument \code{w} with a numeric vector of weights. If it desired to use an +aggregation function that does not accept these arguments, a wrapper +would need to be written. For weighted methods, \code{agg_fun = "mean"} and +\code{agg_fun = "median"} are translated to use \code{matrixStats::weightedMean} and +\code{matrixStats::weightedMedian} respectively. } From 2a92d7af5a79002942ecf44a7c05fba9f0f086c1 Mon Sep 17 00:00:00 2001 From: Evan Ray Date: Thu, 8 Jun 2023 12:18:17 -0400 Subject: [PATCH 23/24] remove hub_connection as argument to simple_ensemble --- R/simple_ensemble.R | 3 +-- man/simple_ensemble.Rd | 3 +-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/R/simple_ensemble.R b/R/simple_ensemble.R index 6c4fbbe..c1ee0ff 100644 --- a/R/simple_ensemble.R +++ b/R/simple_ensemble.R @@ -48,8 +48,7 @@ simple_ensemble <- function(model_outputs, weights = NULL, model_id = "hub-ensemble", task_id_cols = NULL, output_type_col = "output_type", - output_type_id_col = "output_type_id", - hub_connection = NULL) { + output_type_id_col = "output_type_id") { if (!is.data.frame(model_outputs)) { cli::cli_abort(c("x" = "{.arg model_outputs} must be a `data.frame`.")) } diff --git a/man/simple_ensemble.Rd b/man/simple_ensemble.Rd index d922a62..3fd6ea4 100644 --- a/man/simple_ensemble.Rd +++ b/man/simple_ensemble.Rd @@ -15,8 +15,7 @@ simple_ensemble( model_id = "hub-ensemble", task_id_cols = NULL, output_type_col = "output_type", - output_type_id_col = "output_type_id", - hub_connection = NULL + output_type_id_col = "output_type_id" ) } \arguments{ From 483b9605c31a309ff1528dc9e434f4c60b0f9e7e Mon Sep 17 00:00:00 2001 From: Evan Ray Date: Thu, 8 Jun 2023 12:22:18 -0400 Subject: [PATCH 24/24] remove doc for removed parameter value_col from simple_ensemble --- R/simple_ensemble.R | 2 -- man/simple_ensemble.Rd | 3 --- 2 files changed, 5 deletions(-) diff --git a/R/simple_ensemble.R b/R/simple_ensemble.R index c1ee0ff..6bf2c9e 100644 --- a/R/simple_ensemble.R +++ b/R/simple_ensemble.R @@ -25,8 +25,6 @@ #' `model_outputs` that contains the output type. #' @param output_type_id_col `character` string with the name of the column in #' `model_outputs` that contains the output type id. -#' @param value_col `character` string with the name of the column in -#' `model_outputs` that contains model output values. #' #' @details The default for `agg_fun` is `"mean"`, in which case the ensemble's #' output is the average of the component model outputs within each group diff --git a/man/simple_ensemble.Rd b/man/simple_ensemble.Rd index 3fd6ea4..d2b4caf 100644 --- a/man/simple_ensemble.Rd +++ b/man/simple_ensemble.Rd @@ -50,9 +50,6 @@ ensemble model.} \item{output_type_id_col}{\code{character} string with the name of the column in \code{model_outputs} that contains the output type id.} - -\item{value_col}{\code{character} string with the name of the column in -\code{model_outputs} that contains model output values.} } \value{ a data.frame with columns \code{model_id}, one column for