Skip to content

Commit

Permalink
Add a first draft of stratification with groups
Browse files Browse the repository at this point in the history
Addresses #317
  • Loading branch information
mikemahoney218 committed Aug 19, 2022
1 parent 585b8fc commit b98e0ed
Show file tree
Hide file tree
Showing 9 changed files with 215 additions and 18 deletions.
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# rsample (development version)

* `group_vfold_cv()` now supports stratification, so long as `balance = "groups"` (the default). Strata must be constant within each group (#317).

* `group_bootstraps()` now warns if resampling returns any empty assessment sets (previously had been an error) (#356) (#357).

* `bootstraps()` now warns if resampling returns any empty assessment sets (previously had no message or warning) (#356) (#357).
Expand Down
60 changes: 53 additions & 7 deletions R/make_groups.R
Original file line number Diff line number Diff line change
Expand Up @@ -32,19 +32,30 @@ make_groups <- function(data,
group,
v,
balance = c("groups", "observations", "prop"),
strata = NULL,
...) {
rlang::check_dots_used(call = rlang::caller_env())
balance <- rlang::arg_match(balance, error_call = rlang::caller_env())

data_ind <- tibble(..index = 1:nrow(data), ..group = group)
data_ind <- tibble(..index = 1:nrow(data), ..group = group, ..strata = strata)
data_ind$..group <- as.character(data_ind$..group)

res <- switch(
balance,
"groups" = balance_groups(data_ind = data_ind, v = v, ...),
"observations" = balance_observations(data_ind = data_ind, v = v, ...),
"prop" = balance_prop(data_ind = data_ind, v = v, ...)
)
res <- if (is.null(strata)) {
switch(
balance,
"groups" = balance_groups(data_ind = data_ind, v = v, ...),
"observations" = balance_observations(data_ind = data_ind, v = v, ...),
"prop" = balance_prop(data_ind = data_ind, v = v, ...)
)
} else {
data_ind$..strata <- as.character(data_ind$..strata)
switch(
balance,
"groups" = balance_groups_strata(data_ind = data_ind, v = v, ...),
"observations" = balance_observations_strata(data_ind = data_ind, v = v, ...),
"prop" = balance_prop_strata(data_ind = data_ind, v = v, ...)
)
}

data_ind <- res$data_ind
keys <- res$keys
Expand Down Expand Up @@ -72,6 +83,23 @@ balance_groups <- function(data_ind, v, ...) {
)
}

balance_groups_strata <- function(data_ind, v, ...) {
rlang::check_dots_empty()

keys <- vctrs::vec_unique(data_ind[c("..group", "..strata")])
keys <- split_unnamed(keys, keys$..strata)
keys <- purrr::map(keys, add_vfolds, v = v)
keys <- dplyr::bind_rows(keys)
keys <- data.frame(
..group = keys$..group,
..folds = keys$folds
)
list(
data_ind = data_ind,
keys = keys
)
}

balance_observations <- function(data_ind, v, ...) {
rlang::check_dots_empty()
n_obs <- nrow(data_ind)
Expand Down Expand Up @@ -105,6 +133,15 @@ balance_observations <- function(data_ind, v, ...) {

}

balance_observations_strata <- function(data_ind, v, ...) {
rlang::abort(
c(
"`balance = 'observations'` has not yet been implemented for grouped resampling with stratification.",
i = "Consider setting `balance = 'groups'`"
)
)
}

balance_prop <- function(prop, data_ind, v, replace = FALSE, ...) {
rlang::check_dots_empty()
acceptable_prop <- is.numeric(prop)
Expand Down Expand Up @@ -143,6 +180,15 @@ balance_prop <- function(prop, data_ind, v, replace = FALSE, ...) {

}

balance_prop_strata <- function(...) {
rlang::abort(
c(
"`balance = 'prop'` has not yet been implemented for grouped resampling with stratification, and should not be usable at all.",
i = "Please open an issue at https://github.com/tidymodels/rsample/issues with the code you ran that returned this error."
)
)
}

collapse_groups <- function(freq_table, data_ind, v) {
data_ind <- dplyr::left_join(data_ind, freq_table, by = c("..group" = "key"))
data_ind$..group <- data_ind$assignment
Expand Down
49 changes: 43 additions & 6 deletions R/vfold.R
Original file line number Diff line number Diff line change
Expand Up @@ -194,14 +194,32 @@ vfold_splits <- function(data, v = 10, strata = NULL, breaks = 4, pool = 0.1) {
#' # Leave-one-group-out CV
#' group_vfold_cv(ames, group = Neighborhood)
#'
#' library(dplyr)
#' data(Sacramento, package = "modeldata")
#'
#' city_strata <- Sacramento %>%
#' group_by(city) %>%
#' summarize(strata = mean(price)) %>%
#' summarize(city = city,
#' strata = cut(strata, quantile(strata), include.lowest = TRUE))
#'
#' sacramento_data <- Sacramento %>%
#' full_join(city_strata, by = "city")
#'
#' group_vfold_cv(sacramento_data, city, strata = strata)
#'
#' @export
group_vfold_cv <- function(data, group = NULL, v = NULL, repeats = 1, balance = c("groups", "observations"), ...) {
group_vfold_cv <- function(data, group = NULL, v = NULL, repeats = 1, balance = c("groups", "observations"), ..., strata = NULL, pool = 0.1) {

group <- validate_group({{ group }}, data)
balance <- rlang::arg_match(balance)

if (!missing(strata)) {
strata <- check_grouped_strata({{ group }}, {{ strata }}, pool, data)
}

if (repeats == 1) {
split_objs <- group_vfold_splits(data = data, group = group, v = v, balance = balance)
split_objs <- group_vfold_splits(data = data, group = group, v = v, balance = balance, strata = strata, pool = pool)
} else {
if (is.null(v)) {
rlang::abort(
Expand All @@ -214,7 +232,7 @@ group_vfold_cv <- function(data, group = NULL, v = NULL, repeats = 1, balance =
)
}
for (i in 1:repeats) {
tmp <- group_vfold_splits(data = data, group = group, v = v, balance = balance)
tmp <- group_vfold_splits(data = data, group = group, v = v, balance = balance, strata = strata, pool = pool)
tmp$id2 <- tmp$id
tmp$id <- names0(repeats, "Repeat")[i]
split_objs <- if (i == 1) {
Expand All @@ -237,7 +255,7 @@ group_vfold_cv <- function(data, group = NULL, v = NULL, repeats = 1, balance =

## Save some overall information

cv_att <- list(v = v, group = group, balance = balance, repeats = 1, strata = FALSE)
cv_att <- list(v = v, group = group, balance = balance, repeats = 1, strata = strata, pool = pool)

new_rset(
splits = split_objs$splits,
Expand All @@ -247,9 +265,14 @@ group_vfold_cv <- function(data, group = NULL, v = NULL, repeats = 1, balance =
)
}

group_vfold_splits <- function(data, group, v = NULL, balance) {
group_vfold_splits <- function(data, group, v = NULL, balance, strata = NULL, pool = 0.1) {

group <- getElement(data, group)
if (!is.null(strata)) {
strata <- getElement(data, strata)
strata <- as.character(strata)
strata <- make_strata(strata, pool = pool)
}
max_v <- length(unique(group))

if (is.null(v)) {
Expand All @@ -258,7 +281,7 @@ group_vfold_splits <- function(data, group, v = NULL, balance) {
check_v(v = v, max_v = max_v, rows = "groups", call = rlang::caller_env())
}

indices <- make_groups(data, group, v, balance)
indices <- make_groups(data, group, v, balance, strata)
indices <- lapply(indices, default_complement, n = nrow(data))
split_objs <-
purrr::map(indices,
Expand Down Expand Up @@ -286,3 +309,17 @@ check_v <- function(v, max_v, rows = "rows", call = rlang::caller_env()) {
)
}
}

check_grouped_strata <- function(group, strata, pool, data) {

strata <- tidyselect::vars_select(names(data), !!enquo(strata))
grouped_table <- tibble(
group = getElement(data, group),
strata = getElement(data, strata)
)

if (nrow(vctrs::vec_unique(grouped_table)) !=
nrow(vctrs::vec_unique(grouped_table["group"]))) {
rlang::abort("`strata` must be constant across all members of each `group`.")
}
}
27 changes: 26 additions & 1 deletion man/group_vfold_cv.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

13 changes: 12 additions & 1 deletion man/make_groups.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 3 additions & 1 deletion man/rsample-package.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

14 changes: 14 additions & 0 deletions tests/testthat/_snaps/vfold.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,20 @@
4 <split [2278/652]> Resample4
5 <split [2347/583]> Resample5

# grouping -- strata

Code
sizes4
Output
# A tibble: 5 x 5
analysis assessment n p id
<int> <int> <int> <int> <chr>
1 80087 19913 100000 3 Resample1
2 79903 20097 100000 3 Resample2
3 80163 19837 100000 3 Resample3
4 79867 20133 100000 3 Resample4
5 79980 20020 100000 3 Resample5

# grouping -- printing

Code
Expand Down
30 changes: 28 additions & 2 deletions tests/testthat/test-misc.R
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,10 @@ test_that("reshuffle_rset is working", {
# with non-default arguments,
# is supported by reshuffled_resample

# Select any function in rset_subclasses with a strata argument:
# Select any non-grouped function in rset_subclasses with a strata argument:
supports_strata <- purrr::map_lgl(
names(supported_subclasses),
~ any(names(formals(.x)) == "strata")
~ any(names(formals(.x)) == "strata") && !any(names(formals(.x)) == "group")
)
supports_strata <- names(supported_subclasses)[supports_strata]

Expand All @@ -86,6 +86,32 @@ test_that("reshuffle_rset is working", {
expect_identical(resample, reshuffled_resample)
}

# Select any grouped function in rset_subclasses with a strata argument:
grouped_strata <- purrr::map_lgl(
names(supported_subclasses),
~ any(names(formals(.x)) == "strata") && any(names(formals(.x)) == "group")
)
grouped_strata <- names(supported_subclasses)[grouped_strata]

for (i in seq_along(grouped_strata)) {
# Fit those functions with non-default arguments:
set.seed(123)
resample <- do.call(
grouped_strata[i],
list(
data = cbind(test_data(), z = rep(1:5, each = 10)),
group = "y",
strata = "z",
breaks = 2,
pool = 0.2
)
)
# Reshuffle them under the same seed to ensure they're identical
set.seed(123)
reshuffled_resample <- reshuffle_rset(resample)
expect_identical(resample, reshuffled_resample)
}

for (i in seq_along(non_random_classes)) {
expect_snapshot(
reshuffle_rset(rset_subclasses[[non_random_classes[[i]]]])
Expand Down
34 changes: 34 additions & 0 deletions tests/testthat/test-vfold.R
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,40 @@ test_that("grouping -- other balance methods", {

})

test_that("grouping -- strata", {
set.seed(11)

group_table <- tibble(
group = 1:100,
outcome = sample(c(rep(0, 90), rep(1, 10)))
)
observation_table <- tibble(
group = sample(1:100, 1e5, replace = TRUE),
observation = 1:1e5
)
sample_data <- dplyr::full_join(group_table, observation_table, by = "group")
rs4 <- group_vfold_cv(sample_data, group, v = 5, strata = outcome)
sizes4 <- dim_rset(rs4)
expect_snapshot(sizes4)

rate <- purrr::map_dbl(
rs4$splits,
function(x) {
dat <- as.data.frame(x)$outcome
mean(dat == "1")
}
)
expect_equal(mean(unique(rate)), 0.1, tolerance = 1e-3)

good_holdout <- purrr::map_lgl(
rs4$splits,
function(x) {
length(intersect(x$in_ind, x$out_id)) == 0
}
)
expect_true(all(good_holdout))
})

test_that("grouping -- repeated", {
set.seed(11)
rs2 <- group_vfold_cv(dat1, c, v = 3, repeats = 4)
Expand Down

0 comments on commit b98e0ed

Please sign in to comment.