Skip to content

Commit

Permalink
Add group_bootstraps() (#316)
Browse files Browse the repository at this point in the history
* Add group_initial_split

* Add group_validation_split

* Add group_validation_split() to docs

* Add group_bootstraps()

* Add group_bootstraps to pkgdown

* Error on 0-row assessment

* Test for 0-row assessment

* Tiny edits to error message, docs

* Update NEWS

Co-authored-by: Julia Silge <[email protected]>
  • Loading branch information
mikemahoney218 and juliasilge committed Jun 30, 2022
1 parent 38a52d2 commit c257ec5
Show file tree
Hide file tree
Showing 15 changed files with 367 additions and 12 deletions.
12 changes: 12 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ S3method(vec_cast,bootstraps.data.frame)
S3method(vec_cast,bootstraps.tbl_df)
S3method(vec_cast,data.frame.apparent)
S3method(vec_cast,data.frame.bootstraps)
S3method(vec_cast,data.frame.group_bootstraps)
S3method(vec_cast,data.frame.group_mc_cv)
S3method(vec_cast,data.frame.group_vfold_cv)
S3method(vec_cast,data.frame.loo_cv)
Expand All @@ -86,6 +87,9 @@ S3method(vec_cast,data.frame.sliding_period)
S3method(vec_cast,data.frame.sliding_window)
S3method(vec_cast,data.frame.validation_split)
S3method(vec_cast,data.frame.vfold_cv)
S3method(vec_cast,group_bootstraps.data.frame)
S3method(vec_cast,group_bootstraps.group_bootstraps)
S3method(vec_cast,group_bootstraps.tbl_df)
S3method(vec_cast,group_mc_cv.data.frame)
S3method(vec_cast,group_mc_cv.group_mc_cv)
S3method(vec_cast,group_mc_cv.tbl_df)
Expand Down Expand Up @@ -118,6 +122,7 @@ S3method(vec_cast,sliding_window.sliding_window)
S3method(vec_cast,sliding_window.tbl_df)
S3method(vec_cast,tbl_df.apparent)
S3method(vec_cast,tbl_df.bootstraps)
S3method(vec_cast,tbl_df.group_bootstraps)
S3method(vec_cast,tbl_df.group_mc_cv)
S3method(vec_cast,tbl_df.group_vfold_cv)
S3method(vec_cast,tbl_df.loo_cv)
Expand All @@ -144,6 +149,7 @@ S3method(vec_ptype2,bootstraps.data.frame)
S3method(vec_ptype2,bootstraps.tbl_df)
S3method(vec_ptype2,data.frame.apparent)
S3method(vec_ptype2,data.frame.bootstraps)
S3method(vec_ptype2,data.frame.group_bootstraps)
S3method(vec_ptype2,data.frame.group_mc_cv)
S3method(vec_ptype2,data.frame.group_vfold_cv)
S3method(vec_ptype2,data.frame.loo_cv)
Expand All @@ -156,6 +162,9 @@ S3method(vec_ptype2,data.frame.sliding_period)
S3method(vec_ptype2,data.frame.sliding_window)
S3method(vec_ptype2,data.frame.validation_split)
S3method(vec_ptype2,data.frame.vfold_cv)
S3method(vec_ptype2,group_bootstraps.data.frame)
S3method(vec_ptype2,group_bootstraps.group_bootstraps)
S3method(vec_ptype2,group_bootstraps.tbl_df)
S3method(vec_ptype2,group_mc_cv.data.frame)
S3method(vec_ptype2,group_mc_cv.group_mc_cv)
S3method(vec_ptype2,group_mc_cv.tbl_df)
Expand Down Expand Up @@ -188,6 +197,7 @@ S3method(vec_ptype2,sliding_window.sliding_window)
S3method(vec_ptype2,sliding_window.tbl_df)
S3method(vec_ptype2,tbl_df.apparent)
S3method(vec_ptype2,tbl_df.bootstraps)
S3method(vec_ptype2,tbl_df.group_bootstraps)
S3method(vec_ptype2,tbl_df.group_mc_cv)
S3method(vec_ptype2,tbl_df.group_vfold_cv)
S3method(vec_ptype2,tbl_df.loo_cv)
Expand All @@ -208,6 +218,7 @@ S3method(vec_ptype2,vfold_cv.tbl_df)
S3method(vec_ptype2,vfold_cv.vfold_cv)
S3method(vec_restore,apparent)
S3method(vec_restore,bootstraps)
S3method(vec_restore,group_bootstraps)
S3method(vec_restore,group_mc_cv)
S3method(vec_restore,group_vfold_cv)
S3method(vec_restore,loo_cv)
Expand Down Expand Up @@ -235,6 +246,7 @@ export(ends_with)
export(everything)
export(form_pred)
export(gather)
export(group_bootstraps)
export(group_initial_split)
export(group_mc_cv)
export(group_validation_split)
Expand Down
2 changes: 1 addition & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

* Added arguments to control how `group_vfold_cv()` combines groups. Use `balance = "groups"` to assign (roughly) the same number of groups to each fold, or `balance = "observations"` to assign (roughly) the same number of observations to each fold.

* Added new functions for grouped resampling: `group_mc_cv()` (#313), `group_initial_split()` and `group_validation_split()` (#315).
* Added new functions for grouped resampling: `group_mc_cv()` (#313), `group_initial_split()` and `group_validation_split()` (#315), and `group_bootstraps()` (#316).

* Added a new function, `reverse_splits()`, to swap analysis and assessment splits (#319, #284).

Expand Down
100 changes: 99 additions & 1 deletion R/boot.R
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
#' some estimators used by the `summary` function that require the apparent
#' error rate.
#' @export
#' @return An tibble with classes `bootstraps`, `rset`, `tbl_df`, `tbl`, and
#' @return A tibble with classes `bootstraps`, `rset`, `tbl_df`, `tbl`, and
#' `data.frame`. The results include a column for the data split objects and a
#' column called `id` that has a character string with the resample identifier.
#' @examples
Expand Down Expand Up @@ -146,3 +146,101 @@ boot_splits <-
id = names0(length(split_objs), "Bootstrap")
)
}

#' Group Bootstraps
#'
#' Group bootstrapping creates splits of the data based
#' on some grouping variable (which may have more than a single row
#' associated with it). A common use of this kind of resampling is when you
#' have repeated measures of the same subject.
#' A bootstrap sample is a sample that is the same size as the original data
#' set that is made using replacement. This results in analysis samples that
#' have multiple replicates of some of the original rows of the data. The
#' assessment set is defined as the rows of the original data that were not
#' included in the bootstrap sample. This is often referred to as the
#' "out-of-bag" (OOB) sample.
#' @details The argument `apparent` enables the option of an additional
#' "resample" where the analysis and assessment data sets are the same as the
#' original data set. This can be required for some types of analysis of the
#' bootstrap results.
#'
#' @inheritParams bootstraps
#' @inheritParams make_groups
#' @export
#' @return An tibble with classes `group_bootstraps` `bootstraps`, `rset`,
#' `tbl_df`, `tbl`, and `data.frame`. The results include a column for the data
#' split objects and a column called `id` that has a character string with the
#' resample identifier.
#' @examplesIf rlang::is_installed("modeldata")
#' data(ames, package = "modeldata")
#'
#' set.seed(13)
#' group_bootstraps(ames, Neighborhood, times = 3)
#' group_bootstraps(ames, Neighborhood, times = 3, apparent = TRUE)
#'
#' @export
group_bootstraps <- function(data,
group,
times = 25,
apparent = FALSE,
...) {

rlang::check_dots_empty()

group <- validate_group({{ group }}, data)

split_objs <-
group_boot_splits(
data = data,
group = group,
times = times
)

## We remove the holdout indices since it will save space and we can
## derive them later when they are needed.
split_objs$splits <- map(split_objs$splits, rm_out)

if (apparent) {
split_objs <- bind_rows(split_objs, apparent(data))
}

boot_att <- list(
times = times,
apparent = apparent,
strata = FALSE,
group = group
)

new_rset(
splits = split_objs$splits,
ids = split_objs$id,
attrib = boot_att,
subclass = c("group_bootstraps", "bootstraps", "rset")
)
}

group_boot_splits <- function(data, group, times = 25) {

group <- getElement(data, group)
n <- nrow(data)
indices <- make_groups(data, group, times, balance = "prop", prop = 1, replace = TRUE)
indices <- lapply(indices, boot_complement, n = n)
split_objs <-
purrr::map(indices, make_splits, data = data, class = c("group_boot_split", "boot_split"))
all_assessable <- purrr::map(split_objs, function(x) nrow(assessment(x)))

if (any(all_assessable == 0)) {
rlang::abort(
c(
"Some assessment sets contained zero rows",
i = "Consider using a non-grouped resampling method"
),
call = rlang::caller_env()
)
}

list(
splits = split_objs,
id = names0(length(split_objs), "Bootstrap")
)
}
1 change: 1 addition & 0 deletions R/compat-vctrs-helpers.R
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ test_data <- function() {
delayedAssign("rset_subclasses", {
list(
bootstraps = bootstraps(test_data()),
group_bootstraps = group_bootstraps(test_data(), y),
vfold_cv = vfold_cv(test_data(), v = 10, repeats = 2),
group_vfold_cv = group_vfold_cv(test_data(), y),
loo_cv = loo_cv(test_data()),
Expand Down
52 changes: 52 additions & 0 deletions R/compat-vctrs.R
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,58 @@ vec_cast.data.frame.bootstraps <- function(x, to, ..., x_arg = "", to_arg = "")
df_cast(x, to, ..., x_arg = x_arg, to_arg = to_arg)
}

# ------------------------------------------------------------------------------
# group_bootstraps

#' @export
vec_restore.group_bootstraps <- function(x, to, ...) {
rset_reconstruct(x, to)
}


#' @export
vec_ptype2.group_bootstraps.group_bootstraps <- function(x, y, ..., x_arg = "", y_arg = "") {
stop_never_called("vec_ptype2.group_bootstraps.group_bootstraps")
}
#' @export
vec_ptype2.group_bootstraps.tbl_df <- function(x, y, ..., x_arg = "", y_arg = "") {
stop_never_called("vec_ptype2.group_bootstraps.tbl_df")
}
#' @export
vec_ptype2.tbl_df.group_bootstraps <- function(x, y, ..., x_arg = "", y_arg = "") {
stop_never_called("vec_ptype2.tbl_df.group_bootstraps")
}
#' @export
vec_ptype2.group_bootstraps.data.frame <- function(x, y, ..., x_arg = "", y_arg = "") {
stop_never_called("vec_ptype2.group_bootstraps.data.frame")
}
#' @export
vec_ptype2.data.frame.group_bootstraps <- function(x, y, ..., x_arg = "", y_arg = "") {
stop_never_called("vec_ptype2.data.frame.group_bootstraps")
}


#' @export
vec_cast.group_bootstraps.group_bootstraps <- function(x, to, ..., x_arg = "", to_arg = "") {
stop_incompatible_cast_rset(x, to, x_arg = x_arg, to_arg = to_arg)
}
#' @export
vec_cast.group_bootstraps.tbl_df <- function(x, to, ..., x_arg = "", to_arg = "") {
stop_incompatible_cast_rset(x, to, x_arg = x_arg, to_arg = to_arg)
}
#' @export
vec_cast.tbl_df.group_bootstraps <- function(x, to, ..., x_arg = "", to_arg = "") {
tib_cast(x, to, ..., x_arg = x_arg, to_arg = to_arg)
}
#' @export
vec_cast.group_bootstraps.data.frame <- function(x, to, ..., x_arg = "", to_arg = "") {
stop_incompatible_cast_rset(x, to, x_arg = x_arg, to_arg = to_arg)
}
#' @export
vec_cast.data.frame.group_bootstraps <- function(x, to, ..., x_arg = "", to_arg = "") {
df_cast(x, to, ..., x_arg = x_arg, to_arg = to_arg)
}

# ------------------------------------------------------------------------------
# vfold_cv

Expand Down
20 changes: 15 additions & 5 deletions R/make_groups.R
Original file line number Diff line number Diff line change
Expand Up @@ -105,25 +105,35 @@ balance_observations <- function(data_ind, v, ...) {

}

balance_prop <- function(prop, data_ind, v, ...) {
balance_prop <- function(prop, data_ind, v, replace = FALSE, ...) {
rlang::check_dots_empty()
if (!is.numeric(prop) | prop >= 1 | prop <= 0) {
acceptable_prop <- is.numeric(prop)
acceptable_prop <- acceptable_prop && ((prop <= 1 && replace) || (prop < 1 && !replace))
acceptable_prop <- acceptable_prop && prop > 0
if (!acceptable_prop) {
rlang::abort("`prop` must be a number on (0, 1).", call = rlang::caller_env())
}
n_obs <- nrow(data_ind)

freq_table <- vec_count(data_ind$..group)

n <- nrow(freq_table)
# If sampling with replacement,
# set `n` to the number of resamples we'd need
# if we somehow got the smallest group every time
if (replace) n <- n * prop * sum(freq_table$count) / min(freq_table$count)
n <- ceiling(n)

freq_table <- purrr::map_dfr(
seq_len(v),
function(x) {
freq_table <- freq_table[sample.int(nrow(freq_table)), ]
cumulative_proportion <- cumsum(freq_table$count) / sum(freq_table$count)
work_table <- freq_table[sample.int(nrow(freq_table), n, replace = replace), ]
cumulative_proportion <- cumsum(work_table$count) / sum(freq_table$count)
crosses_target <- which(cumulative_proportion > prop)[[1]]
is_closest <- cumulative_proportion[c(crosses_target, crosses_target - 1)]
is_closest <- which.min(abs(is_closest - prop)) - 1
crosses_target <- crosses_target - is_closest
out <- freq_table[seq_len(crosses_target), ]
out <- work_table[seq_len(crosses_target), ]
out$assignment <- x
out
}
Expand Down
18 changes: 14 additions & 4 deletions R/mc.R
Original file line number Diff line number Diff line change
Expand Up @@ -170,8 +170,7 @@ group_mc_cv <- function(data, group, prop = 3 / 4, times = 25, ...) {
data = data,
group = group,
prop = prop,
times = times,
balance = "prop"
times = times
)

## We remove the holdout indices since it will save space and we can
Expand All @@ -193,14 +192,25 @@ group_mc_cv <- function(data, group, prop = 3 / 4, times = 25, ...) {
)
}

group_mc_splits <- function(data, group, prop = 3 / 4, times = 25, balance = "prop") {
group_mc_splits <- function(data, group, prop = 3 / 4, times = 25) {

group <- getElement(data, group)
n <- nrow(data)
indices <- make_groups(data, group, times, balance, prop = prop)
indices <- make_groups(data, group, times, balance = "prop", prop = prop, replace = FALSE)
indices <- lapply(indices, mc_complement, n = n)
split_objs <-
purrr::map(indices, make_splits, data = data, class = "grouped_mc_split")
all_assessable <- purrr::map(split_objs, function(x) nrow(assessment(x)))

if (any(all_assessable == 0)) {
rlang::abort(
c(
"Some assessment sets contained zero rows",
i = "Consider using a non-grouped resampling method"
),
call = rlang::caller_env()
)
}
list(
splits = split_objs,
id = names0(length(split_objs), "Resample")
Expand Down
1 change: 1 addition & 0 deletions _pkgdown.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ reference:
- loo_cv
- mc_cv
- validation_split
- group_bootstraps
- group_mc_cv
- group_vfold_cv
- rolling_origin
Expand Down
2 changes: 1 addition & 1 deletion man/bootstraps.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit c257ec5

Please sign in to comment.