diff --git a/R-package/R/lgb.Dataset.R b/R-package/R/lgb.Dataset.R index 429eb1f91275..53724a7144a2 100644 --- a/R-package/R/lgb.Dataset.R +++ b/R-package/R/lgb.Dataset.R @@ -226,9 +226,18 @@ Dataset <- R6::R6Class( ref_handle <- private$reference$.__enclos_env__$private$get_handle() } - # Not subsetting + # not subsetting, constructing from raw data if (is.null(private$used_indices)) { + if (is.null(private$raw_data)) { + stop(paste0( + "Attempting to create a Dataset without any raw data. " + , "This can happen if you have called Dataset$finalize() or if this Dataset was saved with saveRDS(). " + , "To avoid this error in the future, use lgb.Dataset.save() or " + , "Dataset$save_binary() to save lightgbm Datasets." + )) + } + # Are we using a data file? if (is.character(private$raw_data)) { diff --git a/R-package/src/lightgbm_R.cpp b/R-package/src/lightgbm_R.cpp index b393011186f1..99dd666dbf9e 100644 --- a/R-package/src/lightgbm_R.cpp +++ b/R-package/src/lightgbm_R.cpp @@ -90,6 +90,24 @@ void _DatasetFinalizer(SEXP handle) { LGBM_DatasetFree_R(handle); } +void _AssertBoosterHandleNotNull(SEXP handle) { + if (Rf_isNull(handle) || !R_ExternalPtrAddr(handle)) { + Rf_error( + "Attempting to use a Booster which no longer exists. " + "This can happen if you have called Booster$finalize() or if this Booster was saved with saveRDS(). " + "To avoid this error in the future, use saveRDS.lgb.Booster() or Booster$save_model() to save lightgbm Boosters."); + } +} + +void _AssertDatasetHandleNotNull(SEXP handle) { + if (Rf_isNull(handle) || !R_ExternalPtrAddr(handle)) { + Rf_error( + "Attempting to use a Dataset which no longer exists. " + "This can happen if you have called Dataset$finalize() or if this Dataset was saved with saveRDS(). " + "To avoid this error in the future, use lgb.Dataset.save() or Dataset$save_binary() to save lightgbm Datasets."); + } +} + SEXP LGBM_DatasetCreateFromFile_R(SEXP filename, SEXP parameters, SEXP reference) { @@ -172,6 +190,7 @@ SEXP LGBM_DatasetGetSubset_R(SEXP handle, SEXP len_used_row_indices, SEXP parameters) { R_API_BEGIN(); + _AssertDatasetHandleNotNull(handle); SEXP ret = PROTECT(R_MakeExternalPtr(nullptr, R_NilValue, R_NilValue)); int32_t len = static_cast(Rf_asInteger(len_used_row_indices)); std::vector idxvec(len); @@ -195,6 +214,7 @@ SEXP LGBM_DatasetGetSubset_R(SEXP handle, SEXP LGBM_DatasetSetFeatureNames_R(SEXP handle, SEXP feature_names) { R_API_BEGIN(); + _AssertDatasetHandleNotNull(handle); auto vec_names = Split(CHAR(PROTECT(Rf_asChar(feature_names))), '\t'); std::vector vec_sptr; int len = static_cast(vec_names.size()); @@ -211,6 +231,7 @@ SEXP LGBM_DatasetSetFeatureNames_R(SEXP handle, SEXP LGBM_DatasetGetFeatureNames_R(SEXP handle) { SEXP cont_token = PROTECT(R_MakeUnwindCont()); R_API_BEGIN(); + _AssertDatasetHandleNotNull(handle); SEXP feature_names; int len = 0; CHECK_CALL(LGBM_DatasetGetNumFeature(R_ExternalPtrAddr(handle), &len)); @@ -258,6 +279,7 @@ SEXP LGBM_DatasetGetFeatureNames_R(SEXP handle) { SEXP LGBM_DatasetSaveBinary_R(SEXP handle, SEXP filename) { R_API_BEGIN(); + _AssertDatasetHandleNotNull(handle); const char* filename_ptr = CHAR(PROTECT(Rf_asChar(filename))); CHECK_CALL(LGBM_DatasetSaveBinary(R_ExternalPtrAddr(handle), filename_ptr)); @@ -281,6 +303,7 @@ SEXP LGBM_DatasetSetField_R(SEXP handle, SEXP field_data, SEXP num_element) { R_API_BEGIN(); + _AssertDatasetHandleNotNull(handle); int len = Rf_asInteger(num_element); const char* name = CHAR(PROTECT(Rf_asChar(field_name))); if (!strcmp("group", name) || !strcmp("query", name)) { @@ -309,6 +332,7 @@ SEXP LGBM_DatasetGetField_R(SEXP handle, SEXP field_name, SEXP field_data) { R_API_BEGIN(); + _AssertDatasetHandleNotNull(handle); const char* name = CHAR(PROTECT(Rf_asChar(field_name))); int out_len = 0; int out_type = 0; @@ -343,6 +367,7 @@ SEXP LGBM_DatasetGetFieldSize_R(SEXP handle, SEXP field_name, SEXP out) { R_API_BEGIN(); + _AssertDatasetHandleNotNull(handle); const char* name = CHAR(PROTECT(Rf_asChar(field_name))); int out_len = 0; int out_type = 0; @@ -370,6 +395,7 @@ SEXP LGBM_DatasetUpdateParamChecking_R(SEXP old_params, SEXP LGBM_DatasetGetNumData_R(SEXP handle, SEXP out) { R_API_BEGIN(); + _AssertDatasetHandleNotNull(handle); int nrow; CHECK_CALL(LGBM_DatasetGetNumData(R_ExternalPtrAddr(handle), &nrow)); INTEGER(out)[0] = nrow; @@ -380,6 +406,7 @@ SEXP LGBM_DatasetGetNumData_R(SEXP handle, SEXP out) { SEXP LGBM_DatasetGetNumFeature_R(SEXP handle, SEXP out) { R_API_BEGIN(); + _AssertDatasetHandleNotNull(handle); int nfeature; CHECK_CALL(LGBM_DatasetGetNumFeature(R_ExternalPtrAddr(handle), &nfeature)); INTEGER(out)[0] = nfeature; @@ -406,6 +433,7 @@ SEXP LGBM_BoosterFree_R(SEXP handle) { SEXP LGBM_BoosterCreate_R(SEXP train_data, SEXP parameters) { R_API_BEGIN(); + _AssertDatasetHandleNotNull(train_data); SEXP ret = PROTECT(R_MakeExternalPtr(nullptr, R_NilValue, R_NilValue)); const char* parameters_ptr = CHAR(PROTECT(Rf_asChar(parameters))); BoosterHandle handle = nullptr; @@ -448,6 +476,8 @@ SEXP LGBM_BoosterLoadModelFromString_R(SEXP model_str) { SEXP LGBM_BoosterMerge_R(SEXP handle, SEXP other_handle) { R_API_BEGIN(); + _AssertBoosterHandleNotNull(handle); + _AssertBoosterHandleNotNull(other_handle); CHECK_CALL(LGBM_BoosterMerge(R_ExternalPtrAddr(handle), R_ExternalPtrAddr(other_handle))); return R_NilValue; R_API_END(); @@ -456,6 +486,8 @@ SEXP LGBM_BoosterMerge_R(SEXP handle, SEXP LGBM_BoosterAddValidData_R(SEXP handle, SEXP valid_data) { R_API_BEGIN(); + _AssertBoosterHandleNotNull(handle); + _AssertDatasetHandleNotNull(valid_data); CHECK_CALL(LGBM_BoosterAddValidData(R_ExternalPtrAddr(handle), R_ExternalPtrAddr(valid_data))); return R_NilValue; R_API_END(); @@ -464,6 +496,8 @@ SEXP LGBM_BoosterAddValidData_R(SEXP handle, SEXP LGBM_BoosterResetTrainingData_R(SEXP handle, SEXP train_data) { R_API_BEGIN(); + _AssertBoosterHandleNotNull(handle); + _AssertDatasetHandleNotNull(train_data); CHECK_CALL(LGBM_BoosterResetTrainingData(R_ExternalPtrAddr(handle), R_ExternalPtrAddr(train_data))); return R_NilValue; R_API_END(); @@ -472,6 +506,7 @@ SEXP LGBM_BoosterResetTrainingData_R(SEXP handle, SEXP LGBM_BoosterResetParameter_R(SEXP handle, SEXP parameters) { R_API_BEGIN(); + _AssertBoosterHandleNotNull(handle); const char* parameters_ptr = CHAR(PROTECT(Rf_asChar(parameters))); CHECK_CALL(LGBM_BoosterResetParameter(R_ExternalPtrAddr(handle), parameters_ptr)); UNPROTECT(1); @@ -482,6 +517,7 @@ SEXP LGBM_BoosterResetParameter_R(SEXP handle, SEXP LGBM_BoosterGetNumClasses_R(SEXP handle, SEXP out) { R_API_BEGIN(); + _AssertBoosterHandleNotNull(handle); int num_class; CHECK_CALL(LGBM_BoosterGetNumClasses(R_ExternalPtrAddr(handle), &num_class)); INTEGER(out)[0] = num_class; @@ -491,6 +527,7 @@ SEXP LGBM_BoosterGetNumClasses_R(SEXP handle, SEXP LGBM_BoosterUpdateOneIter_R(SEXP handle) { R_API_BEGIN(); + _AssertBoosterHandleNotNull(handle); int is_finished = 0; CHECK_CALL(LGBM_BoosterUpdateOneIter(R_ExternalPtrAddr(handle), &is_finished)); return R_NilValue; @@ -502,6 +539,7 @@ SEXP LGBM_BoosterUpdateOneIterCustom_R(SEXP handle, SEXP hess, SEXP len) { R_API_BEGIN(); + _AssertBoosterHandleNotNull(handle); int is_finished = 0; int int_len = Rf_asInteger(len); std::vector tgrad(int_len), thess(int_len); @@ -517,6 +555,7 @@ SEXP LGBM_BoosterUpdateOneIterCustom_R(SEXP handle, SEXP LGBM_BoosterRollbackOneIter_R(SEXP handle) { R_API_BEGIN(); + _AssertBoosterHandleNotNull(handle); CHECK_CALL(LGBM_BoosterRollbackOneIter(R_ExternalPtrAddr(handle))); return R_NilValue; R_API_END(); @@ -525,6 +564,7 @@ SEXP LGBM_BoosterRollbackOneIter_R(SEXP handle) { SEXP LGBM_BoosterGetCurrentIteration_R(SEXP handle, SEXP out) { R_API_BEGIN(); + _AssertBoosterHandleNotNull(handle); int out_iteration; CHECK_CALL(LGBM_BoosterGetCurrentIteration(R_ExternalPtrAddr(handle), &out_iteration)); INTEGER(out)[0] = out_iteration; @@ -535,6 +575,7 @@ SEXP LGBM_BoosterGetCurrentIteration_R(SEXP handle, SEXP LGBM_BoosterGetUpperBoundValue_R(SEXP handle, SEXP out_result) { R_API_BEGIN(); + _AssertBoosterHandleNotNull(handle); double* ptr_ret = REAL(out_result); CHECK_CALL(LGBM_BoosterGetUpperBoundValue(R_ExternalPtrAddr(handle), ptr_ret)); return R_NilValue; @@ -544,6 +585,7 @@ SEXP LGBM_BoosterGetUpperBoundValue_R(SEXP handle, SEXP LGBM_BoosterGetLowerBoundValue_R(SEXP handle, SEXP out_result) { R_API_BEGIN(); + _AssertBoosterHandleNotNull(handle); double* ptr_ret = REAL(out_result); CHECK_CALL(LGBM_BoosterGetLowerBoundValue(R_ExternalPtrAddr(handle), ptr_ret)); return R_NilValue; @@ -553,6 +595,7 @@ SEXP LGBM_BoosterGetLowerBoundValue_R(SEXP handle, SEXP LGBM_BoosterGetEvalNames_R(SEXP handle) { SEXP cont_token = PROTECT(R_MakeUnwindCont()); R_API_BEGIN(); + _AssertBoosterHandleNotNull(handle); SEXP eval_names; int len; CHECK_CALL(LGBM_BoosterGetEvalCounts(R_ExternalPtrAddr(handle), &len)); @@ -602,6 +645,7 @@ SEXP LGBM_BoosterGetEval_R(SEXP handle, SEXP data_idx, SEXP out_result) { R_API_BEGIN(); + _AssertBoosterHandleNotNull(handle); int len; CHECK_CALL(LGBM_BoosterGetEvalCounts(R_ExternalPtrAddr(handle), &len)); double* ptr_ret = REAL(out_result); @@ -616,6 +660,7 @@ SEXP LGBM_BoosterGetNumPredict_R(SEXP handle, SEXP data_idx, SEXP out) { R_API_BEGIN(); + _AssertBoosterHandleNotNull(handle); int64_t len; CHECK_CALL(LGBM_BoosterGetNumPredict(R_ExternalPtrAddr(handle), Rf_asInteger(data_idx), &len)); INTEGER(out)[0] = static_cast(len); @@ -627,6 +672,7 @@ SEXP LGBM_BoosterGetPredict_R(SEXP handle, SEXP data_idx, SEXP out_result) { R_API_BEGIN(); + _AssertBoosterHandleNotNull(handle); double* ptr_ret = REAL(out_result); int64_t out_len; CHECK_CALL(LGBM_BoosterGetPredict(R_ExternalPtrAddr(handle), Rf_asInteger(data_idx), &out_len, ptr_ret)); @@ -659,6 +705,7 @@ SEXP LGBM_BoosterPredictForFile_R(SEXP handle, SEXP parameter, SEXP result_filename) { R_API_BEGIN(); + _AssertBoosterHandleNotNull(handle); const char* data_filename_ptr = CHAR(PROTECT(Rf_asChar(data_filename))); const char* parameter_ptr = CHAR(PROTECT(Rf_asChar(parameter))); const char* result_filename_ptr = CHAR(PROTECT(Rf_asChar(result_filename))); @@ -680,6 +727,7 @@ SEXP LGBM_BoosterCalcNumPredict_R(SEXP handle, SEXP num_iteration, SEXP out_len) { R_API_BEGIN(); + _AssertBoosterHandleNotNull(handle); int pred_type = GetPredictType(is_rawscore, is_leafidx, is_predcontrib); int64_t len = 0; CHECK_CALL(LGBM_BoosterCalcNumPredict(R_ExternalPtrAddr(handle), Rf_asInteger(num_row), @@ -704,6 +752,7 @@ SEXP LGBM_BoosterPredictForCSC_R(SEXP handle, SEXP parameter, SEXP out_result) { R_API_BEGIN(); + _AssertBoosterHandleNotNull(handle); int pred_type = GetPredictType(is_rawscore, is_leafidx, is_predcontrib); const int* p_indptr = INTEGER(indptr); const int32_t* p_indices = reinterpret_cast(INTEGER(indices)); @@ -735,6 +784,7 @@ SEXP LGBM_BoosterPredictForMat_R(SEXP handle, SEXP parameter, SEXP out_result) { R_API_BEGIN(); + _AssertBoosterHandleNotNull(handle); int pred_type = GetPredictType(is_rawscore, is_leafidx, is_predcontrib); int32_t nrow = static_cast(Rf_asInteger(num_row)); int32_t ncol = static_cast(Rf_asInteger(num_col)); @@ -755,6 +805,7 @@ SEXP LGBM_BoosterSaveModel_R(SEXP handle, SEXP feature_importance_type, SEXP filename) { R_API_BEGIN(); + _AssertBoosterHandleNotNull(handle); const char* filename_ptr = CHAR(PROTECT(Rf_asChar(filename))); CHECK_CALL(LGBM_BoosterSaveModel(R_ExternalPtrAddr(handle), 0, Rf_asInteger(num_iteration), Rf_asInteger(feature_importance_type), filename_ptr)); UNPROTECT(1); @@ -767,6 +818,7 @@ SEXP LGBM_BoosterSaveModelToString_R(SEXP handle, SEXP feature_importance_type) { SEXP cont_token = PROTECT(R_MakeUnwindCont()); R_API_BEGIN(); + _AssertBoosterHandleNotNull(handle); SEXP model_str; int64_t out_len = 0; int64_t buf_len = 1024 * 1024; @@ -791,6 +843,7 @@ SEXP LGBM_BoosterDumpModel_R(SEXP handle, SEXP feature_importance_type) { SEXP cont_token = PROTECT(R_MakeUnwindCont()); R_API_BEGIN(); + _AssertBoosterHandleNotNull(handle); SEXP model_str; int64_t out_len = 0; int64_t buf_len = 1024 * 1024; diff --git a/R-package/tests/testthat/test_dataset.R b/R-package/tests/testthat/test_dataset.R index 04ca2e3d1cd3..d8eb0b06ce1e 100644 --- a/R-package/tests/testthat/test_dataset.R +++ b/R-package/tests/testthat/test_dataset.R @@ -496,3 +496,51 @@ test_that("lgb.Dataset: should be able to create a Dataset from a text file with expect_identical(dtrain$get_params(), list(header = FALSE)) expect_identical(dtrain$dim(), c(100L, 2L)) }) + +test_that("Dataset: method calls on a Dataset with a null handle should raise an informative error and not segfault", { + data(agaricus.train, package = "lightgbm") + train <- agaricus.train + dtrain <- lgb.Dataset(train$data, label = train$label) + dtrain$construct() + dvalid <- dtrain$create_valid( + data = train$data[seq_len(100L), ] + , label = train$label[seq_len(100L)] + ) + dvalid$construct() + tmp_file <- tempfile(fileext = ".rds") + saveRDS(dtrain, tmp_file) + rm(dtrain) + dtrain <- readRDS(tmp_file) + expect_error({ + dtrain$construct() + }, regexp = "Attempting to create a Dataset without any raw data") + expect_error({ + dtrain$dim() + }, regexp = "cannot get dimensions before dataset has been constructed") + expect_error({ + dtrain$get_colnames() + }, regexp = "cannot get column names before dataset has been constructed") + expect_error({ + dtrain$save_binary(fname = tempfile(fileext = ".bin")) + }, regexp = "Attempting to create a Dataset without any raw data") + expect_error({ + dtrain$set_categorical_feature(categorical_feature = 1L) + }, regexp = "cannot set categorical feature after freeing raw data") + expect_error({ + dtrain$set_reference(reference = dvalid) + }, regexp = "cannot set reference after freeing raw data") + + tmp_valid_file <- tempfile(fileext = ".rds") + saveRDS(dvalid, tmp_valid_file) + rm(dvalid) + dvalid <- readRDS(tmp_valid_file) + dtrain <- lgb.Dataset( + train$data + , label = train$label + , free_raw_data = FALSE + ) + dtrain$construct() + expect_error({ + dtrain$set_reference(reference = dvalid) + }, regexp = "cannot get column names before dataset has been constructed") +}) diff --git a/R-package/tests/testthat/test_lgb.Booster.R b/R-package/tests/testthat/test_lgb.Booster.R index 7ec7ff86285b..af1ecb3c7504 100644 --- a/R-package/tests/testthat/test_lgb.Booster.R +++ b/R-package/tests/testthat/test_lgb.Booster.R @@ -819,6 +819,93 @@ test_that("early_stopping, num_iterations are stored correctly in model string e }) +test_that("Booster: method calls Booster with a null handle should raise an informative error and not segfault", { + data(agaricus.train, package = "lightgbm") + train <- agaricus.train + dtrain <- lgb.Dataset(train$data, label = train$label) + bst <- lgb.train( + params = list( + objective = "regression" + , metric = "l2" + , num_leaves = 8L + ) + , data = dtrain + , verbose = -1L + , nrounds = 5L + , valids = list( + train = dtrain + ) + ) + tmp_file <- tempfile(fileext = ".rds") + saveRDS(bst, tmp_file) + rm(bst) + bst <- readRDS(tmp_file) + .expect_booster_error <- function(object) { + error_regexp <- "Attempting to use a Booster which no longer exists" + expect_error(object, regexp = error_regexp) + } + .expect_booster_error({ + bst$current_iter() + }) + .expect_booster_error({ + bst$dump_model() + }) + .expect_booster_error({ + bst$eval(data = dtrain, name = "valid") + }) + .expect_booster_error({ + bst$eval_train() + }) + .expect_booster_error({ + bst$lower_bound() + }) + .expect_booster_error({ + bst$predict(data = train$data[seq_len(5L), ]) + }) + .expect_booster_error({ + bst$reset_parameter(params = list(learning_rate = 0.123)) + }) + .expect_booster_error({ + bst$rollback_one_iter() + }) + .expect_booster_error({ + bst$save() + }) + .expect_booster_error({ + bst$save_model(filename = tempfile(fileext = ".model")) + }) + .expect_booster_error({ + bst$save_model_to_string() + }) + .expect_booster_error({ + bst$update() + }) + .expect_booster_error({ + bst$upper_bound() + }) + predictor <- bst$to_predictor() + .expect_booster_error({ + predictor$current_iter() + }) + .expect_booster_error({ + predictor$predict(data = train$data[seq_len(5L), ]) + }) +}) + +test_that("Booster$new() using a Dataset with a null handle should raise an informative error and not segfault", { + data(agaricus.train, package = "lightgbm") + train <- agaricus.train + dtrain <- lgb.Dataset(train$data, label = train$label) + dtrain$construct() + tmp_file <- tempfile(fileext = ".bin") + saveRDS(dtrain, tmp_file) + rm(dtrain) + dtrain <- readRDS(tmp_file) + expect_error({ + bst <- Booster$new(train_set = dtrain) + }, regexp = "lgb.Booster: cannot create Booster handle") +}) + # this is almost identical to the test above it, but for lgb.cv(). A lot of code # is duplicated between lgb.train() and lgb.cv(), and this will catch cases where # one is updated and the other isn't