diff --git a/NEWS.md b/NEWS.md index 0ac5e7c..6df2b89 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,5 +1,7 @@ # hubEnsembles (development version) +* Fix bug in `simple_ensemble()` where `all.equal()` was sometimes throwing an error (#134) + # hubEnsembles 0.1.8 * Update README to point to hubverse R-universe diff --git a/R/simple_ensemble.R b/R/simple_ensemble.R index 655d0c9..203b10a 100644 --- a/R/simple_ensemble.R +++ b/R/simple_ensemble.R @@ -94,7 +94,7 @@ simple_ensemble <- function(model_out_tbl, weights = NULL, } # don't interpolate when calling `matrixStats::weightedMedian` - if (isTRUE(all.equal(agg_fun, matrixStats::weightedMedian))) { + if (identical(agg_fun, matrixStats::weightedMedian)) { agg_args <- c(agg_args, list(interpolate = FALSE)) } diff --git a/tests/testthat/test-simple_ensemble.R b/tests/testthat/test-simple_ensemble.R index 3d74274..740e417 100644 --- a/tests/testthat/test-simple_ensemble.R +++ b/tests/testthat/test-simple_ensemble.R @@ -281,3 +281,17 @@ test_that("duplicate forecast values still result in correct weighted median", { expect_equal(weighted_median_expected, weighted_median_actual) }) + +test_that("simple_ensemble accepts custom functions without error", { + geometric_mean <- function(x) { + n <- length(x) + return(prod(x)^(1 / n)) + } + + model_outputs |> + hubEnsembles::simple_ensemble( + agg_fun = geometric_mean, + model_id = "simple-ensemble-geometric" + ) |> + expect_no_error() +})