Skip to content

Commit

Permalink
Placement of the eval_time argument (#857)
Browse files Browse the repository at this point in the history
  • Loading branch information
hfrick authored Feb 28, 2024
1 parent 4b1c857 commit a8f0772
Show file tree
Hide file tree
Showing 19 changed files with 170 additions and 117 deletions.
5 changes: 5 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,11 @@

* For iterative optimization routines, `autoplot()` will use integer breaks when `type = "performance"` or `type = "parameters"`.

## Breaking Change

* Several functions gain an `eval_time` argument for the evaluation time of dynamic metrics for censored regression. The placement of the argument breaks passing-by-position for one or more other arguments to `fit_best.tune_results()`, `show_best.tune_results()`, and the developer-focused `check_initial()` (#857).


# tune 1.1.2

* `last_fit()` now works with the 3-way validation split objects from `rsample::initial_validation_split()`. `last_fit()` and `fit_best()` now have a new argument `add_validation_set` to include or exclude the validation set in the dataset used to fit the model (#701).
Expand Down
12 changes: 9 additions & 3 deletions R/checks.R
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,13 @@ bayes_msg <- "`initial` should be a positive integer or the results of [tune_gri
#' @param wflow A `workflow` object.
#' @param resamples An `rset` object.
#' @param ctrl A `control_grid` object.
check_initial <- function(x, pset, wflow, resamples, metrics, ctrl, eval_time,
check_initial <- function(x,
pset,
wflow,
resamples,
metrics,
eval_time,
ctrl,
checks = "grid") {
if (is.null(x)) {
rlang::abort(bayes_msg)
Expand All @@ -424,9 +430,9 @@ check_initial <- function(x, pset, wflow, resamples, metrics, ctrl, eval_time,
resamples = resamples,
grid = x,
metrics = metrics,
eval_time = eval_time,
param_info = pset,
control = parsnip::condense_control(grid_ctrl, control_grid()),
eval_time = eval_time
control = parsnip::condense_control(grid_ctrl, control_grid())
)

if (ctrl$verbose) {
Expand Down
2 changes: 1 addition & 1 deletion R/fit_best.R
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,9 @@ fit_best.default <- function(x, ...) {
#' @rdname fit_best
fit_best.tune_results <- function(x,
metric = NULL,
eval_time = NULL,
parameters = NULL,
verbose = FALSE,
eval_time = NULL,
add_validation_set = NULL,
...) {
if (length(list(...))) {
Expand Down
28 changes: 20 additions & 8 deletions R/last_fit.R
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ last_fit.model_fit <- function(object, ...) {
#' @export
#' @rdname last_fit
last_fit.model_spec <- function(object, preprocessor, split, ..., metrics = NULL,
control = control_last_fit(), eval_time = NULL,
eval_time = NULL, control = control_last_fit(),
add_validation_set = FALSE) {
if (rlang::is_missing(preprocessor) || !is_preprocessor(preprocessor)) {
rlang::abort(paste(
Expand All @@ -141,30 +141,42 @@ last_fit.model_spec <- function(object, preprocessor, split, ..., metrics = NULL
wflow <- add_formula(wflow, preprocessor)
}

last_fit_workflow(wflow, split, metrics, control, eval_time,
add_validation_set)
last_fit_workflow(
wflow,
split = split,
metrics = metrics,
eval_time = eval_time,
control = control,
add_validation_set = add_validation_set
)
}


#' @rdname last_fit
#' @export
last_fit.workflow <- function(object, split, ..., metrics = NULL,
control = control_last_fit(), eval_time = NULL,
eval_time = NULL, control = control_last_fit(),
add_validation_set = FALSE) {
empty_ellipses(...)

control <- parsnip::condense_control(control, control_last_fit())

last_fit_workflow(object, split, metrics, control, eval_time,
add_validation_set)
last_fit_workflow(
object,
split = split,
metrics = metrics,
eval_time = eval_time,
control = control,
add_validation_set = add_validation_set
)
}


last_fit_workflow <- function(object,
split,
metrics,
control,
eval_time = NULL,
control,
add_validation_set = FALSE,
...,
call = rlang::caller_env()) {
Expand Down Expand Up @@ -192,8 +204,8 @@ last_fit_workflow <- function(object,
workflow = object,
resamples = resamples,
metrics = metrics,
control = control,
eval_time = eval_time,
control = control,
rng = rng,
call = call
)
Expand Down
6 changes: 3 additions & 3 deletions R/plots.R
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,8 @@ autoplot.tune_results <-
function(object,
type = c("marginals", "parameters", "performance"),
metric = NULL,
width = NULL,
eval_time = NULL,
width = NULL,
call = rlang::current_env(),
...) {
type <- match.arg(type)
Expand Down Expand Up @@ -108,7 +108,7 @@ autoplot.tune_results <-
p <- plot_param_vs_iter(object, call)
} else {
if (type == "performance") {
p <- plot_perf_vs_iter(object, metric, width, eval_time = eval_time, call)
p <- plot_perf_vs_iter(object, metric, eval_time = eval_time, width, call)
} else {
if (use_regular_grid_plot(object)) {
p <- plot_regular_grid(object, metric = metric, eval_time = eval_time, call, ...)
Expand Down Expand Up @@ -278,7 +278,7 @@ process_autoplot_metrics <- function(x, metric, eval_time) {

# ------------------------------------------------------------------------------

plot_perf_vs_iter <- function(x, metric = NULL, width = NULL, eval_time = NULL,
plot_perf_vs_iter <- function(x, metric = NULL, eval_time = NULL, width = NULL,
call = rlang::caller_env()) {
if (is.null(width)) {
width <- max(x$.iter) / 75
Expand Down
20 changes: 10 additions & 10 deletions R/resample.R
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,8 @@ fit_resamples.model_spec <- function(object,
resamples,
...,
metrics = NULL,
control = control_resamples(),
eval_time = NULL) {
eval_time = NULL,
control = control_resamples()) {
if (rlang::is_missing(preprocessor) || !is_preprocessor(preprocessor)) {
rlang::abort(paste(
"To tune a model spec, you must preprocess",
Expand All @@ -99,8 +99,8 @@ fit_resamples.model_spec <- function(object,
wflow,
resamples = resamples,
metrics = metrics,
control = control,
eval_time = eval_time
eval_time = eval_time,
control = control
)
}

Expand All @@ -111,8 +111,8 @@ fit_resamples.workflow <- function(object,
resamples,
...,
metrics = NULL,
control = control_resamples(),
eval_time = NULL) {
eval_time = NULL,
control = control_resamples()) {
empty_ellipses(...)

control <- parsnip::condense_control(control, control_resamples())
Expand All @@ -122,8 +122,8 @@ fit_resamples.workflow <- function(object,
workflow = object,
resamples = resamples,
metrics = metrics,
control = control,
eval_time = eval_time,
control = control,
rng = TRUE
)
.stash_last_result(res)
Expand All @@ -132,8 +132,8 @@ fit_resamples.workflow <- function(object,

# ------------------------------------------------------------------------------

resample_workflow <- function(workflow, resamples, metrics, control,
eval_time = NULL, rng, call = caller_env()) {
resample_workflow <- function(workflow, resamples, metrics, eval_time = NULL,
control, rng, call = caller_env()) {
check_no_tuning(workflow)

# `NULL` is the signal that we have no grid to tune with
Expand All @@ -145,9 +145,9 @@ resample_workflow <- function(workflow, resamples, metrics, control,
resamples = resamples,
grid = grid,
metrics = metrics,
eval_time = eval_time,
pset = pset,
control = control,
eval_time = eval_time,
rng = rng,
call = call
)
Expand Down
9 changes: 7 additions & 2 deletions R/select_best.R
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,12 @@ show_best.default <- function(x, ...) {

#' @export
#' @rdname show_best
show_best.tune_results <- function(x, metric = NULL, n = 5, eval_time = NULL, ..., call = rlang::current_env()) {
show_best.tune_results <- function(x,
metric = NULL,
eval_time = NULL,
n = 5,
...,
call = rlang::current_env()) {
rlang::check_dots_empty()

metric_info <- choose_metric(x, metric, call = call)
Expand Down Expand Up @@ -141,7 +146,7 @@ select_by_pct_loss.default <- function(x, ...) {

#' @export
#' @rdname show_best
select_by_pct_loss.tune_results <- function(x, ..., metric = NULL, limit = 2, eval_time = NULL) {
select_by_pct_loss.tune_results <- function(x, ..., metric = NULL, eval_time = NULL, limit = 2) {
metric_info <- choose_metric(x, metric)
metric <- metric_info$metric

Expand Down
72 changes: 48 additions & 24 deletions R/tune_bayes.R
Original file line number Diff line number Diff line change
Expand Up @@ -198,10 +198,10 @@ tune_bayes.model_spec <- function(object,
iter = 10,
param_info = NULL,
metrics = NULL,
eval_time = NULL,
objective = exp_improve(),
initial = 5,
control = control_bayes(),
eval_time = NULL) {
control = control_bayes()) {
if (rlang::is_missing(preprocessor) || !is_preprocessor(preprocessor)) {
rlang::abort(paste(
"To tune a model spec, you must preprocess",
Expand All @@ -222,9 +222,15 @@ tune_bayes.model_spec <- function(object,

tune_bayes_workflow(
wflow,
resamples = resamples, iter = iter, param_info = param_info,
metrics = metrics, objective = objective, initial = initial,
control = control, eval_time = eval_time, ...
resamples = resamples,
iter = iter,
param_info = param_info,
metrics = metrics,
eval_time = eval_time,
objective = objective,
initial = initial,
control = control,
...
)
}

Expand All @@ -238,30 +244,43 @@ tune_bayes.workflow <-
iter = 10,
param_info = NULL,
metrics = NULL,
eval_time = NULL,
objective = exp_improve(),
initial = 5,
control = control_bayes(),
eval_time = NULL) {
control = control_bayes()) {

# set `seed` so that calling `control_bayes()` doesn't alter RNG state (#721)
control <- parsnip::condense_control(control, control_bayes(seed = 1))

res <-
tune_bayes_workflow(
object,
resamples = resamples, iter = iter, param_info = param_info,
metrics = metrics, objective = objective, initial = initial,
control = control, eval_time = eval_time, ...
resamples = resamples,
iter = iter,
param_info = param_info,
metrics = metrics,
eval_time = eval_time,
objective = objective,
initial = initial,
control = control,
...
)
.stash_last_result(res)
res
}

tune_bayes_workflow <-
function(object, resamples, iter = 10, param_info = NULL, metrics = NULL,
objective = exp_improve(),
initial = 5, control, eval_time = NULL, ...,
call = caller_env()) {
tune_bayes_workflow <- function(object,
resamples,
iter = 10,
param_info = NULL,
metrics = NULL,
eval_time = NULL,
objective = exp_improve(),
initial = 5,
control,
...,
call = caller_env()) {

start_time <- proc.time()[3]

initialize_catalog(control = control)
Expand All @@ -287,8 +306,13 @@ tune_bayes_workflow <-
check_backend_options(control$backend_options)

unsummarized <- check_initial(
initial, param_info, object, resamples,
metrics, control, eval_time,
initial,
pset = param_info,
wflow = object,
resamples = resamples,
metrics = metrics,
eval_time = eval_time,
ctrl = control,
checks = "bayes"
)

Expand Down Expand Up @@ -423,9 +447,9 @@ tune_bayes_workflow <-
resamples = resamples,
candidates = candidates,
metrics = metrics,
eval_time = eval_time,
control = control,
param_info = param_info,
eval_time = eval_time
param_info = param_info
)

check_time(start_time, control$time_limit)
Expand Down Expand Up @@ -557,7 +581,7 @@ encode_set <- function(x, pset, as_matrix = FALSE, ...) {
x
}

fit_gp <- function(dat, pset, metric, control, eval_time = NULL, ...) {
fit_gp <- function(dat, pset, metric, eval_time = NULL, control, ...) {
dat <- dat %>% dplyr::filter(.metric == metric)

if (!is.null(eval_time)) {
Expand Down Expand Up @@ -738,8 +762,8 @@ initial_info <- function(stats, metrics, maximize, eval_time) {
# ------------------------------------------------------------------------------


more_results <- function(object, resamples, candidates, metrics, control,
param_info, eval_time = NULL) {
more_results <- function(object, resamples, candidates, metrics,
eval_time = NULL, control, param_info) {
tune_log(control, split = NULL, task = "Estimating performance", type = "info")

candidates <- candidates[, !(names(candidates) %in% c(".mean", ".sd", "objective"))]
Expand All @@ -753,8 +777,8 @@ more_results <- function(object, resamples, candidates, metrics, control,
param_info = param_info,
grid = candidates,
metrics = metrics,
control = control,
eval_time = eval_time
eval_time = eval_time,
control = control
),
silent = TRUE
)
Expand Down
Loading

0 comments on commit a8f0772

Please sign in to comment.