From 8ee5bf13cfec646da348a2a702037bd4c03bf10b Mon Sep 17 00:00:00 2001 From: DavisVaughan Date: Thu, 3 Sep 2020 14:01:10 -0400 Subject: [PATCH 1/3] Construct a "manual" rset for usage in `last_fit()` --- NAMESPACE | 1 + NEWS.md | 2 ++ R/last_fit.R | 23 +++++++++++++++-------- tests/testthat/test-last-fit.R | 2 +- 4 files changed, 19 insertions(+), 9 deletions(-) diff --git a/NAMESPACE b/NAMESPACE index 5f39cfd79..d289bdfb8 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -41,6 +41,7 @@ S3method(parameters,workflow) S3method(predict,conf_bound) S3method(predict,exp_improve) S3method(predict,prob_improve) +S3method(pretty,manual_rset) S3method(print,control_grid) S3method(print,prob_improve) S3method(print,tune_results) diff --git a/NEWS.md b/NEWS.md index 5ec5cc0ed..194d04e6b 100644 --- a/NEWS.md +++ b/NEWS.md @@ -2,6 +2,8 @@ ## Bug Fixes +* `last_fit()` no longer accidentally adjusts the random seed (#264). + * Fixed two bugs in the acquisition function calculations. # tune 0.1.1 diff --git a/R/last_fit.R b/R/last_fit.R index 8d993214d..f156eddaf 100644 --- a/R/last_fit.R +++ b/R/last_fit.R @@ -122,13 +122,6 @@ last_fit.workflow <- function(object, split, ..., metrics = NULL) { last_fit_workflow(object, split, metrics) } -split_to_rset <- function(x) { - prop <- length(x$in_id)/nrow(x$data) - res <- rsample::mc_cv(x$data, times = 1, prop = prop) - res$splits[[1]] <- x - res -} - last_fit_workflow <- function(object, split, metrics) { extr <- function(x) x @@ -140,7 +133,6 @@ last_fit_workflow <- function(object, split, metrics) { metrics = metrics, control = ctrl ) - res$id[[1]] <- "train/test split" res$.workflow <- res$.extracts[[1]][[1]] res$.extracts <- NULL class(res) <- c("last_fit", class(res)) @@ -148,3 +140,18 @@ last_fit_workflow <- function(object, split, metrics) { res } +# Fake an rset for `fit_resamples()` +split_to_rset <- function(x) { + splits <- list(x) + ids <- "train/test split" + new_manual_rset(splits, ids) +} + +new_manual_rset <- function(splits, ids) { + rsample::new_rset(splits, ids, subclass = c("manual_rset", "rset")) +} + +#' @export +pretty.manual_rset <- function(x, ...) { + "Manual resampling" +} diff --git a/tests/testthat/test-last-fit.R b/tests/testthat/test-last-fit.R index 0b26caff0..467045718 100644 --- a/tests/testthat/test-last-fit.R +++ b/tests/testthat/test-last-fit.R @@ -32,7 +32,7 @@ test_that("recipe method", { test_that("split_to_rset", { res <- tune:::split_to_rset(split) - expect_true(inherits(res, "mc_cv")) + expect_true(inherits(res, "manual_rset")) expect_true(nrow(res) == 1) expect_true(nrow(res) == 1) From f772c3a36735b81c6bc343c21bbcc61e5b698386 Mon Sep 17 00:00:00 2001 From: DavisVaughan Date: Mon, 14 Sep 2020 08:28:59 -0400 Subject: [PATCH 2/3] Use dev rsample --- DESCRIPTION | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/DESCRIPTION b/DESCRIPTION index 3ddfafda3..ea17d894d 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -25,7 +25,7 @@ Imports: cli (>= 2.0.0), crayon, yardstick, - rsample, + rsample (>= 0.0.7.9000), tidyr, GPfit, foreach, @@ -48,3 +48,5 @@ LazyData: true Roxygen: list(markdown = TRUE) RoxygenNote: 7.1.1 Language: en-US +Remotes: + tidymodels/rsample From b81346121f6b8d6f54ea9cd44b246e2559deb50a Mon Sep 17 00:00:00 2001 From: DavisVaughan Date: Mon, 14 Sep 2020 08:29:20 -0400 Subject: [PATCH 3/3] Use `rsample::manual_rset()` --- NAMESPACE | 1 - R/last_fit.R | 20 +++----------------- tests/testthat/test-last-fit.R | 14 -------------- 3 files changed, 3 insertions(+), 32 deletions(-) diff --git a/NAMESPACE b/NAMESPACE index d289bdfb8..5f39cfd79 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -41,7 +41,6 @@ S3method(parameters,workflow) S3method(predict,conf_bound) S3method(predict,exp_improve) S3method(predict,prob_improve) -S3method(pretty,manual_rset) S3method(print,control_grid) S3method(print,prob_improve) S3method(print,tune_results) diff --git a/R/last_fit.R b/R/last_fit.R index f156eddaf..85c4abcfe 100644 --- a/R/last_fit.R +++ b/R/last_fit.R @@ -126,10 +126,12 @@ last_fit_workflow <- function(object, split, metrics) { extr <- function(x) x ctrl <- control_resamples(save_pred = TRUE, extract = extr) + splits <- list(split) + resamples <- rsample::manual_rset(splits, ids = "train/test split") res <- fit_resamples( object, - resamples = split_to_rset(split), + resamples = resamples, metrics = metrics, control = ctrl ) @@ -139,19 +141,3 @@ last_fit_workflow <- function(object, split, metrics) { class(res) <- unique(class(res)) res } - -# Fake an rset for `fit_resamples()` -split_to_rset <- function(x) { - splits <- list(x) - ids <- "train/test split" - new_manual_rset(splits, ids) -} - -new_manual_rset <- function(splits, ids) { - rsample::new_rset(splits, ids, subclass = c("manual_rset", "rset")) -} - -#' @export -pretty.manual_rset <- function(x, ...) { - "Manual resampling" -} diff --git a/tests/testthat/test-last-fit.R b/tests/testthat/test-last-fit.R index 467045718..89b44ad5b 100644 --- a/tests/testthat/test-last-fit.R +++ b/tests/testthat/test-last-fit.R @@ -29,20 +29,6 @@ test_that("recipe method", { expect_equal(res$.predictions[[1]]$.pred, unname(test_pred)) }) -test_that("split_to_rset", { - - res <- tune:::split_to_rset(split) - expect_true(inherits(res, "manual_rset")) - expect_true(nrow(res) == 1) - expect_true(nrow(res) == 1) - - res <- linear_reg() %>% set_engine("lm") %>% last_fit(f, split) - expect_true(is.list(res$.workflow)) - expect_true(inherits(res$.workflow[[1]], "workflow")) - expect_true(is.list(res$.predictions)) - expect_true(inherits(res$.predictions[[1]], "tbl_df")) -}) - test_that("collect metrics of last fit", { res <- linear_reg() %>% set_engine("lm") %>% last_fit(f, split)