Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature idea - provide custom validation sets for early stopping #48

Open
dfsnow opened this issue Aug 25, 2022 · 4 comments
Open

Feature idea - provide custom validation sets for early stopping #48

dfsnow opened this issue Aug 25, 2022 · 4 comments
Labels
feature a feature request or enhancement

Comments

@dfsnow
Copy link

dfsnow commented Aug 25, 2022

Thanks for creating this excellent package. I created a similar fork of treesnip but am planning to replace it with {bonsai} in all our production models.

One feature that I think would be incredibly useful in {bonsai} is the ability to provide custom validation sets during early stopping (instead of using a random split of the training data). This would have a few potential benefits:

  1. More training data. In many cases, you're already going to have a validation set set aside from a classic train, validate, test split. Currently, {bonsai} will further split the train data into train subset and validation specifically for early stopping sets. Instead, it would be ideal to be able to pass the validate set directly. This would mean all of train would be used for training.
  2. Ability to do more complex cross-validation. Certain cross-validation techniques (rolling origin, spatial, etc.) don't rely on a random sample of the training data and instead use some sort of partitioning (time or geographic). Allowing custom validation data would let users use the "correct" validation set for early stopping when using these more complex methods.
  3. Better integration with tidymodels. Tidymodels supports k-fold and other types of cross-validation. Using the validation set created for each fold rather than splitting a separate validation set specifically for early stopping would be much simpler.

Let me know if this is out-of-scope for this project. If not, I'm happy to contribute if needed.

@simonpcouch
Copy link
Contributor

Thanks for the issue! I'm on board. :)

Related to tidymodels/parsnip#760, and tidymodels/parsnip#765.

My response for the analogous parsnip issues reflects where my thinking is at with this in bonsai as well.

This is an interesting idea and one that we ought to consider. xgboost and lightgbm's interfaces for validation sets allow for a lot of user control, but we'd need to think carefully about what a tidymodels-esque interface might feel like here.

This won't be on the top of our to-do list for now, but will leave this open as a possible future extension. :)

@dfsnow
Copy link
Author

dfsnow commented Aug 25, 2022

Great! Thanks for the quick response. Looks like there's already a PR in {parsnip} for exactly this @ tidymodels/parsnip#771. I'll await that merge and then happy to assist with any further work needed to integrate it into {bonsai}.

@jameslamb
Copy link
Contributor

Whenever you or others here pick this up @simonpcouch , @ me if you need any help with how to do this in {lightgbm}.

There is a LightGBM-y way to create validation sets that is slightly different from "just subset rows". See https://lightgbm.readthedocs.io/en/latest/R/reference/lgb.Dataset.create.valid.html.

@diegoperoni
Copy link

diegoperoni commented Jun 25, 2024

Hi,
I wrote a simple fix to allow an alternative way to specify a custom validation set using "validation" param.
Using this code with bonsai v0.3.0 user can provide:

Example:

validation = 0.3 # default random sample (current solution)

validation = c(0.7, 0.9) # alternative solution to select a continuous subset starting from 70% and ending at 90% of the training set.

Here the code to replace the internal function after bonsai library 0.3.0 has been loaded.

Hope it is useful

Regards

 utils::assignInNamespace(
  x  = "process_data",
  ns = "bonsai",
  value = function(args, x, y, weights, validation, missing_validation) {
    
    #                                           trn_index       | val_index
    #                                         ----------------------------------
    #  needs_validation &  missing_validation | 1:n               1:n
    #  needs_validation & !missing_validation | sample(1:n, m)    setdiff(trn_index, 1:n)
    # !needs_validation &  missing_validation | 1:n               NULL
    # !needs_validation & !missing_validation | sample(1:n, m)    setdiff(trn_index, 1:n)
  
    n <- nrow(x)
    needs_validation <- !is.null(args$params$early_stopping_round)
    if (!needs_validation) {
      # If early_stopping_round isn't set, clear it from arguments actually
      # passed to LightGBM.
      args$params$early_stopping_round <- NULL
    }
    
    if (missing_validation) {
      trn_index <- 1:n
      if (needs_validation) {
        val_index <- trn_index
      } else {
        val_index <- NULL
      }
    } else {
      if (length(validation)==2) {
        # validation range percent bounds c(lower, higher)
        l <- floor(n * validation[1]) + 1
        h <- floor(n * validation[2])
        val_index <- c(l:h)
        trn_index <- setdiff(1:n, val_index)
      } else {
        # validation percent as scalar (default method)
        m <- min(floor(n * (1 - validation)) + 1, n - 1)
        trn_index <- sample(1:n, size = max(m, 2))
        val_index <- setdiff(1:n, trn_index)
      }
    }
    
    data_args <-
      c(
        list(
          data = bonsai:::prepare_df_lgbm(x[trn_index, , drop = FALSE]),
          label = y[trn_index],
          categorical_feature = bonsai:::categorical_columns(x[trn_index, , drop = FALSE]),
          params = c(list(feature_pre_filter = FALSE), args$params),
          weight = weights[trn_index]
        ),
        args$main_args_dataset
      )
    
    args$main_args_train$data <-
      rlang::eval_bare(
        rlang::call2("lgb.Dataset", !!!data_args, .ns = "lightgbm")
      )
    
    if (!is.null(val_index)) {
      valids_args <-
        c(
          list(
            data = bonsai:::prepare_df_lgbm(x[val_index, , drop = FALSE]),
            label = y[val_index],
            categorical_feature = bonsai:::categorical_columns(x[val_index, , drop = FALSE]),
            params = list(feature_pre_filter = FALSE, args$params),
            weight = weights[val_index]
          ),
          args$main_args_dataset
        )
      
      args$main_args_train$valids <-
        list(
          validation =
            rlang::eval_bare(
              rlang::call2("lgb.Dataset", !!!valids_args, .ns = "lightgbm")
            )
        )
    }
    
    args
})

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature a feature request or enhancement
Projects
None yet
Development

No branches or pull requests

4 participants