From 2696fba59ce3522ff11b6c3f895be347f2eef028 Mon Sep 17 00:00:00 2001 From: Mike Mahoney Date: Fri, 1 Jul 2022 14:55:40 -0400 Subject: [PATCH 1/3] Add repeats argument (#327) --- R/vfold.R | 35 +++++++++++++++++++++++++++++++--- man/group_vfold_cv.Rd | 4 ++++ tests/testthat/_snaps/vfold.md | 12 ++++++++++++ tests/testthat/test-vfold.R | 24 +++++++++++++++++++++++ 4 files changed, 72 insertions(+), 3 deletions(-) diff --git a/R/vfold.R b/R/vfold.R index 56d418bc..4c85a504 100644 --- a/R/vfold.R +++ b/R/vfold.R @@ -78,6 +78,11 @@ vfold_cv <- function(data, v = 10, repeats = 1, strata = strata, breaks = breaks, pool = pool ) } else { + if (v == nrow(data)) { + rlang::abort( + glue::glue("Repeated resampling when `v` is {v} would create identical resamples") + ) + } for (i in 1:repeats) { tmp <- vfold_splits(data = data, v = v, strata = strata, pool = pool) tmp$id2 <- tmp$id @@ -174,14 +179,38 @@ vfold_splits <- function(data, v = 10, strata = NULL, breaks = 4, pool = 0.1) { #' v = 5, #' balance = "observations" #' ) +#' group_vfold_cv(ames, group = Neighborhood, v = 5, repeats = 2) #' #' @export -group_vfold_cv <- function(data, group = NULL, v = NULL, balance = c("groups", "observations"), ...) { +group_vfold_cv <- function(data, group = NULL, v = NULL, balance = c("groups", "observations"), repeats = 1, ...) { group <- validate_group({{ group }}, data) balance <- rlang::arg_match(balance) - split_objs <- group_vfold_splits(data = data, group = group, v = v, balance = balance) + if (repeats == 1) { + split_objs <- group_vfold_splits(data = data, group = group, v = v, balance = balance) + } else { + if (is.null(v)) { + rlang::abort( + glue::glue("Repeated resampling when `v` is `NULL` would create identical resamples") + ) + } + if (v == length(unique(getElement(data, group)))) { + rlang::abort( + glue::glue("Repeated resampling when `v` is {v} would create identical resamples") + ) + } + for (i in 1:repeats) { + tmp <- group_vfold_splits(data = data, group = group, v = v, balance = balance) + tmp$id2 <- tmp$id + tmp$id <- names0(repeats, "Repeat")[i] + split_objs <- if (i == 1) { + tmp + } else { + rbind(split_objs, tmp) + } + } + } ## We remove the holdout indices since it will save space and we can ## derive them later when they are needed. @@ -213,7 +242,7 @@ group_vfold_splits <- function(data, group, v = NULL, balance) { if (is.null(v)) { v <- max_v } else { - check_v(v = v, max_v = max_v, rows = "rows", call = rlang::caller_env()) + check_v(v = v, max_v = max_v, rows = "groups", call = rlang::caller_env()) } indices <- make_groups(data, group, v, balance) diff --git a/man/group_vfold_cv.Rd b/man/group_vfold_cv.Rd index 8e922c70..4d5c8d35 100644 --- a/man/group_vfold_cv.Rd +++ b/man/group_vfold_cv.Rd @@ -9,6 +9,7 @@ group_vfold_cv( group = NULL, v = NULL, balance = c("groups", "observations"), + repeats = 1, ... ) } @@ -26,6 +27,8 @@ will be set to the number of unique values in the group.} groups be combined into folds? Should be one of \code{"groups"} or \code{"observations"}.} +\item{repeats}{The number of times to repeat the V-fold partitioning.} + \item{...}{Not currently used.} } \value{ @@ -55,5 +58,6 @@ group_vfold_cv( v = 5, balance = "observations" ) +group_vfold_cv(ames, group = Neighborhood, v = 5, repeats = 2) \dontshow{\}) # examplesIf} } diff --git a/tests/testthat/_snaps/vfold.md b/tests/testthat/_snaps/vfold.md index 32499cdc..62dd9843 100644 --- a/tests/testthat/_snaps/vfold.md +++ b/tests/testthat/_snaps/vfold.md @@ -15,6 +15,10 @@ The number of rows is less than `v = 500` +--- + + Repeated resampling when `v` is 150 would create identical resamples + # printing Code @@ -35,6 +39,14 @@ 9 Fold09 10 Fold10 +# grouping -- bad args + + Repeated resampling when `v` is 4 would create identical resamples + +--- + + Repeated resampling when `v` is `NULL` would create identical resamples + # grouping -- other balance methods Code diff --git a/tests/testthat/test-vfold.R b/tests/testthat/test-vfold.R index 86d4d737..0f19f9ef 100644 --- a/tests/testthat/test-vfold.R +++ b/tests/testthat/test-vfold.R @@ -79,6 +79,7 @@ test_that("bad args", { expect_error(vfold_cv(iris, strata = c("Species", "Sepal.Width"))) expect_snapshot_error(vfold_cv(iris, v = -500)) expect_snapshot_error(vfold_cv(iris, v = 500)) + expect_snapshot_error(vfold_cv(iris, v = 150, repeats = 2)) }) test_that("printing", { @@ -104,6 +105,8 @@ test_that("grouping -- bad args", { expect_error(group_vfold_cv(warpbreaks, group = "tensio")) expect_error(group_vfold_cv(warpbreaks)) expect_error(group_vfold_cv(warpbreaks, group = "tension", v = 10)) + expect_snapshot_error(group_vfold_cv(dat1, c, v = 4, repeats = 4)) + expect_snapshot_error(group_vfold_cv(dat1, c, repeats = 4)) }) @@ -219,6 +222,27 @@ test_that("grouping -- other balance methods", { }) +test_that("grouping -- repeated", { + set.seed(11) + rs2 <- group_vfold_cv(dat1, c, v = 3, repeats = 4) + sizes2 <- dim_rset(rs2) + + same_data <- + purrr::map_lgl(rs2$splits, function(x) { + all.equal(x$data, dat1) + }) + expect_true(all(same_data)) + + good_holdout <- purrr::map_lgl( + rs2$splits, + function(x) { + length(intersect(x$in_ind, x$out_id)) == 0 + } + ) + expect_true(all(good_holdout)) + +}) + test_that("grouping -- printing", { expect_snapshot(group_vfold_cv(warpbreaks, "tension")) }) From 1451551bc6babcf547c331e822e9dea11d1484d0 Mon Sep 17 00:00:00 2001 From: Julia Silge Date: Fri, 1 Jul 2022 17:15:05 -0600 Subject: [PATCH 2/3] Update NEWS --- NEWS.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/NEWS.md b/NEWS.md index 109f465c..8ffb9606 100644 --- a/NEWS.md +++ b/NEWS.md @@ -2,6 +2,8 @@ * Added arguments to control how `group_vfold_cv()` combines groups. Use `balance = "groups"` to assign (roughly) the same number of groups to each fold, or `balance = "observations"` to assign (roughly) the same number of observations to each fold. +* Added a `repeats` argument to `group_vfold_cv()` (#330). + * Added new functions for grouped resampling: `group_mc_cv()` (#313), `group_initial_split()` and `group_validation_split()` (#315), and `group_bootstraps()` (#316). * Added a new function, `reverse_splits()`, to swap analysis and assessment splits (#319, #284). From 62a6f2ff539e808fb95026267f940c92de22ab24 Mon Sep 17 00:00:00 2001 From: Julia Silge Date: Fri, 1 Jul 2022 17:15:15 -0600 Subject: [PATCH 3/3] Take out extra glue() --- R/vfold.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/vfold.R b/R/vfold.R index 4c85a504..5171eb4f 100644 --- a/R/vfold.R +++ b/R/vfold.R @@ -192,7 +192,7 @@ group_vfold_cv <- function(data, group = NULL, v = NULL, balance = c("groups", " } else { if (is.null(v)) { rlang::abort( - glue::glue("Repeated resampling when `v` is `NULL` would create identical resamples") + "Repeated resampling when `v` is `NULL` would create identical resamples" ) } if (v == length(unique(getElement(data, group)))) {