From 2f4ff973c1cb8b9cc876e90b9317314fe1fea09c Mon Sep 17 00:00:00 2001 From: Mike Mahoney Date: Wed, 7 Dec 2022 08:44:41 -0500 Subject: [PATCH 1/4] Add get_rsplit() helper --- NAMESPACE | 3 ++ R/misc.R | 60 ++++++++++++++++++++++++++++++++++++++ man/get_rsplit.Rd | 34 +++++++++++++++++++++ tests/testthat/test-misc.R | 27 +++++++++++++++++ 4 files changed, 124 insertions(+) create mode 100644 man/get_rsplit.Rd diff --git a/NAMESPACE b/NAMESPACE index 4020fccd..a3901bbd 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -14,6 +14,8 @@ S3method(complement,sliding_index_split) S3method(complement,sliding_period_split) S3method(complement,sliding_window_split) S3method(dim,rsplit) +S3method(get_rsplit,default) +S3method(get_rsplit,rset) S3method(labels,rset) S3method(labels,rsplit) S3method(labels,vfold_cv) @@ -316,6 +318,7 @@ export(ends_with) export(everything) export(form_pred) export(gather) +export(get_rsplit) export(group_bootstraps) export(group_initial_split) export(group_mc_cv) diff --git a/R/misc.R b/R/misc.R index fb4273e8..c4de1d0b 100644 --- a/R/misc.R +++ b/R/misc.R @@ -278,3 +278,63 @@ non_random_classes <- c( "rolling_origin", "validation_time_split" ) + +#' Retrieve individual rsplits objects from an rset +#' +#' @param rset The `rset` object to retrieve an rsplit from. +#' @param index An integer indicating which rsplit to retrieve: `1` for the +#' rsplit in the first row of the `rset`, `2` for the second, and so on. +#' @inheritParams rlang::args_dots_empty +#' +#' @return The rsplit object in row `index` of `rset` +#' +#' @examples +#' set.seed(123) +#' (starting_splits <- group_vfold_cv(mtcars, cyl, v = 3)) +#' get_rsplit(starting_splits, 1) +#' +#' @rdname get_rsplit +#' @export +get_rsplit <- function(rset, index, ...) { + UseMethod("get_rsplit") +} + +#' @rdname get_rsplit +#' @export +get_rsplit.rset <- function(rset, index, ...) { + rlang::check_dots_empty() + + n_rows <- nrow(rset) + + acceptable_index <- length(index) == 1 && + rlang::is_integerish(index) && + index > 0 && + index <= n_rows + + if (!acceptable_index) { + msg <- ifelse( + length(index) != 1, + glue::glue("Index was of length {length(index)}."), + glue::glue("A value of {index} was provided.") + ) + + rlang::abort( + c( + glue::glue("`index` must be a length-1 integer between 1 and {n_rows}."), + x = msg + ) + ) + } + + rset$splits[[index]] + +} + +#' @rdname get_rsplit +#' @export +get_rsplit.default <- function(rset, index, ...) { + cls <- paste0("'", class(rset), "'", collapse = ", ") + rlang::abort( + paste("No `get_rsplit()` method for this class(es)", cls) + ) +} diff --git a/man/get_rsplit.Rd b/man/get_rsplit.Rd new file mode 100644 index 00000000..a142795f --- /dev/null +++ b/man/get_rsplit.Rd @@ -0,0 +1,34 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/misc.R +\name{get_rsplit} +\alias{get_rsplit} +\alias{get_rsplit.rset} +\alias{get_rsplit.default} +\title{Retrieve individual rsplits objects from an rset} +\usage{ +get_rsplit(rset, index, ...) + +\method{get_rsplit}{rset}(rset, index, ...) + +\method{get_rsplit}{default}(rset, index, ...) +} +\arguments{ +\item{rset}{The \code{rset} object to retrieve an rsplit from.} + +\item{index}{An integer indicating which rsplit to retrieve: \code{1} for the +rsplit in the first row of the \code{rset}, \code{2} for the second, and so on.} + +\item{...}{These dots are for future extensions and must be empty.} +} +\value{ +The rsplit object in row \code{index} of \code{rset} +} +\description{ +Retrieve individual rsplits objects from an rset +} +\examples{ +set.seed(123) +(starting_splits <- group_vfold_cv(mtcars, cyl, v = 3)) +get_rsplit(starting_splits, 1) + +} diff --git a/tests/testthat/test-misc.R b/tests/testthat/test-misc.R index afa69b31..9f19bfcd 100644 --- a/tests/testthat/test-misc.R +++ b/tests/testthat/test-misc.R @@ -141,3 +141,30 @@ test_that("reshuffle_rset is working", { expect_snapshot_error(reshuffle_rset(rset_subclasses[["manual_rset"]]$splits[[1]])) }) + +test_that("get_rsplit()", { + + val <- withr::with_seed( + 11, + validation_split(warpbreaks) + ) + + expect_identical(val$splits[[1]], get_rsplit(val, 1)) + + expect_snapshot_error( + get_rsplit(val, 3) + ) + + expect_snapshot_error( + get_rsplit(val, c(1, 2)) + ) + + expect_snapshot_error( + get_rsplit(val, 1.5) + ) + + expect_snapshot_error( + get_rsplit(warpbreaks, 1) + ) + +}) From 6f3e9a29e5ae84888c741603478c84a6a3b280ea Mon Sep 17 00:00:00 2001 From: Mike Mahoney Date: Wed, 7 Dec 2022 08:50:50 -0500 Subject: [PATCH 2/4] Add to pkgdown index --- _pkgdown.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/_pkgdown.yml b/_pkgdown.yml index b4fb6c84..5364308c 100644 --- a/_pkgdown.yml +++ b/_pkgdown.yml @@ -62,6 +62,7 @@ reference: - add_resample_id - complement - form_pred + - get_rsplit - starts_with("labels") - make_splits - make_strata From 85462cd6811acbaea33f84a52d9ded6cf7d0cdc7 Mon Sep 17 00:00:00 2001 From: Hannah Frick Date: Wed, 7 Dec 2022 14:50:11 +0000 Subject: [PATCH 3/4] avoid potential awkwardness down the line of adding a method for some other class than `rset` but the main arg being named `rset` --- R/misc.R | 17 ++++++++--------- man/get_rsplit.Rd | 10 +++++----- 2 files changed, 13 insertions(+), 14 deletions(-) diff --git a/R/misc.R b/R/misc.R index c4de1d0b..d19a96da 100644 --- a/R/misc.R +++ b/R/misc.R @@ -281,9 +281,9 @@ non_random_classes <- c( #' Retrieve individual rsplits objects from an rset #' -#' @param rset The `rset` object to retrieve an rsplit from. +#' @param x The `rset` object to retrieve an rsplit from. #' @param index An integer indicating which rsplit to retrieve: `1` for the -#' rsplit in the first row of the `rset`, `2` for the second, and so on. +#' rsplit in the first row of the rset, `2` for the second, and so on. #' @inheritParams rlang::args_dots_empty #' #' @return The rsplit object in row `index` of `rset` @@ -295,16 +295,16 @@ non_random_classes <- c( #' #' @rdname get_rsplit #' @export -get_rsplit <- function(rset, index, ...) { +get_rsplit <- function(x, index, ...) { UseMethod("get_rsplit") } #' @rdname get_rsplit #' @export -get_rsplit.rset <- function(rset, index, ...) { +get_rsplit.rset <- function(x, index, ...) { rlang::check_dots_empty() - n_rows <- nrow(rset) + n_rows <- nrow(x) acceptable_index <- length(index) == 1 && rlang::is_integerish(index) && @@ -326,14 +326,13 @@ get_rsplit.rset <- function(rset, index, ...) { ) } - rset$splits[[index]] - + x$splits[[index]] } #' @rdname get_rsplit #' @export -get_rsplit.default <- function(rset, index, ...) { - cls <- paste0("'", class(rset), "'", collapse = ", ") +get_rsplit.default <- function(x, index, ...) { + cls <- paste0("'", class(x), "'", collapse = ", ") rlang::abort( paste("No `get_rsplit()` method for this class(es)", cls) ) diff --git a/man/get_rsplit.Rd b/man/get_rsplit.Rd index a142795f..3eb81ae3 100644 --- a/man/get_rsplit.Rd +++ b/man/get_rsplit.Rd @@ -6,17 +6,17 @@ \alias{get_rsplit.default} \title{Retrieve individual rsplits objects from an rset} \usage{ -get_rsplit(rset, index, ...) +get_rsplit(x, index, ...) -\method{get_rsplit}{rset}(rset, index, ...) +\method{get_rsplit}{rset}(x, index, ...) -\method{get_rsplit}{default}(rset, index, ...) +\method{get_rsplit}{default}(x, index, ...) } \arguments{ -\item{rset}{The \code{rset} object to retrieve an rsplit from.} +\item{x}{The \code{rset} object to retrieve an rsplit from.} \item{index}{An integer indicating which rsplit to retrieve: \code{1} for the -rsplit in the first row of the \code{rset}, \code{2} for the second, and so on.} +rsplit in the first row of the rset, \code{2} for the second, and so on.} \item{...}{These dots are for future extensions and must be empty.} } From c3f2c795536ab6180389c32164e39c964313a42a Mon Sep 17 00:00:00 2001 From: Hannah Frick Date: Wed, 7 Dec 2022 14:51:14 +0000 Subject: [PATCH 4/4] make the snaps easier to read by adding the call that generated the error --- tests/testthat/_snaps/misc.md | 35 +++++++++++++++++++++++++++++++++++ tests/testthat/test-misc.R | 16 ++++++++-------- 2 files changed, 43 insertions(+), 8 deletions(-) diff --git a/tests/testthat/_snaps/misc.md b/tests/testthat/_snaps/misc.md index 2f7749e0..525a5567 100644 --- a/tests/testthat/_snaps/misc.md +++ b/tests/testthat/_snaps/misc.md @@ -140,3 +140,38 @@ `rset` must be an rset object +# get_rsplit() + + Code + get_rsplit(val, 3) + Condition + Error in `get_rsplit()`: + ! `index` must be a length-1 integer between 1 and 1. + x A value of 3 was provided. + +--- + + Code + get_rsplit(val, c(1, 2)) + Condition + Error in `get_rsplit()`: + ! `index` must be a length-1 integer between 1 and 1. + x Index was of length 2. + +--- + + Code + get_rsplit(val, 1.5) + Condition + Error in `get_rsplit()`: + ! `index` must be a length-1 integer between 1 and 1. + x A value of 1.5 was provided. + +--- + + Code + get_rsplit(warpbreaks, 1) + Condition + Error in `get_rsplit()`: + ! No `get_rsplit()` method for this class(es) 'data.frame' + diff --git a/tests/testthat/test-misc.R b/tests/testthat/test-misc.R index 9f19bfcd..9c6fa573 100644 --- a/tests/testthat/test-misc.R +++ b/tests/testthat/test-misc.R @@ -151,20 +151,20 @@ test_that("get_rsplit()", { expect_identical(val$splits[[1]], get_rsplit(val, 1)) - expect_snapshot_error( + expect_snapshot(error = TRUE,{ get_rsplit(val, 3) - ) + }) - expect_snapshot_error( + expect_snapshot(error = TRUE,{ get_rsplit(val, c(1, 2)) - ) + }) - expect_snapshot_error( + expect_snapshot(error = TRUE,{ get_rsplit(val, 1.5) - ) + }) - expect_snapshot_error( + expect_snapshot(error = TRUE,{ get_rsplit(warpbreaks, 1) - ) + }) })