Skip to content

Commit

Permalink
indexing changes for #443
Browse files Browse the repository at this point in the history
  • Loading branch information
topepo committed Aug 3, 2023
1 parent 46de181 commit 53b540d
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 5 deletions.
10 changes: 7 additions & 3 deletions R/initial_validation_split.R
Original file line number Diff line number Diff line change
Expand Up @@ -233,14 +233,16 @@ group_initial_validation_split <- function(data,
#' @rdname initial_validation_split
training.initial_validation_split <- function(x, ...) {
check_dots_empty()
x$data[sort(x$train_id), , drop = FALSE]
ind <- sort(x$train_id)
dplyr::slice(x$data, ind)
}

#' @export
#' @rdname initial_validation_split
testing.initial_validation_split <- function(x, ...) {
check_dots_empty()
x$data[-sort(c(x$train_id, x$val_id)), , drop = FALSE]
ind <- -sort(c(x$train_id, x$val_id))
dplyr::slice(x$data, ind)
}

#' @export
Expand All @@ -262,7 +264,9 @@ validation.default <- function(x, ...) {
#' @rdname initial_validation_split
validation.initial_validation_split <- function(x, ...) {
check_dots_empty()
x$data[sort(x$val_id), , drop = FALSE]

ind <- sort(x$val_id)
dplyr::slice(x$data, ind)
}


Expand Down
3 changes: 2 additions & 1 deletion R/rsplit.R
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,8 @@ as.data.frame.rsplit <-
x$data[, x$col_id] <- permuted_col
return(x$data)
}
x$data[as.integer(x, data = data, ...), , drop = FALSE]
row_ind <- as.integer(x, data = data, ...)
dplyr::slice(x$data, row_ind)
}

#' @rdname as.data.frame.rsplit
Expand Down
2 changes: 1 addition & 1 deletion R/validation_set.R
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
validation_set <- function(split, ...) {
rlang::check_dots_empty()

train_and_val <- rbind(
train_and_val <- dplyr::bind_rows(
training(split),
validation(split)
)
Expand Down
66 changes: 66 additions & 0 deletions tests/testthat/test-validation_set.R
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,69 @@ test_that("accessor functions for `val_split`", {
expect_equal(validation(val_split), validation(initial_val_split))
expect_snapshot(error = TRUE, testing(val_split))
})


test_that("working with Surv objects - issue #443", {

check_surv <- function(x) inherits(x$surv_obj, "Surv")

srv <-
list(
age = c(74, 68, 56, 57, 60, 74, 76, 77, 39, 75, 66, 58),
sex = c(1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 2),
surv_obj = structure(
c(306, 455, 1010, 210, 883, 1022, 1, 1, 0, 1, 1, 0,
116, 188, 191, 105, 174, 177, 1, 0, 0, 0, 0, 0),
dim = c(12L, 2L),
dimnames = list(NULL, c("time", "status")),
type = "right",
class = "Surv"))
surv_df <-
structure(
srv,
row.names = paste(1:12),
class = "data.frame")

surv_tbl <- dplyr::as_tibble(surv_df)

# ----------------------------------------------------------------------------
# data frame input

set.seed(472)
surv_split_df <- initial_validation_split(surv_df, prop = c(.3, .3))

expect_true(check_surv(surv_split_df$data))

expect_true(check_surv(training(surv_split_df)))
expect_true(check_surv(testing(surv_split_df)))
expect_true(check_surv(validation(surv_split_df)))

surv_rs_df <- validation_set(surv_split_df)
expect_true(check_surv(surv_rs_df$splits[[1]]$data))

expect_true(check_surv(training(surv_rs_df$splits[[1]])))
expect_true(check_surv(validation(surv_rs_df$splits[[1]])))
expect_true(check_surv(analysis(surv_rs_df$splits[[1]])))

# ----------------------------------------------------------------------------
# tibble input

set.seed(472)
surv_split_tbl <- initial_validation_split(surv_tbl, prop = c(.3, .3))

expect_true(check_surv(surv_split_tbl$data))

expect_true(check_surv(training(surv_split_tbl)))
expect_true(check_surv(testing(surv_split_tbl)))
expect_true(check_surv(validation(surv_split_tbl)))
expect_true(check_surv(validation(surv_split_tbl)))

surv_rs_tbl <- validation_set(surv_split_tbl)
expect_true(check_surv(surv_rs_tbl$splits[[1]]$data))

expect_true(check_surv(training(surv_rs_tbl$splits[[1]])))
expect_true(check_surv(validation(surv_rs_tbl$splits[[1]])))
expect_true(check_surv(analysis(surv_rs_tbl$splits[[1]])))
expect_true(check_surv(assessment(surv_rs_tbl$splits[[1]])))

})

0 comments on commit 53b540d

Please sign in to comment.