Skip to content

Commit

Permalink
Use rsample::manual_rset()
Browse files Browse the repository at this point in the history
  • Loading branch information
DavisVaughan committed Sep 14, 2020
1 parent f772c3a commit b813461
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 32 deletions.
1 change: 0 additions & 1 deletion NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ S3method(parameters,workflow)
S3method(predict,conf_bound)
S3method(predict,exp_improve)
S3method(predict,prob_improve)
S3method(pretty,manual_rset)
S3method(print,control_grid)
S3method(print,prob_improve)
S3method(print,tune_results)
Expand Down
20 changes: 3 additions & 17 deletions R/last_fit.R
Original file line number Diff line number Diff line change
Expand Up @@ -126,10 +126,12 @@ last_fit_workflow <- function(object, split, metrics) {
extr <- function(x)
x
ctrl <- control_resamples(save_pred = TRUE, extract = extr)
splits <- list(split)
resamples <- rsample::manual_rset(splits, ids = "train/test split")
res <-
fit_resamples(
object,
resamples = split_to_rset(split),
resamples = resamples,
metrics = metrics,
control = ctrl
)
Expand All @@ -139,19 +141,3 @@ last_fit_workflow <- function(object, split, metrics) {
class(res) <- unique(class(res))
res
}

# Fake an rset for `fit_resamples()`
split_to_rset <- function(x) {
splits <- list(x)
ids <- "train/test split"
new_manual_rset(splits, ids)
}

new_manual_rset <- function(splits, ids) {
rsample::new_rset(splits, ids, subclass = c("manual_rset", "rset"))
}

#' @export
pretty.manual_rset <- function(x, ...) {
"Manual resampling"
}
14 changes: 0 additions & 14 deletions tests/testthat/test-last-fit.R
Original file line number Diff line number Diff line change
Expand Up @@ -29,20 +29,6 @@ test_that("recipe method", {
expect_equal(res$.predictions[[1]]$.pred, unname(test_pred))
})

test_that("split_to_rset", {

res <- tune:::split_to_rset(split)
expect_true(inherits(res, "manual_rset"))
expect_true(nrow(res) == 1)
expect_true(nrow(res) == 1)

res <- linear_reg() %>% set_engine("lm") %>% last_fit(f, split)
expect_true(is.list(res$.workflow))
expect_true(inherits(res$.workflow[[1]], "workflow"))
expect_true(is.list(res$.predictions))
expect_true(inherits(res$.predictions[[1]], "tbl_df"))
})

test_that("collect metrics of last fit", {

res <- linear_reg() %>% set_engine("lm") %>% last_fit(f, split)
Expand Down

0 comments on commit b813461

Please sign in to comment.