Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make balance = observations work with strata #364

Merged
merged 5 commits into from
Sep 16, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# rsample (development version)

* `group_vfold_cv()` now supports stratification, so long as `balance = "groups"` (the default). Strata must be constant within each group (#317).
* `group_vfold_cv()` now supports stratification. Strata must be constant within each group (#317, #360, #363, #364).

* `group_bootstraps()` now warns if resampling returns any empty assessment sets (previously had been an error) (#356) (#357).

Expand Down
71 changes: 52 additions & 19 deletions R/make_groups.R
Original file line number Diff line number Diff line change
Expand Up @@ -168,42 +168,75 @@ balance_observations <- function(data_ind, v, ...) {
n_obs <- nrow(data_ind)
target_per_fold <- 1 / v

freq_table <- vec_count(data_ind$..group, sort = "location")
freq_table <- balance_observations_helper(data_ind, v, target_per_fold)

collapse_groups(freq_table, data_ind, v)

}

balance_observations_strata <- function(data_ind, v, ...) {
rlang::check_dots_empty()

target_per_fold <- 1 / v

data_splits <- split_unnamed(data_ind, data_ind[["..strata"]])
hfrick marked this conversation as resolved.
Show resolved Hide resolved
freq_table <- purrr::map_dfr(
data_splits,
balance_observations_helper,
v = v,
target_per_fold = target_per_fold
)

collapse_groups(freq_table, data_ind, v)
}

balance_observations_helper <- function(data_split, v, target_per_fold) {

n_obs <- nrow(data_split)
# Create a frequency table counting how many of each group are in the data:
freq_table <- vec_count(data_split$..group, sort = "location")
# Randomly shuffle that table, then assign the first few rows to folds
# (to ensure that each fold gets at least one group assigned):
freq_table <- freq_table[sample.int(nrow(freq_table)), ]
freq_table$assignment <- NA
# Assign the first `v` rows to folds, so that each fold has _some_ data:
freq_table$assignment[seq_len(v)] <- seq_len(v)

# Each run of this loop assigns one "NA" assignment to a fold,
# so we won't get caught in an endless loop here
while (any(is.na(freq_table$assignment))) {
# Get the index of the next row to be assigned, and its count:
next_row <- which(is.na(freq_table$assignment))[[1]]
next_size <- freq_table[next_row, ]$count

# Calculate which fold to assign this new row into:
group_breakdown <- freq_table %>%
# The only NA column in freq_table should be assignment
# So this should only drop un-assigned groups:
stats::na.omit() %>%
# Group by fold assignments and count data in each fold:
dplyr::group_by(.data$assignment) %>%
dplyr::summarise(count = sum(.data$count), .groups = "drop") %>%
dplyr::mutate(prop = .data$count / n_obs,
pre_error = abs(.data$prop - target_per_fold),
if_added_count = .data$count + next_size,
if_added_prop = .data$if_added_count / n_obs,
post_error = abs(.data$if_added_prop - target_per_fold),
improvement = .data$post_error - .data$pre_error)
# Calculate...:
hfrick marked this conversation as resolved.
Show resolved Hide resolved
dplyr::mutate(
# The proportion of data in each fold so far,
prop = .data$count / n_obs,
# The amount off from the target proportion so far,
pre_error = abs(.data$prop - target_per_fold),
# The amount off from the target proportion if we add this new group,
if_added_count = .data$count + next_size,
if_added_prop = .data$if_added_count / n_obs,
post_error = abs(.data$if_added_prop - target_per_fold),
# And how much better or worse adding this new group would make things
improvement = .data$post_error - .data$pre_error
)

# Assign the group in question to the best fold and move on to the next one:
most_improved <- which.min(group_breakdown$improvement)
freq_table[next_row, ]$assignment <-
group_breakdown[most_improved, ]$assignment
}

collapse_groups(freq_table, data_ind, v)

}

balance_observations_strata <- function(data_ind, v, ...) {
rlang::abort(
c(
"`balance = 'observations'` has not yet been implemented for grouped resampling with stratification.",
i = "Consider setting `balance = 'groups'`"
)
)
freq_table
}

balance_prop <- function(prop, data_ind, v, replace = FALSE, ...) {
Expand Down
35 changes: 32 additions & 3 deletions R/vfold.R
Original file line number Diff line number Diff line change
Expand Up @@ -268,18 +268,47 @@ group_vfold_cv <- function(data, group = NULL, v = NULL, repeats = 1, balance =
group_vfold_splits <- function(data, group, v = NULL, balance, strata = NULL, pool = 0.1) {

group <- getElement(data, group)
max_v <- length(unique(group))
if (!is.null(strata)) {
strata <- getElement(data, strata)
strata <- as.character(strata)
strata <- make_strata(strata, pool = pool)

if (is.null(v)) {
hfrick marked this conversation as resolved.
Show resolved Hide resolved
# Set max_v to be the lowest number of groups in a single strata
# to ensure that all folds get each strata
max_v <- min(
vec_count(
vec_unique(
data.frame(group, strata)
)$strata
)$count
)
message <- c(
"Leaving `v = NULL` while using stratification will set `v` to the number of groups present in the least common stratum."
)

if (max_v < 5) {
rlang::abort(c(
message,
x = glue::glue("The least common stratum only had {max_v} groups, which may not be enough for cross-validation."),
i = "Set `v` explicitly to override this error."
),
call = rlang::caller_env())
}

rlang::warn(c(
message,
i = "Set `v` explicitly to override this warning."
),
call = rlang::caller_env())
}
}
max_v <- length(unique(group))

if (is.null(v)) {
v <- max_v
} else {
check_v(v = v, max_v = max_v, rows = "groups", call = rlang::caller_env())
}
check_v(v = v, max_v = max_v, rows = "groups", call = rlang::caller_env())

indices <- make_groups(data, group, v, balance, strata)
indices <- lapply(indices, default_complement, n = nrow(data))
Expand Down
5 changes: 5 additions & 0 deletions tests/testthat/_snaps/vfold.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,11 @@
4 80131 19869 100000 3 Resample4
5 80103 19897 100000 3 Resample5

---

Leaving `v = NULL` while using stratification will set `v` to the number of groups present in the least common stratum.
i Set `v` explicitly to override this warning.

# grouping -- printing

Code
Expand Down
3 changes: 2 additions & 1 deletion tests/testthat/test-misc.R
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,8 @@ test_that("reshuffle_rset is working", {
group = "y",
strata = "z",
breaks = 2,
pool = 0.2
pool = 0.2,
v = 2
)
)
# Reshuffle them under the same seed to ensure they're identical
Expand Down
63 changes: 59 additions & 4 deletions tests/testthat/test-vfold.R
Original file line number Diff line number Diff line change
Expand Up @@ -227,9 +227,12 @@ test_that("grouping -- other balance methods", {
test_that("grouping -- strata", {
set.seed(11)

n_common_class <- 70
n_rare_class <- 30

group_table <- tibble(
group = 1:100,
outcome = sample(c(rep(0, 70), rep(1, 30)))
outcome = sample(c(rep(0, n_common_class), rep(1, n_rare_class)))
)
observation_table <- tibble(
group = sample(1:100, 1e5, replace = TRUE),
Expand All @@ -247,7 +250,7 @@ test_that("grouping -- strata", {
mean(dat == "1")
}
)
expect_equal(mean(unique(rate)), 0.3, tolerance = 1e-2)
expect_equal(mean(rate), 0.3, tolerance = 1e-2)

good_holdout <- purrr::map_lgl(
rs4$splits,
Expand All @@ -257,11 +260,63 @@ test_that("grouping -- strata", {
)
expect_true(all(good_holdout))

expect_snapshot_warning(
group_vfold_cv(sample_data, group, strata = outcome)
)

expect_equal(
nrow(group_vfold_cv(sample_data, group, strata = outcome)),
length(unique(sample_data$group))
nrow(
suppressWarnings(
group_vfold_cv(sample_data, group, strata = outcome)
)
),
n_rare_class
)

rs5 <- group_vfold_cv(
sample_data,
group,
v = 5,
strata = outcome,
balance = "observations"
)
sizes5 <- dim_rset(rs5)
expect_snapshot(sizes5)

rate <- purrr::map_dbl(
rs5$splits,
function(x) {
dat <- as.data.frame(x)$outcome
mean(dat == "1")
}
)
expect_equal(mean(rate), 0.3, tolerance = 1e-2)

good_holdout <- purrr::map_lgl(
rs5$splits,
function(x) {
length(intersect(x$in_ind, x$out_id)) == 0
}
)
expect_true(all(good_holdout))

expect_snapshot_warning(
group_vfold_cv(sample_data, group, strata = outcome)
)

expect_equal(
nrow(
suppressWarnings(
group_vfold_cv(
sample_data,
group,
strata = outcome,
balance = "observations"
)
)
),
n_rare_class
)
})

test_that("grouping -- repeated", {
Expand Down