Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Standardize classing for grouped resampling #325

Merged
merged 3 commits into from
Jul 1, 2022
Merged

Standardize classing for grouped resampling #325

merged 3 commits into from
Jul 1, 2022

Conversation

mikemahoney218
Copy link
Member

@mikemahoney218 mikemahoney218 commented Jun 30, 2022

This PR fixes #318 by standardizing how we subclass when performed grouped resampling. Grouped resampling now:

  • Always indicates it's been grouped when printed
  • Always includes all the classes of the non-grouped variant
  • Always adds a group_rset, and group_x class to the rset object
  • Always adds a group_x class to the rsplit object
library(rsample)
data(ames, package = "modeldata")

vf <- vfold_cv(ames)
gvf <- group_vfold_cv(ames, Neighborhood)
vf
#> #  10-fold cross-validation 
#> # A tibble: 10 × 2
#>    splits             id    
#>    <list>             <chr> 
#>  1 <split [2637/293]> Fold01
#>  2 <split [2637/293]> Fold02
#>  3 <split [2637/293]> Fold03
#>  4 <split [2637/293]> Fold04
#>  5 <split [2637/293]> Fold05
#>  6 <split [2637/293]> Fold06
#>  7 <split [2637/293]> Fold07
#>  8 <split [2637/293]> Fold08
#>  9 <split [2637/293]> Fold09
#> 10 <split [2637/293]> Fold10
gvf
#> # Group 28-fold cross-validation 
#> # A tibble: 28 × 2
#>    splits             id        
#>    <list>             <chr>     
#>  1 <split [2907/23]>  Resample01
#>  2 <split [2663/267]> Resample02
#>  3 <split [2929/1]>   Resample03
#>  4 <split [2920/10]>  Resample04
#>  5 <split [2736/194]> Resample05
#>  6 <split [2816/114]> Resample06
#>  7 <split [2837/93]>  Resample07
#>  8 <split [2827/103]> Resample08
#>  9 <split [2922/8]>   Resample09
#> 10 <split [2691/239]> Resample10
#> # … with 18 more rows
class(gvf)
#> [1] "group_vfold_cv" "vfold_cv"       "group_rset"     "rset"          
#> [5] "tbl_df"         "tbl"            "data.frame"
class(gvf$splits[[1]])
#> [1] "group_vfold_split" "vfold_split"       "rsplit"
all(class(vf) %in% class(gvf))
#> [1] TRUE
all(class(vf$splits[[1]]) %in% class(gvf$splits[[1]]))
#> [1] TRUE

mc <- mc_cv(ames)
gmc <- group_mc_cv(ames, Neighborhood)
mc
#> # Monte Carlo cross-validation (0.75/0.25) with 25 resamples  
#> # A tibble: 25 × 2
#>    splits             id        
#>    <list>             <chr>     
#>  1 <split [2197/733]> Resample01
#>  2 <split [2197/733]> Resample02
#>  3 <split [2197/733]> Resample03
#>  4 <split [2197/733]> Resample04
#>  5 <split [2197/733]> Resample05
#>  6 <split [2197/733]> Resample06
#>  7 <split [2197/733]> Resample07
#>  8 <split [2197/733]> Resample08
#>  9 <split [2197/733]> Resample09
#> 10 <split [2197/733]> Resample10
#> # … with 15 more rows
gmc
#> # Grouped Monte Carlo cross-validation (0.75/0.25) with 25 resamples  
#> # A tibble: 25 × 2
#>    splits             id        
#>    <list>             <chr>     
#>  1 <split [2205/725]> Resample01
#>  2 <split [2394/536]> Resample02
#>  3 <split [2292/638]> Resample03
#>  4 <split [2129/801]> Resample04
#>  5 <split [2206/724]> Resample05
#>  6 <split [2230/700]> Resample06
#>  7 <split [2222/708]> Resample07
#>  8 <split [2249/681]> Resample08
#>  9 <split [2125/805]> Resample09
#> 10 <split [2303/627]> Resample10
#> # … with 15 more rows
class(gmc)
#> [1] "group_mc_cv" "mc_cv"       "group_rset"  "rset"        "tbl_df"     
#> [6] "tbl"         "data.frame"
class(gmc$splits[[1]])
#> [1] "grouped_mc_split" "mc_split"         "rsplit"
all(class(mc) %in% class(gmc))
#> [1] TRUE
all(class(mc$splits[[1]]) %in% class(gmc$splits[[1]]))
#> [1] TRUE

bs <- bootstraps(ames)
gbs <- group_bootstraps(ames, Neighborhood)
bs
#> # Bootstrap sampling 
#> # A tibble: 25 × 2
#>    splits              id         
#>    <list>              <chr>      
#>  1 <split [2930/1071]> Bootstrap01
#>  2 <split [2930/1072]> Bootstrap02
#>  3 <split [2930/1100]> Bootstrap03
#>  4 <split [2930/1071]> Bootstrap04
#>  5 <split [2930/1078]> Bootstrap05
#>  6 <split [2930/1082]> Bootstrap06
#>  7 <split [2930/1054]> Bootstrap07
#>  8 <split [2930/1104]> Bootstrap08
#>  9 <split [2930/1077]> Bootstrap09
#> 10 <split [2930/1064]> Bootstrap10
#> # … with 15 more rows
gbs
#> # Group bootstrap sampling 
#> # A tibble: 25 × 2
#>    splits              id         
#>    <list>              <chr>      
#>  1 <split [2901/872]>  Bootstrap01
#>  2 <split [2948/1053]> Bootstrap02
#>  3 <split [2933/885]>  Bootstrap03
#>  4 <split [2947/1206]> Bootstrap04
#>  5 <split [3035/1174]> Bootstrap05
#>  6 <split [2958/1112]> Bootstrap06
#>  7 <split [2914/1066]> Bootstrap07
#>  8 <split [2903/839]>  Bootstrap08
#>  9 <split [2920/1321]> Bootstrap09
#> 10 <split [2823/1532]> Bootstrap10
#> # … with 15 more rows
class(gbs)
#> [1] "group_bootstraps" "bootstraps"       "group_rset"       "rset"            
#> [5] "tbl_df"           "tbl"              "data.frame"
class(gbs$splits[[1]])
#> [1] "group_boot_split" "boot_split"       "rsplit"
all(class(bs) %in% class(gbs))
#> [1] TRUE
all(class(bs$splits[[1]]) %in% class(gbs$splits[[1]]))
#> [1] TRUE

is <- initial_split(ames)
gis <- group_initial_split(ames, Neighborhood)
is
#> <Training/Testing/Total>
#> <2197/733/2930>
gis
#> <Training/Testing/Total>
#> <2187/743/2930>
class(gis)
#> [1] "group_initial_split" "initial_split"       "grouped_mc_split"   
#> [4] "mc_split"            "rsplit"
class(is)
#> [1] "initial_split" "mc_split"      "rsplit"
all(class(is) %in% class(gis))
#> [1] TRUE

vs <- validation_split(ames)
gvs <- group_validation_split(ames, Neighborhood)
vs
#> # Validation Set Split (0.75/0.25)  
#> # A tibble: 1 × 2
#>   splits             id        
#>   <list>             <chr>     
#> 1 <split [2197/733]> validation
gvs
#> # Grouped Validation Set Split (0.75/0.25)  
#> # A tibble: 1 × 2
#>   splits             id        
#>   <list>             <chr>     
#> 1 <split [2177/753]> validation
class(gvs)
#> [1] "group_validation_split" "validation_split"       "group_rset"            
#> [4] "rset"                   "tbl_df"                 "tbl"                   
#> [7] "data.frame"
class(gvs$splits[[1]])
#> [1] "group_val_split" "val_split"       "rsplit"
all(class(vs) %in% class(gvs))
#> [1] TRUE
all(class(vs$splits[[1]]) %in% class(gvs$splits[[1]]))
#> [1] TRUE

Created on 2022-06-30 by the reprex package (v2.0.1)

@mikemahoney218 mikemahoney218 marked this pull request as ready for review June 30, 2022 19:15
@@ -195,13 +195,13 @@ group_vfold_cv <- function(data, group = NULL, v = NULL, balance = c("groups", "

## Save some overall information

cv_att <- list(v = v, group = group, balance = balance)
cv_att <- list(v = v, group = group, balance = balance, repeats = 1, strata = FALSE)
Copy link
Member Author

@mikemahoney218 mikemahoney218 Jun 30, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bit extraneous for this PR, but as part of a future PR (likely to address #79) I want to try and make sure that group_* variants always have all the same attributes as their non-grouped superclass, so that methods dispatching on the non-grouped variant can always handle the grouped variant as well.

Copy link
Member

@juliasilge juliasilge left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've got one question on printing, and then I was wondering if we could write a test to loop through all the relevant rset types and do this, like you posted:

all(class(mc$splits[[1]]) %in% class(gmc$splits[[1]]))

R/printing.R Outdated
pretty.group_validation_split <- function(x, ...) {
details <- attributes(x)
res <- paste0(
"Grouped Validation Set Split (",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think we should print "Grouped" or "Group"? We have both right now.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm -- I'm maybe 60/40 in favor of "Group", because the output isn't "Grouped" in a dplyr sense.

Do we want to print either, though? Is it ever important to know in interactive use that these splits were made using groups? Programmatic use will look for the group_* classes.

Between the three, I think my favorite is to always prefix with "Group", but it's an extremely weak opinion.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, let's keep the special printing for now and go with "Group". 👍

@mikemahoney218
Copy link
Member Author

Alright, added tests and changed printing in 43c778e 😄

Copy link
Member

@juliasilge juliasilge left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perfect! Thank you so much 🙌

@juliasilge juliasilge merged commit b3ba57c into main Jul 1, 2022
@juliasilge juliasilge deleted the mike/318 branch July 1, 2022 16:22
@github-actions
Copy link

This pull request has been automatically locked. If you believe you have found a related problem, please file a new issue (with a reprex: https://reprex.tidyverse.org) and link to this issue.

@github-actions github-actions bot locked and limited conversation to collaborators Jul 16, 2022
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Grouped resamples are inconsistently of a group_* class
2 participants