-
Notifications
You must be signed in to change notification settings - Fork 42
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add collect_metrics()
argument to pivot output
#839
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
Package: tune | ||
Title: Tidy Tuning Tools | ||
Version: 1.1.2.9018 | ||
Version: 1.1.2.9019 | ||
Authors@R: c( | ||
person("Max", "Kuhn", , "[email protected]", role = c("aut", "cre"), | ||
comment = c(ORCID = "0000-0003-2402-136X")), | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,17 +11,26 @@ | |
#' used to filter the predicted values before processing. This tibble should | ||
#' only have columns for each tuning parameter identifier (e.g. `"my_param"` | ||
#' if `tune("my_param")` was used). | ||
#' @param type One of `"long"` (the default) or `"wide"`. When `type = "long"`, | ||
#' output has columns `.metric` and one of `.estimate` or `mean`. | ||
#' `.estimate`/`mean` gives the values for the `.metric`. When `type = "wide"`, | ||
#' each metric has its own column and the `n` and `std_err` columns are removed, | ||
#' if they exist. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I didn't have strong opinions here and the original implementation in #689 had the same behavior, so I did it this way, but if we wanted we could have columns like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think excluding them is fine; we can add them later if someone asks for them. |
||
#' | ||
#' @param ... Not currently used. | ||
#' @return A tibble. The column names depend on the results and the mode of the | ||
#' model. | ||
#' | ||
#' For [collect_metrics()] and [collect_predictions()], when unsummarized, | ||
#' there are columns for each tuning parameter (using the `id` from [tune()], | ||
#' if any). | ||
#' [collect_metrics()] also has columns `.metric`, and `.estimator`. When the | ||
#' results are summarized, there are columns for `mean`, `n`, and `std_err`. | ||
#' When not summarized, the additional columns for the resampling identifier(s) | ||
#' and `.estimate`. | ||
#' | ||
#' [collect_metrics()] also has columns `.metric`, and `.estimator` by default. | ||
#' For [collect_metrics()] methods that have a `type` argument, supplying | ||
#' `type = "wide"` will pivot the output such that each metric has its own | ||
#' column. When the results are summarized, there are columns for `mean`, `n`, | ||
#' and `std_err`. When not summarized, the additional columns for the resampling | ||
#' identifier(s) and `.estimate`. | ||
#' | ||
#' For [collect_predictions()], there are additional columns for the resampling | ||
#' identifier(s), columns for the predicted values (e.g., `.pred`, | ||
|
@@ -445,19 +454,41 @@ collect_metrics.default <- function(x, ...) { | |
|
||
#' @export | ||
#' @rdname collect_predictions | ||
collect_metrics.tune_results <- function(x, summarize = TRUE, ...) { | ||
collect_metrics.tune_results <- function(x, summarize = TRUE, type = c("long", "wide"), ...) { | ||
rlang::arg_match0(type, values = c("long", "wide")) | ||
|
||
if (inherits(x, "last_fit")) { | ||
return(x$.metrics[[1]]) | ||
res <- x$.metrics[[1]] | ||
} else { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not stoked about this design pattern.😞 But wasn't clear to me that there's a super clean way to write this. |
||
if (summarize) { | ||
res <- estimate_tune_results(x) | ||
} else { | ||
res <- collector(x, coll_col = ".metrics") | ||
} | ||
} | ||
|
||
if (summarize) { | ||
res <- estimate_tune_results(x) | ||
} else { | ||
res <- collector(x, coll_col = ".metrics") | ||
if (identical(type, "wide")) { | ||
res <- pivot_metrics(x, res) | ||
} | ||
|
||
res | ||
} | ||
|
||
pivot_metrics <- function(x, x_metrics) { | ||
params <- .get_tune_parameter_names(x) | ||
res <- paste_param_by(x_metrics) | ||
|
||
tidyr::pivot_wider( | ||
res, | ||
id_cols = c( | ||
dplyr::any_of(c(params, ".config", ".iter", ".eval_time")), | ||
starts_with("id") | ||
), | ||
names_from = .metric, | ||
values_from = dplyr::any_of(c(".estimate", "mean")) | ||
) | ||
} | ||
|
||
collector <- function(x, coll_col = ".predictions") { | ||
is_iterative <- any(colnames(x) == ".iter") | ||
if (is_iterative) { | ||
|
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm conflicted on the name and values for this argument. Very much open to suggestions. :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that a logical called
pivot_wider
is good. It implies that it is already long.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just to check, we decided that it's unlikely that we'll ever deviate from something binary? No other forms of "wide" metrics anticipated?