Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add collect_metrics() argument to pivot output #839

Merged
merged 3 commits into from
Feb 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: tune
Title: Tidy Tuning Tools
Version: 1.1.2.9018
Version: 1.1.2.9019
Authors@R: c(
person("Max", "Kuhn", , "[email protected]", role = c("aut", "cre"),
comment = c(ORCID = "0000-0003-2402-136X")),
Expand Down
4 changes: 3 additions & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

## New Features

* Added a new function, `compute_metrics()`, that allows for computing new metrics after evaluating against resamples. The arguments and output formats are closely related to those from `collect_metrics()`, but this function requires that the input be generated with the control option `save_pred = TRUE` and additionally takes a `metrics` argument with a metric set for new metrics to compute. This allows for computing new performance metrics without requiring users to re-fit and re-predict from each model (#663).
* Added a `type` argument to `collect_metrics()` to indicate the desired output format. The default, `type = "long"`, returns output as before, while `type = "wide"` pivots the output such that each metric has its own column (#839).

* Added a new function, `compute_metrics()`, that allows for computing new metrics after evaluating against resamples. The arguments and output formats are closely related to those from `collect_metrics()`, but this function requires that the input be generated with the control option `save_pred = TRUE` and additionally takes a `metrics` argument with a metric set for new metrics to compute. This allows for computing new performance metrics without requiring users to re-fit and re-predict from each model. (#663)

* A method for rsample's `int_pctl()` function that will compute percentile confidence intervals on performance metrics for objects produced by `fit_resamples()`, `tune_*()`, and `last_fit()`.

Expand Down
51 changes: 41 additions & 10 deletions R/collect.R
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,26 @@
#' used to filter the predicted values before processing. This tibble should
#' only have columns for each tuning parameter identifier (e.g. `"my_param"`
#' if `tune("my_param")` was used).
#' @param type One of `"long"` (the default) or `"wide"`. When `type = "long"`,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm conflicted on the name and values for this argument. Very much open to suggestions. :)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that a logical called pivot_wider is good. It implies that it is already long.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to check, we decided that it's unlikely that we'll ever deviate from something binary? No other forms of "wide" metrics anticipated?

#' output has columns `.metric` and one of `.estimate` or `mean`.
#' `.estimate`/`mean` gives the values for the `.metric`. When `type = "wide"`,
#' each metric has its own column and the `n` and `std_err` columns are removed,
#' if they exist.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't have strong opinions here and the original implementation in #689 had the same behavior, so I did it this way, but if we wanted we could have columns like rmse, rmse_n, rmse_std_err.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think excluding them is fine; we can add them later if someone asks for them.

#'
#' @param ... Not currently used.
#' @return A tibble. The column names depend on the results and the mode of the
#' model.
#'
#' For [collect_metrics()] and [collect_predictions()], when unsummarized,
#' there are columns for each tuning parameter (using the `id` from [tune()],
#' if any).
#' [collect_metrics()] also has columns `.metric`, and `.estimator`. When the
#' results are summarized, there are columns for `mean`, `n`, and `std_err`.
#' When not summarized, the additional columns for the resampling identifier(s)
#' and `.estimate`.
#'
#' [collect_metrics()] also has columns `.metric`, and `.estimator` by default.
#' For [collect_metrics()] methods that have a `type` argument, supplying
#' `type = "wide"` will pivot the output such that each metric has its own
#' column. When the results are summarized, there are columns for `mean`, `n`,
#' and `std_err`. When not summarized, the additional columns for the resampling
#' identifier(s) and `.estimate`.
#'
#' For [collect_predictions()], there are additional columns for the resampling
#' identifier(s), columns for the predicted values (e.g., `.pred`,
Expand Down Expand Up @@ -445,19 +454,41 @@ collect_metrics.default <- function(x, ...) {

#' @export
#' @rdname collect_predictions
collect_metrics.tune_results <- function(x, summarize = TRUE, ...) {
collect_metrics.tune_results <- function(x, summarize = TRUE, type = c("long", "wide"), ...) {
rlang::arg_match0(type, values = c("long", "wide"))

if (inherits(x, "last_fit")) {
return(x$.metrics[[1]])
res <- x$.metrics[[1]]
} else {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not stoked about this design pattern.😞 But wasn't clear to me that there's a super clean way to write this.

if (summarize) {
res <- estimate_tune_results(x)
} else {
res <- collector(x, coll_col = ".metrics")
}
}

if (summarize) {
res <- estimate_tune_results(x)
} else {
res <- collector(x, coll_col = ".metrics")
if (identical(type, "wide")) {
res <- pivot_metrics(x, res)
}

res
}

pivot_metrics <- function(x, x_metrics) {
params <- .get_tune_parameter_names(x)
res <- paste_param_by(x_metrics)

tidyr::pivot_wider(
res,
id_cols = c(
dplyr::any_of(c(params, ".config", ".iter", ".eval_time")),
starts_with("id")
),
names_from = .metric,
values_from = dplyr::any_of(c(".estimate", "mean"))
)
}

collector <- function(x, coll_col = ".predictions") {
is_iterative <- any(colnames(x) == ".iter")
if (is_iterative) {
Expand Down
19 changes: 14 additions & 5 deletions man/collect_predictions.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

16 changes: 16 additions & 0 deletions tests/testthat/_snaps/collect.md
Original file line number Diff line number Diff line change
Expand Up @@ -135,3 +135,19 @@
Error in `collect_metrics()`:
! No `collect_metrics()` exists for a <lm> object.

# `collect_metrics(type)` errors informatively with bad input

Code
collect_metrics(ames_grid_search, type = "boop")
Condition
Error in `collect_metrics()`:
! `type` must be one of "long" or "wide", not "boop".

---

Code
collect_metrics(ames_grid_search, type = NULL)
Condition
Error in `collect_metrics()`:
! `type` must be a string or character vector.

102 changes: 102 additions & 0 deletions tests/testthat/test-collect.R
Original file line number Diff line number Diff line change
Expand Up @@ -247,3 +247,105 @@ test_that("`collect_metrics()` errors informatively applied to unsupported class
collect_metrics(lm(mpg ~ disp, mtcars))
)
})

test_that("`collect_metrics(type)` errors informatively with bad input", {
skip_on_cran()

expect_snapshot(
error = TRUE,
collect_metrics(ames_grid_search, type = "boop")
)

expect_snapshot(
error = TRUE,
collect_metrics(ames_grid_search, type = NULL)
)
})

test_that("`pivot_metrics()`, grid search, typical metrics, summarized", {
expect_equal(
pivot_metrics(ames_grid_search, collect_metrics(ames_grid_search)) %>%
dplyr::slice(),
tibble::tibble(
K = integer(0),
weight_func = character(0),
dist_power = numeric(0),
lon = integer(0),
lat = integer(0),
.config = character(0),
rmse = numeric(0),
rsq = numeric(0)
)
)
})

test_that("`pivot_metrics()`, grid search, typical metrics, unsummarized", {
expect_equal(
pivot_metrics(
ames_grid_search,
collect_metrics(ames_grid_search, summarize = FALSE)
) %>%
dplyr::slice(),
tibble::tibble(
K = integer(0),
weight_func = character(0),
dist_power = numeric(0),
lon = integer(0),
lat = integer(0),
.config = character(0),
id = character(0),
rmse = numeric(0),
rsq = numeric(0)
)
)
})

test_that("`pivot_metrics()`, iterative search, typical metrics, summarized", {
expect_equal(
pivot_metrics(ames_iter_search, collect_metrics(ames_iter_search)) %>%
dplyr::slice(),
tibble::tibble(
K = integer(0),
weight_func = character(0),
dist_power = numeric(0),
lon = integer(0),
lat = integer(0),
.config = character(0),
.iter = integer(0),
rmse = numeric(0),
rsq = numeric(0)
)
)
})

test_that("`pivot_metrics()`, resampled fits, fairness metrics, summarized", {
mtcars_fair <- mtcars
mtcars_fair$vs <- as.factor(mtcars_fair$vs)
mtcars_fair$cyl <- as.factor(mtcars_fair$cyl)
mtcars_fair$am <- as.factor(mtcars_fair$am)
set.seed(4400)

ms <-
yardstick::metric_set(
yardstick::demographic_parity(cyl),
yardstick::demographic_parity(am)
)

res <-
fit_resamples(
nearest_neighbor("classification"),
vs ~ mpg + hp + cyl,
rsample::bootstraps(mtcars_fair, 3),
metrics = ms
)

expect_equal(
pivot_metrics(res, collect_metrics(res)) %>% slice(),
tibble::tibble(
.config = character(0),
`demographic_parity(am)` = integer(0),
`demographic_parity(cyl)` = numeric(0),
)
)
})

Loading