Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simple ensemble #4

Merged
merged 25 commits into from
Jun 8, 2023
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
6672a15
Create simple_ensemble.R
Feb 17, 2023
2a37506
add example-data for function testing
Feb 17, 2023
f676190
fix weighted vs unweighted ensemble code
Apr 24, 2023
cf1a459
minor fixes to simple_ensemble() function
Apr 25, 2023
10f1db5
add unweighted ensemble testing protocol and test data
Apr 25, 2023
4420a9f
Remove hard coding of output_type, output_id, value columns
Apr 28, 2023
d6cad09
Add additional tests for simple_ensemble()
Apr 28, 2023
3900a91
clean up simple_ensemble() testing file
May 2, 2023
935c420
remove extraneous test data
May 2, 2023
15a9299
remove old testing file
May 2, 2023
7588b8a
fix file path to test data
May 24, 2023
a88c693
progress on refactoring simple_ensemble to use cli for messaging and …
elray1 Jun 5, 2023
39f601e
warn if 0 rows
elray1 Jun 6, 2023
0794377
add package imports, validations for weights
elray1 Jun 6, 2023
783e936
misc updates
elray1 Jun 7, 2023
6671da2
remove stale example data, update data example in unit tests
elray1 Jun 7, 2023
dbe05e1
updates to description
elray1 Jun 8, 2023
79515c0
fix authors field in DESCRIPTION
elray1 Jun 8, 2023
ac7e964
add magrittr to package imports
elray1 Jun 8, 2023
2a99105
Merge branch 'main' into simple_ensemble
elray1 Jun 8, 2023
a8e6ece
Update R/simple_ensemble.R
elray1 Jun 8, 2023
4bc4d1d
updates in response to anna's pr comments, weights_col_name as argument
elray1 Jun 8, 2023
38d1a29
update docs for simple_ensemble
elray1 Jun 8, 2023
2a92d7a
remove hub_connection as argument to simple_ensemble
elray1 Jun 8, 2023
483b960
remove doc for removed parameter value_col from simple_ensemble
elray1 Jun 8, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 21 additions & 5 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,16 +1,32 @@
Package: hubEnsembles
Title: What the Package Does (One Line, Title Case)
Title: Ensemble methods for combining hub model outputs.
Version: 0.0.0.9000
Authors@R:
Authors@R: c(
person("Anna", "Krystalli", , "[email protected]", role = c("aut", "cre"),
comment = c(ORCID = "0000-0002-2378-4915"))
Description: What the package does (one paragraph).
comment = c(ORCID = "0000-0002-2378-4915")),
person(given = "Evan L",
family = "Ray",
role = c("aut")),
person(given = "Li",
family = "Shandross",
role = c("aut")))
Description: Functions for combining model outputs (e.g. predictions or
estimates) from multiple models into an aggregated ensemble model output.
License: MIT + file LICENSE
Suggests:
testthat (>= 3.0.0)
Config/testthat/edition: 3
Encoding: UTF-8
Roxygen: list(markdown = TRUE)
RoxygenNote: 7.2.2
RoxygenNote: 7.2.3
URL: https://github.com/Infectious-Disease-Modeling-Hubs/hubEnsembles
BugReports: https://github.com/Infectious-Disease-Modeling-Hubs/hubEnsembles/issues
Imports:
cli,
dplyr,
hubUtils,
magrittr,
matrixStats,
rlang
Remotes:
Infectious-Disease-Modeling-Hubs/hubUtils
2 changes: 2 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
# Generated by roxygen2: do not edit by hand

export("%>%")
importFrom(magrittr,"%>%")
140 changes: 140 additions & 0 deletions R/simple_ensemble.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
#' Compute ensemble model outputs by summarizing component model outputs for
#' each combination of model task, output type, and output type id. Supported
#' output types include `mean`, `median`, `quantile`, `cdf`, and `pmf`.
#'
#' @param model_outputs an object of class `model_output_df` with component
#' model outputs (e.g., predictions).
#' @param weights an optional `data.frame` with component model weights. If
#' provided, it should have columns `model_id`, `weight`, and optionally,
#' additional columns corresponding to task id variables, `output_type`, or
#' `output_type_id`, if weights are specific to values of those variables. The
#' default is `NULL`, in which case an equally-weighted ensemble is calculated.
#' @param agg_fun a function or character string name of a function to use for
#' aggregating component model outputs into the ensemble outputs. See the
#' details for more information.
#' @param agg_args a named list of any additional arguments that will be passed
#' to `agg_fun`.
#' @param model_id `character` string with the identifier to use for the
#' ensemble model.
#' @param task_id_cols, output_type_col, output_type_id_col, value_col
#' `character` vectors with the names of the columns in `model_outputs` for
#' the output's type, additional identifying information, and value of the
#' model output.
#'
#' @details The default for `agg_fun` is `mean`, in which case the ensemble's
#' output is the average of the component model outputs within each group
#' defined by a combination of values in the task id columns, output type, and
#' output type id. The provided `agg_fun` should have an argument `x` for the
elray1 marked this conversation as resolved.
Show resolved Hide resolved
#' vector of numeric values to summarize, and for weighted methods, an
#' argument `w` with a numeric vector of weights. For weighted methods,
#' `agg_fun = "mean"` and `agg_fun = "median"` are translated to use
#' `matrixStats::weightedMean` and `matrixStats::weightedMedian` respectively.
#'
#' @return a data.frame with columns `team_abbr`, `model_abbr`, one column for
elray1 marked this conversation as resolved.
Show resolved Hide resolved
#' each task id variable, `output_type`, `output_id`, and `value`. Note that
#' any additional columns in the input `model_outputs` are dropped.
simple_ensemble <- function(model_outputs, weights = NULL,
agg_fun = "mean", agg_args = list(),
model_id = "hub-ensemble",
task_id_cols = NULL,
output_type_col = "output_type",
output_type_id_col = "output_type_id",
hub_connection = NULL) {
model_out_cols <- colnames(model_outputs)
elray1 marked this conversation as resolved.
Show resolved Hide resolved
if (!is.data.frame(model_outputs)) {
cli::cli_abort(c("x" = "{.arg model_outputs} must be a `data.frame`."))
}

non_task_cols <- c("model_id", output_type_col, output_type_id_col, "value")
if (is.null(task_id_cols)) {
task_id_cols <- model_out_cols[!model_out_cols %in% non_task_cols]
}

req_col_names <- c(non_task_cols, task_id_cols)
if (!all(req_col_names %in% model_out_cols)) {
cli::cli_abort(c(
"x" = "{.arg model_outputs} did not have all required columns
{.val {req_col_names}}."
))
}

## Validations above this point to be relocated to hubUtils
# hubUtils::validate_model_output_df(model_outputs)

if (nrow(model_outputs) == 0) {
cli::cli_warn(c("!" = "{.arg model_outputs} has zero rows."))
}

valid_types <- c("mean", "median", "quantile", "cdf", "pmf")
unique_types <- unique(model_outputs[[output_type_col]])
invalid_types <- unique_types[!unique_types %in% valid_types]
if (length(invalid_types) > 0) {
cli::cli_abort(c(
"x" = "{.arg model_outputs} contains unsupported output type.",
"i" = "Included output type{?s}: {.val {invalid_types}}.",
elray1 marked this conversation as resolved.
Show resolved Hide resolved
"i" = "Supported output types: {.val {valid_types}}."
))
}

if (is.null(weights)) {
agg_args <- c(agg_args, list(x = quote(.data[["value"]])))
} else {
req_weight_cols <- c("model_id", "weight")
if (!all(req_weight_cols %in% colnames(weights))) {
cli::cli_abort(c(
"x" = "{.arg weights} did not include required columns
{.val {req_weight_cols}}."
))
}

weight_by_cols <- colnames(weights)[colnames(weights) != "weight"]

if ("value" %in% weight_by_cols) {
cli::cli_abort(c(
elray1 marked this conversation as resolved.
Show resolved Hide resolved
"x" = "{.arg weights} included a column named {.val {\"value\"}},
which is not allowed."
))
}

invalid_cols <- weight_by_cols[!weight_by_cols %in% colnames(model_outputs)]
if (!all(weight_by_cols %in% colnames(model_outputs))) {
cli::cli_abort(c(
elray1 marked this conversation as resolved.
Show resolved Hide resolved
"x" = "{.arg weights} included {length(invalid_cols)} column{?s} that
{?was/were} not present in {.arg model_outputs}:
{.val {invalid_cols}}"
))
}

if ("weight" %in% colnames(model_outputs)) {
weight_col_name <- paste0("weight_", rlang::hash(colnames(model_outputs)))
weights <- weights %>% dplyr::rename(!!weight_col_name := "weight")
} else {
weight_col_name <- "weight"
}

model_outputs <- model_outputs %>%
dplyr::left_join(weights, by = weight_by_cols)

if (is.character(agg_fun)) {
if (agg_fun == "mean") {
agg_fun <- matrixStats::weightedMean
} else if (agg_fun == "median") {
agg_fun <- matrixStats::weightedMedian
}
}

agg_args <- c(agg_args, list(x = quote(.data[["value"]]),
w = quote(.data[[weight_col_name]])))
}

group_by_cols <- c(task_id_cols, output_type_col, output_type_id_col)
ensemble_model_outputs <- model_outputs %>%
dplyr::group_by(dplyr::across(dplyr::all_of(group_by_cols))) %>%
dplyr::summarize(value = do.call(agg_fun, args = agg_args)) %>%
dplyr::mutate(model_id = model_id, .before = 1) %>%
dplyr::ungroup()

# hubUtils::as_model_output_df(ensemble_model_outputs)

return(ensemble_model_outputs)
}
14 changes: 14 additions & 0 deletions R/utils-pipe.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
#' Pipe operator
#'
#' See \code{magrittr::\link[magrittr:pipe]{\%>\%}} for details.
#'
#' @name %>%
#' @rdname pipe
#' @keywords internal
#' @export
#' @importFrom magrittr %>%
#' @usage lhs \%>\% rhs
#' @param lhs A value or the magrittr placeholder.
#' @param rhs A function call using the magrittr semantics.
#' @return The result of calling `rhs(lhs)`.
NULL
20 changes: 20 additions & 0 deletions man/pipe.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

65 changes: 65 additions & 0 deletions man/simple_ensemble.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading