Skip to content

Commit

Permalink
Add get_rsplit() helper (#399)
Browse files Browse the repository at this point in the history
* Add get_rsplit() helper

* Add to pkgdown index

* avoid potential awkwardness down the line

of adding a method for some other class than `rset` but the main arg being named `rset`

* make the snaps easier to read

by adding the call that generated the error

Co-authored-by: Hannah Frick <[email protected]>
  • Loading branch information
mikemahoney218 and hfrick authored Dec 7, 2022
1 parent 3127e37 commit 489d979
Show file tree
Hide file tree
Showing 6 changed files with 159 additions and 0 deletions.
3 changes: 3 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ S3method(complement,sliding_index_split)
S3method(complement,sliding_period_split)
S3method(complement,sliding_window_split)
S3method(dim,rsplit)
S3method(get_rsplit,default)
S3method(get_rsplit,rset)
S3method(labels,rset)
S3method(labels,rsplit)
S3method(labels,vfold_cv)
Expand Down Expand Up @@ -317,6 +319,7 @@ export(ends_with)
export(everything)
export(form_pred)
export(gather)
export(get_rsplit)
export(group_bootstraps)
export(group_initial_split)
export(group_mc_cv)
Expand Down
59 changes: 59 additions & 0 deletions R/misc.R
Original file line number Diff line number Diff line change
Expand Up @@ -278,3 +278,62 @@ non_random_classes <- c(
"rolling_origin",
"validation_time_split"
)

#' Retrieve individual rsplits objects from an rset
#'
#' @param x The `rset` object to retrieve an rsplit from.
#' @param index An integer indicating which rsplit to retrieve: `1` for the
#' rsplit in the first row of the rset, `2` for the second, and so on.
#' @inheritParams rlang::args_dots_empty
#'
#' @return The rsplit object in row `index` of `rset`
#'
#' @examples
#' set.seed(123)
#' (starting_splits <- group_vfold_cv(mtcars, cyl, v = 3))
#' get_rsplit(starting_splits, 1)
#'
#' @rdname get_rsplit
#' @export
get_rsplit <- function(x, index, ...) {
UseMethod("get_rsplit")
}

#' @rdname get_rsplit
#' @export
get_rsplit.rset <- function(x, index, ...) {
rlang::check_dots_empty()

n_rows <- nrow(x)

acceptable_index <- length(index) == 1 &&
rlang::is_integerish(index) &&
index > 0 &&
index <= n_rows

if (!acceptable_index) {
msg <- ifelse(
length(index) != 1,
glue::glue("Index was of length {length(index)}."),
glue::glue("A value of {index} was provided.")
)

rlang::abort(
c(
glue::glue("`index` must be a length-1 integer between 1 and {n_rows}."),
x = msg
)
)
}

x$splits[[index]]
}

#' @rdname get_rsplit
#' @export
get_rsplit.default <- function(x, index, ...) {
cls <- paste0("'", class(x), "'", collapse = ", ")
rlang::abort(
paste("No `get_rsplit()` method for this class(es)", cls)
)
}
1 change: 1 addition & 0 deletions _pkgdown.yml
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ reference:
- add_resample_id
- complement
- form_pred
- get_rsplit
- starts_with("labels")
- make_splits
- make_strata
Expand Down
34 changes: 34 additions & 0 deletions man/get_rsplit.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

35 changes: 35 additions & 0 deletions tests/testthat/_snaps/misc.md
Original file line number Diff line number Diff line change
Expand Up @@ -140,3 +140,38 @@

`rset` must be an rset object

# get_rsplit()

Code
get_rsplit(val, 3)
Condition
Error in `get_rsplit()`:
! `index` must be a length-1 integer between 1 and 1.
x A value of 3 was provided.

---

Code
get_rsplit(val, c(1, 2))
Condition
Error in `get_rsplit()`:
! `index` must be a length-1 integer between 1 and 1.
x Index was of length 2.

---

Code
get_rsplit(val, 1.5)
Condition
Error in `get_rsplit()`:
! `index` must be a length-1 integer between 1 and 1.
x A value of 1.5 was provided.

---

Code
get_rsplit(warpbreaks, 1)
Condition
Error in `get_rsplit()`:
! No `get_rsplit()` method for this class(es) 'data.frame'

27 changes: 27 additions & 0 deletions tests/testthat/test-misc.R
Original file line number Diff line number Diff line change
Expand Up @@ -141,3 +141,30 @@ test_that("reshuffle_rset is working", {
expect_snapshot_error(reshuffle_rset(rset_subclasses[["manual_rset"]]$splits[[1]]))

})

test_that("get_rsplit()", {

val <- withr::with_seed(
11,
validation_split(warpbreaks)
)

expect_identical(val$splits[[1]], get_rsplit(val, 1))

expect_snapshot(error = TRUE,{
get_rsplit(val, 3)
})

expect_snapshot(error = TRUE,{
get_rsplit(val, c(1, 2))
})

expect_snapshot(error = TRUE,{
get_rsplit(val, 1.5)
})

expect_snapshot(error = TRUE,{
get_rsplit(warpbreaks, 1)
})

})

0 comments on commit 489d979

Please sign in to comment.