diff --git a/DESCRIPTION b/DESCRIPTION index 65b0dbc..7187dc5 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,6 +1,6 @@ Package: hubEnsembles Title: Ensemble methods for combining hub model outputs -Version: 0.1.3 +Version: 0.1.4 Authors@R: c( person(given = "Evan L", family = "Ray", diff --git a/R/linear_pool.R b/R/linear_pool.R index 3c3a941..212d46d 100644 --- a/R/linear_pool.R +++ b/R/linear_pool.R @@ -66,7 +66,6 @@ #' all.equal(lp_from_component_qs$value, lp_qs, tolerance = 1e-3, #' check.attributes=FALSE) #' -#' @importFrom rlang .data linear_pool <- function(model_outputs, weights = NULL, weights_col_name = "weight", @@ -89,8 +88,8 @@ linear_pool <- function(model_outputs, weights = NULL, # calculate linear opinion pool for different types ensemble_model_outputs <- model_outputs_validated |> - dplyr::group_split(.data$output_type) |> - purrr::map_dfr(.f = function(split_outputs) { + dplyr::group_split("output_type") |> + purrr::map(.f = function(split_outputs) { type <- split_outputs$output_type[1] if (type %in% c("mean", "cdf", "pmf")) { simple_ensemble(split_outputs, weights = weights_validated, @@ -107,6 +106,7 @@ linear_pool <- function(model_outputs, weights = NULL, ...) } }) |> + purrr::list_rbind() |> hubUtils::as_model_out_tbl() return(ensemble_model_outputs) diff --git a/R/linear_pool_quantile.R b/R/linear_pool_quantile.R index 6dc63b5..86d00b8 100644 --- a/R/linear_pool_quantile.R +++ b/R/linear_pool_quantile.R @@ -39,13 +39,13 @@ linear_pool_quantile <- function(model_outputs, weights = NULL, dplyr::summarize( pred_qs = list( distfromq::make_q_fn( - ps = as.numeric(.data$output_type_id), - qs = .data$value, ... + ps = as.numeric(.data[["output_type_id"]]), + qs = .data[["value"]], ... )(sample_q_lvls) ), .groups = "drop" ) |> - tidyr::unnest(.data$pred_qs) |> + tidyr::unnest("pred_qs") |> dplyr::group_by(dplyr::across(dplyr::all_of(task_id_cols))) |> dplyr::summarize( output_type_id = list(quantile_levels), @@ -54,7 +54,7 @@ linear_pool_quantile <- function(model_outputs, weights = NULL, ) |> tidyr::unnest(cols = tidyselect::all_of(c("output_type_id", "value"))) |> dplyr::mutate(model_id = model_id, .before = 1) |> - dplyr::mutate(output_type = "quantile", .before = .data$output_type_id) |> + dplyr::mutate(output_type = "quantile", .before = "output_type_id") |> dplyr::ungroup() return(quantile_outputs) diff --git a/R/validate_output_type_ids.R b/R/validate_output_type_ids.R index 33fb20f..30adb71 100644 --- a/R/validate_output_type_ids.R +++ b/R/validate_output_type_ids.R @@ -12,17 +12,17 @@ #' are `mean`, `quantile`, `cdf`, `pmf`, and `sample`. #' #' @return no return value -#' #' @noRd +#' #' @importFrom rlang .data validate_output_type_ids <- function(model_outputs, task_id_cols) { same_output_id <- model_outputs |> - dplyr::filter(.data$output_type %in% c("cdf", "pmf", "quantile")) |> - dplyr::group_by(.data$model_id, dplyr::across(dplyr::all_of(task_id_cols)), .data$output_type) |> - dplyr::summarize(output_type_id_list = list(sort(.data$output_type_id))) |> + dplyr::filter(.data[["output_type"]] %in% c("cdf", "pmf", "quantile")) |> + dplyr::group_by(dplyr::across(c(dplyr::all_of(task_id_cols), "model_id", "output_type"))) |> + dplyr::summarize(output_type_id_list = list(sort(.data[["output_type_id"]]))) |> dplyr::ungroup() |> - dplyr::group_split(dplyr::across(dplyr::all_of(task_id_cols)), .data$output_type) |> + dplyr::group_split(dplyr::across(dplyr::all_of(task_id_cols)), "output_type") |> purrr::map(.f = function(split_outputs) { length(unique(split_outputs$output_type_id_list)) == 1 }) |> diff --git a/tests/testthat/test-linear_pool.R b/tests/testthat/test-linear_pool.R index 4ad8b8a..a123962 100644 --- a/tests/testthat/test-linear_pool.R +++ b/tests/testthat/test-linear_pool.R @@ -78,7 +78,7 @@ test_that("component model outputs and resulting ensemble model outputs have ide c(250, 350, 500, 350) expected_output_type_ids <- data.frame(quantile_outputs) |> - dplyr::pull(output_type_id) |> + dplyr::pull("output_type_id") |> unique() |> sort() @@ -87,7 +87,7 @@ test_that("component model outputs and resulting ensemble model outputs have ide weights_col_name = NULL, model_id = "hub-ensemble", task_id_cols = NULL) |> - dplyr::pull(output_type_id) |> + dplyr::pull("output_type_id") |> unique() |> sort() diff --git a/tests/testthat/test-simple_ensemble.R b/tests/testthat/test-simple_ensemble.R index e62d0c5..38f0260 100644 --- a/tests/testthat/test-simple_ensemble.R +++ b/tests/testthat/test-simple_ensemble.R @@ -97,7 +97,7 @@ test_that("component model outputs and resulting ensemble model outputs have ide weights_col_name = NULL, model_id = "hub-ensemble", task_id_cols = NULL) |> - dplyr::pull(output_type_id) |> + dplyr::pull("output_type_id") |> unique() |> sort() @@ -227,11 +227,11 @@ test_that("(weighted) medians and means correctly calculated", { test_that("(weighted) medians and means work with alternate name for weights columns", { weighted_median_actual <- model_outputs |> - simple_ensemble(weights = fweight %>% dplyr::rename(w = weight), + simple_ensemble(weights = fweight %>% dplyr::rename(w = "weight"), weights_col_name = "w", agg_fun = "median") weighted_mean_actual <- model_outputs |> - simple_ensemble(weights = fweight %>% dplyr::rename(w = weight), + simple_ensemble(weights = fweight %>% dplyr::rename(w = "weight"), weights_col_name = "w", agg_fun = "mean")