From 016643ebdb275792b8a015476ae5a3bfa5b39803 Mon Sep 17 00:00:00 2001 From: James Lamb Date: Thu, 2 Sep 2021 01:31:38 -0500 Subject: [PATCH 1/7] [R-package] fix segfaults caused by missing Booster and Dataset handles (fixes #4208) --- R-package/R/lgb.Dataset.R | 9 +++ R-package/src/lightgbm_R.cpp | 53 +++++++++++++ R-package/tests/testthat/test_dataset.R | 42 ++++++++++ R-package/tests/testthat/test_lgb.Booster.R | 87 +++++++++++++++++++++ 4 files changed, 191 insertions(+) diff --git a/R-package/R/lgb.Dataset.R b/R-package/R/lgb.Dataset.R index 3212c3e625bc..7ca02b4afb52 100644 --- a/R-package/R/lgb.Dataset.R +++ b/R-package/R/lgb.Dataset.R @@ -167,6 +167,15 @@ Dataset <- R6::R6Class( return(invisible(self)) } + if (is.null(private$raw_data)) { + stop(paste0( + "Attempting to create a Dataset without any raw data. " + , "This can happen if you have called Dataset$finalize() or if this Dataset was saved with saveRDS(). " + , "To avoid this error in the future, use lgb.Dataset.save() or " + , "Dataset$save_binary() to save lightgbm Datasets." + )) + } + # Get feature names cnames <- NULL if (is.matrix(private$raw_data) || methods::is(private$raw_data, "dgCMatrix")) { diff --git a/R-package/src/lightgbm_R.cpp b/R-package/src/lightgbm_R.cpp index a19a22a56a9e..618ac91fb833 100644 --- a/R-package/src/lightgbm_R.cpp +++ b/R-package/src/lightgbm_R.cpp @@ -48,6 +48,26 @@ void _DatasetFinalizer(SEXP handle) { LGBM_DatasetFree_R(handle); } +void _AssertBoosterHandleNotNull(SEXP handle) { + if (Rf_isNull(handle) || !R_ExternalPtrAddr(handle)) { + Rf_error( + "Attempting to use a Booster which no longer exists. " + "This can happen if you have called Booster$finalize() or if this Booster was saved with saveRDS(). " + "To avoid this error in the future, use saveRDS.lgb.Booster() or Booster$save_model() to save lightgbm Boosters." + ); + } +} + +void _AssertDatasetHandleNotNull(SEXP handle) { + if (Rf_isNull(handle) || !R_ExternalPtrAddr(handle)) { + Rf_error( + "Attempting to use a Dataset which no longer exists. " + "This can happen if you have called Dataset$finalize() or if this Dataset was saved with saveRDS(). " + "To avoid this error in the future, use lgb.Dataset.save() or Dataset$save_binary() to save lightgbm Datasets." + ); + } +} + SEXP LGBM_DatasetCreateFromFile_R(SEXP filename, SEXP parameters, SEXP reference) { @@ -129,6 +149,7 @@ SEXP LGBM_DatasetGetSubset_R(SEXP handle, SEXP used_row_indices, SEXP len_used_row_indices, SEXP parameters) { + _AssertDatasetHandleNotNull(handle); SEXP ret; int32_t len = static_cast(Rf_asInteger(len_used_row_indices)); std::vector idxvec(len); @@ -152,6 +173,7 @@ SEXP LGBM_DatasetGetSubset_R(SEXP handle, SEXP LGBM_DatasetSetFeatureNames_R(SEXP handle, SEXP feature_names) { + _AssertDatasetHandleNotNull(handle); auto vec_names = Split(CHAR(PROTECT(Rf_asChar(feature_names))), '\t'); std::vector vec_sptr; int len = static_cast(vec_names.size()); @@ -167,6 +189,7 @@ SEXP LGBM_DatasetSetFeatureNames_R(SEXP handle, } SEXP LGBM_DatasetGetFeatureNames_R(SEXP handle) { + _AssertDatasetHandleNotNull(handle); SEXP feature_names; int len = 0; R_API_BEGIN(); @@ -218,6 +241,7 @@ SEXP LGBM_DatasetGetFeatureNames_R(SEXP handle) { SEXP LGBM_DatasetSaveBinary_R(SEXP handle, SEXP filename) { + _AssertDatasetHandleNotNull(handle); const char* filename_ptr = CHAR(PROTECT(Rf_asChar(filename))); R_API_BEGIN(); CHECK_CALL(LGBM_DatasetSaveBinary(R_ExternalPtrAddr(handle), @@ -241,6 +265,7 @@ SEXP LGBM_DatasetSetField_R(SEXP handle, SEXP field_name, SEXP field_data, SEXP num_element) { + _AssertDatasetHandleNotNull(handle); int len = Rf_asInteger(num_element); const char* name = CHAR(PROTECT(Rf_asChar(field_name))); R_API_BEGIN(); @@ -269,6 +294,7 @@ SEXP LGBM_DatasetSetField_R(SEXP handle, SEXP LGBM_DatasetGetField_R(SEXP handle, SEXP field_name, SEXP field_data) { + _AssertDatasetHandleNotNull(handle); const char* name = CHAR(PROTECT(Rf_asChar(field_name))); int out_len = 0; int out_type = 0; @@ -303,6 +329,7 @@ SEXP LGBM_DatasetGetField_R(SEXP handle, SEXP LGBM_DatasetGetFieldSize_R(SEXP handle, SEXP field_name, SEXP out) { + _AssertDatasetHandleNotNull(handle); const char* name = CHAR(PROTECT(Rf_asChar(field_name))); int out_len = 0; int out_type = 0; @@ -330,6 +357,7 @@ SEXP LGBM_DatasetUpdateParamChecking_R(SEXP old_params, } SEXP LGBM_DatasetGetNumData_R(SEXP handle, SEXP out) { + _AssertDatasetHandleNotNull(handle); int nrow; R_API_BEGIN(); CHECK_CALL(LGBM_DatasetGetNumData(R_ExternalPtrAddr(handle), &nrow)); @@ -340,6 +368,7 @@ SEXP LGBM_DatasetGetNumData_R(SEXP handle, SEXP out) { SEXP LGBM_DatasetGetNumFeature_R(SEXP handle, SEXP out) { + _AssertDatasetHandleNotNull(handle); int nfeature; R_API_BEGIN(); CHECK_CALL(LGBM_DatasetGetNumFeature(R_ExternalPtrAddr(handle), &nfeature)); @@ -366,6 +395,7 @@ SEXP LGBM_BoosterFree_R(SEXP handle) { SEXP LGBM_BoosterCreate_R(SEXP train_data, SEXP parameters) { + _AssertDatasetHandleNotNull(train_data); SEXP ret; const char* parameters_ptr = CHAR(PROTECT(Rf_asChar(parameters))); BoosterHandle handle = nullptr; @@ -408,6 +438,8 @@ SEXP LGBM_BoosterLoadModelFromString_R(SEXP model_str) { SEXP LGBM_BoosterMerge_R(SEXP handle, SEXP other_handle) { + _AssertBoosterHandleNotNull(handle); + _AssertBoosterHandleNotNull(other_handle); R_API_BEGIN(); CHECK_CALL(LGBM_BoosterMerge(R_ExternalPtrAddr(handle), R_ExternalPtrAddr(other_handle))); R_API_END(); @@ -416,6 +448,7 @@ SEXP LGBM_BoosterMerge_R(SEXP handle, SEXP LGBM_BoosterAddValidData_R(SEXP handle, SEXP valid_data) { + _AssertBoosterHandleNotNull(handle); R_API_BEGIN(); CHECK_CALL(LGBM_BoosterAddValidData(R_ExternalPtrAddr(handle), R_ExternalPtrAddr(valid_data))); R_API_END(); @@ -424,6 +457,7 @@ SEXP LGBM_BoosterAddValidData_R(SEXP handle, SEXP LGBM_BoosterResetTrainingData_R(SEXP handle, SEXP train_data) { + _AssertBoosterHandleNotNull(handle); R_API_BEGIN(); CHECK_CALL(LGBM_BoosterResetTrainingData(R_ExternalPtrAddr(handle), R_ExternalPtrAddr(train_data))); R_API_END(); @@ -432,6 +466,7 @@ SEXP LGBM_BoosterResetTrainingData_R(SEXP handle, SEXP LGBM_BoosterResetParameter_R(SEXP handle, SEXP parameters) { + _AssertBoosterHandleNotNull(handle); const char* parameters_ptr = CHAR(PROTECT(Rf_asChar(parameters))); R_API_BEGIN(); CHECK_CALL(LGBM_BoosterResetParameter(R_ExternalPtrAddr(handle), parameters_ptr)); @@ -442,6 +477,7 @@ SEXP LGBM_BoosterResetParameter_R(SEXP handle, SEXP LGBM_BoosterGetNumClasses_R(SEXP handle, SEXP out) { + _AssertBoosterHandleNotNull(handle); int num_class; R_API_BEGIN(); CHECK_CALL(LGBM_BoosterGetNumClasses(R_ExternalPtrAddr(handle), &num_class)); @@ -451,6 +487,7 @@ SEXP LGBM_BoosterGetNumClasses_R(SEXP handle, } SEXP LGBM_BoosterUpdateOneIter_R(SEXP handle) { + _AssertBoosterHandleNotNull(handle); int is_finished = 0; R_API_BEGIN(); CHECK_CALL(LGBM_BoosterUpdateOneIter(R_ExternalPtrAddr(handle), &is_finished)); @@ -462,6 +499,7 @@ SEXP LGBM_BoosterUpdateOneIterCustom_R(SEXP handle, SEXP grad, SEXP hess, SEXP len) { + _AssertBoosterHandleNotNull(handle); int is_finished = 0; R_API_BEGIN(); int int_len = Rf_asInteger(len); @@ -477,6 +515,7 @@ SEXP LGBM_BoosterUpdateOneIterCustom_R(SEXP handle, } SEXP LGBM_BoosterRollbackOneIter_R(SEXP handle) { + _AssertBoosterHandleNotNull(handle); R_API_BEGIN(); CHECK_CALL(LGBM_BoosterRollbackOneIter(R_ExternalPtrAddr(handle))); R_API_END(); @@ -485,6 +524,7 @@ SEXP LGBM_BoosterRollbackOneIter_R(SEXP handle) { SEXP LGBM_BoosterGetCurrentIteration_R(SEXP handle, SEXP out) { + _AssertBoosterHandleNotNull(handle); int out_iteration; R_API_BEGIN(); CHECK_CALL(LGBM_BoosterGetCurrentIteration(R_ExternalPtrAddr(handle), &out_iteration)); @@ -495,6 +535,7 @@ SEXP LGBM_BoosterGetCurrentIteration_R(SEXP handle, SEXP LGBM_BoosterGetUpperBoundValue_R(SEXP handle, SEXP out_result) { + _AssertBoosterHandleNotNull(handle); R_API_BEGIN(); double* ptr_ret = REAL(out_result); CHECK_CALL(LGBM_BoosterGetUpperBoundValue(R_ExternalPtrAddr(handle), ptr_ret)); @@ -504,6 +545,7 @@ SEXP LGBM_BoosterGetUpperBoundValue_R(SEXP handle, SEXP LGBM_BoosterGetLowerBoundValue_R(SEXP handle, SEXP out_result) { + _AssertBoosterHandleNotNull(handle); R_API_BEGIN(); double* ptr_ret = REAL(out_result); CHECK_CALL(LGBM_BoosterGetLowerBoundValue(R_ExternalPtrAddr(handle), ptr_ret)); @@ -512,6 +554,7 @@ SEXP LGBM_BoosterGetLowerBoundValue_R(SEXP handle, } SEXP LGBM_BoosterGetEvalNames_R(SEXP handle) { + _AssertBoosterHandleNotNull(handle); SEXP eval_names; int len; R_API_BEGIN(); @@ -565,6 +608,7 @@ SEXP LGBM_BoosterGetEvalNames_R(SEXP handle) { SEXP LGBM_BoosterGetEval_R(SEXP handle, SEXP data_idx, SEXP out_result) { + _AssertBoosterHandleNotNull(handle); R_API_BEGIN(); int len; CHECK_CALL(LGBM_BoosterGetEvalCounts(R_ExternalPtrAddr(handle), &len)); @@ -579,6 +623,7 @@ SEXP LGBM_BoosterGetEval_R(SEXP handle, SEXP LGBM_BoosterGetNumPredict_R(SEXP handle, SEXP data_idx, SEXP out) { + _AssertBoosterHandleNotNull(handle); R_API_BEGIN(); int64_t len; CHECK_CALL(LGBM_BoosterGetNumPredict(R_ExternalPtrAddr(handle), Rf_asInteger(data_idx), &len)); @@ -590,6 +635,7 @@ SEXP LGBM_BoosterGetNumPredict_R(SEXP handle, SEXP LGBM_BoosterGetPredict_R(SEXP handle, SEXP data_idx, SEXP out_result) { + _AssertBoosterHandleNotNull(handle); R_API_BEGIN(); double* ptr_ret = REAL(out_result); int64_t out_len; @@ -622,6 +668,7 @@ SEXP LGBM_BoosterPredictForFile_R(SEXP handle, SEXP num_iteration, SEXP parameter, SEXP result_filename) { + _AssertBoosterHandleNotNull(handle); const char* data_filename_ptr = CHAR(PROTECT(Rf_asChar(data_filename))); const char* parameter_ptr = CHAR(PROTECT(Rf_asChar(parameter))); const char* result_filename_ptr = CHAR(PROTECT(Rf_asChar(result_filename))); @@ -643,6 +690,7 @@ SEXP LGBM_BoosterCalcNumPredict_R(SEXP handle, SEXP start_iteration, SEXP num_iteration, SEXP out_len) { + _AssertBoosterHandleNotNull(handle); R_API_BEGIN(); int pred_type = GetPredictType(is_rawscore, is_leafidx, is_predcontrib); int64_t len = 0; @@ -667,6 +715,7 @@ SEXP LGBM_BoosterPredictForCSC_R(SEXP handle, SEXP num_iteration, SEXP parameter, SEXP out_result) { + _AssertBoosterHandleNotNull(handle); int pred_type = GetPredictType(is_rawscore, is_leafidx, is_predcontrib); const int* p_indptr = INTEGER(indptr); const int32_t* p_indices = reinterpret_cast(INTEGER(indices)); @@ -698,6 +747,7 @@ SEXP LGBM_BoosterPredictForMat_R(SEXP handle, SEXP num_iteration, SEXP parameter, SEXP out_result) { + _AssertBoosterHandleNotNull(handle); int pred_type = GetPredictType(is_rawscore, is_leafidx, is_predcontrib); int32_t nrow = static_cast(Rf_asInteger(num_row)); int32_t ncol = static_cast(Rf_asInteger(num_col)); @@ -718,6 +768,7 @@ SEXP LGBM_BoosterSaveModel_R(SEXP handle, SEXP num_iteration, SEXP feature_importance_type, SEXP filename) { + _AssertBoosterHandleNotNull(handle); const char* filename_ptr = CHAR(PROTECT(Rf_asChar(filename))); R_API_BEGIN(); CHECK_CALL(LGBM_BoosterSaveModel(R_ExternalPtrAddr(handle), 0, Rf_asInteger(num_iteration), Rf_asInteger(feature_importance_type), filename_ptr)); @@ -729,6 +780,7 @@ SEXP LGBM_BoosterSaveModel_R(SEXP handle, SEXP LGBM_BoosterSaveModelToString_R(SEXP handle, SEXP num_iteration, SEXP feature_importance_type) { + _AssertBoosterHandleNotNull(handle); SEXP model_str; int64_t out_len = 0; int64_t buf_len = 1024 * 1024; @@ -754,6 +806,7 @@ SEXP LGBM_BoosterSaveModelToString_R(SEXP handle, SEXP LGBM_BoosterDumpModel_R(SEXP handle, SEXP num_iteration, SEXP feature_importance_type) { + _AssertBoosterHandleNotNull(handle); SEXP model_str; int64_t out_len = 0; int64_t buf_len = 1024 * 1024; diff --git a/R-package/tests/testthat/test_dataset.R b/R-package/tests/testthat/test_dataset.R index ec2250fbdbc6..2d3753a7cce0 100644 --- a/R-package/tests/testthat/test_dataset.R +++ b/R-package/tests/testthat/test_dataset.R @@ -352,3 +352,45 @@ test_that("lgb.Dataset: should be able to create a Dataset from a text file with expect_identical(dtrain$get_params(), list(header = FALSE)) expect_identical(dtrain$dim(), c(100L, 2L)) }) + +test_that("Dataset: methods called on a Dataset with a null handle should raise an informative error and not segfault", { + data(agaricus.train, package = "lightgbm") + train <- agaricus.train + dtrain <- lgb.Dataset(train$data, label = train$label) + dtrain$construct() + dvalid <- dtrain$create_valid( + data = train$data[seq_len(100L), ] + , label = train$label[seq_len(100L)] + ) + dvalid$construct() + tmp_file <- tempfile(fileext = ".rds") + saveRDS(dtrain, tmp_file) + rm(dtrain) + dtrain <- readRDS(tmp_file) + expect_error({ + dtrain$construct() + }, regexp = "Attempting to create a Dataset without any raw data") + expect_error({ + dtrain$dim() + }, regexp = "cannot get dimensions before dataset has been constructed") + expect_error({ + dtrain$get_colnames() + }, regexp = "cannot get dimensions before dataset has been constructed") + expect_error({ + dtrain$save_binary(fname = tempfile(fileext = ".bin")) + }, regexp = "Attempting to create a Dataset without any raw data") + expect_error({ + dtrain$set_categorical_feature(categorical_feature = 1L) + }, regexp = "cannot set categorical feature after freeing raw data") + expect_error({ + dtrain$set_reference(reference = dvalid) + }, regexp = "cannot set reference after freeing raw data") + + tmp_valid_file <- tempfile(fileext = ".rds") + saveRDS(dvalid, tmp_file) + rm(dvalid) + dvalid <- readRDS(tmp_file) + expect_error({ + dtrain$set_reference(reference = dvalid) + }, regexp = "please call lgb.Dataset.construct explicitly") +}) diff --git a/R-package/tests/testthat/test_lgb.Booster.R b/R-package/tests/testthat/test_lgb.Booster.R index 76d3a41c9b5b..23c457866587 100644 --- a/R-package/tests/testthat/test_lgb.Booster.R +++ b/R-package/tests/testthat/test_lgb.Booster.R @@ -820,6 +820,93 @@ test_that("early_stopping, num_iterations are stored correctly in model string e }) +test_that("Booster: methods called on a Booster with a null handle should raise an informative error and not segfault", { + data(agaricus.train, package = "lightgbm") + train <- agaricus.train + dtrain <- lgb.Dataset(train$data, label = train$label) + bst <- lgb.train( + params = list( + objective = "regression" + , metric = "l2" + , num_leaves = 8L + ) + , data = dtrain + , verbose = -1L + , nrounds = 5L + , valids = list( + train = dtrain + ) + ) + tmp_file <- tempfile(fileext = ".rds") + saveRDS(bst, tmp_file) + rm(bst) + bst <- readRDS(tmp_file) + .expect_booster_error <- function(object){ + error_regexp <- "Attempting to use a Booster which no longer exists" + expect_error(object, regexp = error_regexp) + } + .expect_booster_error({ + bst$current_iter() + }) + .expect_booster_error({ + bst$dump_model() + }) + .expect_booster_error({ + bst$eval(data = dtrain, name = "valid") + }) + .expect_booster_error({ + bst$eval_train() + }) + .expect_booster_error({ + bst$lower_bound() + }) + .expect_booster_error({ + bst$predict(data = train$data[seq_len(5L), ]) + }) + .expect_booster_error({ + bst$reset_parameter(params = list(learning_rate = 0.123)) + }) + .expect_booster_error({ + bst$rollback_one_iter() + }) + .expect_booster_error({ + bst$save() + }) + .expect_booster_error({ + bst$save_model(filename = tempfile(fileext = ".model")) + }) + .expect_booster_error({ + bst$save_model_to_string() + }) + .expect_booster_error({ + bst$update() + }) + .expect_booster_error({ + bst$upper_bound() + }) + predictor <- bst$to_predictor() + .expect_booster_error({ + predictor$current_iter() + }) + .expect_booster_error({ + predictor$predict(data = train$data[seq_len(5L), ]) + }) +}) + +test_that("Booster: attempting to create a Booster from a Dataset will a null handle should raise an informative error and not segfault", { + data(agaricus.train, package = "lightgbm") + train <- agaricus.train + dtrain <- lgb.Dataset(train$data, label = train$label) + dtrain$construct() + tmp_file <- tempfile(fileext = ".bin") + saveRDS(dtrain, tmp_file) + rm(dtrain) + dtrain <- readRDS(tmp_file) + expect_error({ + bst <- Booster$new(train_set = dtrain) + }, regexp = "Attempting to create a Dataset without any raw data") +}) + # this is almost identical to the test above it, but for lgb.cv(). A lot of code # is duplicated between lgb.train() and lgb.cv(), and this will catch cases where # one is updated and the other isn't From 564a73d2e7c79743ddd100858b791348e448b852 Mon Sep 17 00:00:00 2001 From: James Lamb Date: Thu, 2 Sep 2021 16:54:47 -0500 Subject: [PATCH 2/7] fix test errors --- R-package/R/lgb.Dataset.R | 20 ++++++++++---------- R-package/tests/testthat/test_dataset.R | 6 +++--- R-package/tests/testthat/test_lgb.Booster.R | 8 ++++---- 3 files changed, 17 insertions(+), 17 deletions(-) diff --git a/R-package/R/lgb.Dataset.R b/R-package/R/lgb.Dataset.R index 03a6483e3f37..4bdb10d82a7a 100644 --- a/R-package/R/lgb.Dataset.R +++ b/R-package/R/lgb.Dataset.R @@ -167,15 +167,6 @@ Dataset <- R6::R6Class( return(invisible(self)) } - if (is.null(private$raw_data)) { - stop(paste0( - "Attempting to create a Dataset without any raw data. " - , "This can happen if you have called Dataset$finalize() or if this Dataset was saved with saveRDS(). " - , "To avoid this error in the future, use lgb.Dataset.save() or " - , "Dataset$save_binary() to save lightgbm Datasets." - )) - } - # Get feature names cnames <- NULL if (is.matrix(private$raw_data) || methods::is(private$raw_data, "dgCMatrix")) { @@ -235,9 +226,18 @@ Dataset <- R6::R6Class( ref_handle <- private$reference$.__enclos_env__$private$get_handle() } - # Not subsetting + # not subsetting, constructing from raw data if (is.null(private$used_indices)) { + if (is.null(private$raw_data)) { + stop(paste0( + "Attempting to create a Dataset without any raw data. " + , "This can happen if you have called Dataset$finalize() or if this Dataset was saved with saveRDS(). " + , "To avoid this error in the future, use lgb.Dataset.save() or " + , "Dataset$save_binary() to save lightgbm Datasets." + )) + } + # Are we using a data file? if (is.character(private$raw_data)) { diff --git a/R-package/tests/testthat/test_dataset.R b/R-package/tests/testthat/test_dataset.R index d79e2d16db0f..b11f03c59971 100644 --- a/R-package/tests/testthat/test_dataset.R +++ b/R-package/tests/testthat/test_dataset.R @@ -381,7 +381,7 @@ test_that("lgb.Dataset: should be able to create a Dataset from a text file with expect_identical(dtrain$dim(), c(100L, 2L)) }) -test_that("Dataset: methods called on a Dataset with a null handle should raise an informative error and not segfault", { +test_that("Dataset: method calls on a Dataset with a null handle should raise an informative error and not segfault", { data(agaricus.train, package = "lightgbm") train <- agaricus.train dtrain <- lgb.Dataset(train$data, label = train$label) @@ -403,7 +403,7 @@ test_that("Dataset: methods called on a Dataset with a null handle should raise }, regexp = "cannot get dimensions before dataset has been constructed") expect_error({ dtrain$get_colnames() - }, regexp = "cannot get dimensions before dataset has been constructed") + }, regexp = "cannot get column names before dataset has been constructed") expect_error({ dtrain$save_binary(fname = tempfile(fileext = ".bin")) }, regexp = "Attempting to create a Dataset without any raw data") @@ -420,5 +420,5 @@ test_that("Dataset: methods called on a Dataset with a null handle should raise dvalid <- readRDS(tmp_file) expect_error({ dtrain$set_reference(reference = dvalid) - }, regexp = "please call lgb.Dataset.construct explicitly") + }, regexp = "cannot get column names before dataset has been constructed") }) diff --git a/R-package/tests/testthat/test_lgb.Booster.R b/R-package/tests/testthat/test_lgb.Booster.R index 23c457866587..b94c500d4d1b 100644 --- a/R-package/tests/testthat/test_lgb.Booster.R +++ b/R-package/tests/testthat/test_lgb.Booster.R @@ -820,7 +820,7 @@ test_that("early_stopping, num_iterations are stored correctly in model string e }) -test_that("Booster: methods called on a Booster with a null handle should raise an informative error and not segfault", { +test_that("Booster: method calls Booster with a null handle should raise an informative error and not segfault", { data(agaricus.train, package = "lightgbm") train <- agaricus.train dtrain <- lgb.Dataset(train$data, label = train$label) @@ -841,7 +841,7 @@ test_that("Booster: methods called on a Booster with a null handle should raise saveRDS(bst, tmp_file) rm(bst) bst <- readRDS(tmp_file) - .expect_booster_error <- function(object){ + .expect_booster_error <- function(object) { error_regexp <- "Attempting to use a Booster which no longer exists" expect_error(object, regexp = error_regexp) } @@ -893,7 +893,7 @@ test_that("Booster: methods called on a Booster with a null handle should raise }) }) -test_that("Booster: attempting to create a Booster from a Dataset will a null handle should raise an informative error and not segfault", { +test_that("Booster$new() using a Dataset with a null handle should raise an informative error and not segfault", { data(agaricus.train, package = "lightgbm") train <- agaricus.train dtrain <- lgb.Dataset(train$data, label = train$label) @@ -904,7 +904,7 @@ test_that("Booster: attempting to create a Booster from a Dataset will a null ha dtrain <- readRDS(tmp_file) expect_error({ bst <- Booster$new(train_set = dtrain) - }, regexp = "Attempting to create a Dataset without any raw data") + }, regexp = "lgb.Booster: cannot create Booster handle") }) # this is almost identical to the test above it, but for lgb.cv(). A lot of code From 2ae39054781d7984cec7b3351a88774e3869ac58 Mon Sep 17 00:00:00 2001 From: James Lamb Date: Thu, 2 Sep 2021 17:13:18 -0500 Subject: [PATCH 3/7] fixes for cpplint --- R-package/src/lightgbm_R.cpp | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/R-package/src/lightgbm_R.cpp b/R-package/src/lightgbm_R.cpp index 618ac91fb833..3ee5fa957ee4 100644 --- a/R-package/src/lightgbm_R.cpp +++ b/R-package/src/lightgbm_R.cpp @@ -53,8 +53,7 @@ void _AssertBoosterHandleNotNull(SEXP handle) { Rf_error( "Attempting to use a Booster which no longer exists. " "This can happen if you have called Booster$finalize() or if this Booster was saved with saveRDS(). " - "To avoid this error in the future, use saveRDS.lgb.Booster() or Booster$save_model() to save lightgbm Boosters." - ); + "To avoid this error in the future, use saveRDS.lgb.Booster() or Booster$save_model() to save lightgbm Boosters."); } } @@ -63,8 +62,7 @@ void _AssertDatasetHandleNotNull(SEXP handle) { Rf_error( "Attempting to use a Dataset which no longer exists. " "This can happen if you have called Dataset$finalize() or if this Dataset was saved with saveRDS(). " - "To avoid this error in the future, use lgb.Dataset.save() or Dataset$save_binary() to save lightgbm Datasets." - ); + "To avoid this error in the future, use lgb.Dataset.save() or Dataset$save_binary() to save lightgbm Datasets."); } } From fea61a8dc83214ca94ad2fd2cc1685bbb8428608 Mon Sep 17 00:00:00 2001 From: James Lamb Date: Thu, 16 Sep 2021 21:39:08 -0400 Subject: [PATCH 4/7] Update R-package/tests/testthat/test_dataset.R Co-authored-by: Nikita Titov --- R-package/tests/testthat/test_dataset.R | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/R-package/tests/testthat/test_dataset.R b/R-package/tests/testthat/test_dataset.R index 58d5d86d0bb8..b1b46390ca7f 100644 --- a/R-package/tests/testthat/test_dataset.R +++ b/R-package/tests/testthat/test_dataset.R @@ -531,9 +531,9 @@ test_that("Dataset: method calls on a Dataset with a null handle should raise an }, regexp = "cannot set reference after freeing raw data") tmp_valid_file <- tempfile(fileext = ".rds") - saveRDS(dvalid, tmp_file) + saveRDS(dvalid, tmp_valid_file) rm(dvalid) - dvalid <- readRDS(tmp_file) + dvalid <- readRDS(tmp_valid_file) expect_error({ dtrain$set_reference(reference = dvalid) }, regexp = "cannot get column names before dataset has been constructed") From 713c59d39f8d2ac4758e874ea7acc63fcd9cc38c Mon Sep 17 00:00:00 2001 From: James Lamb Date: Thu, 23 Sep 2021 21:27:01 -0500 Subject: [PATCH 5/7] fix tests --- R-package/tests/testthat/test_dataset.R | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/R-package/tests/testthat/test_dataset.R b/R-package/tests/testthat/test_dataset.R index b1b46390ca7f..d8eb0b06ce1e 100644 --- a/R-package/tests/testthat/test_dataset.R +++ b/R-package/tests/testthat/test_dataset.R @@ -534,6 +534,12 @@ test_that("Dataset: method calls on a Dataset with a null handle should raise an saveRDS(dvalid, tmp_valid_file) rm(dvalid) dvalid <- readRDS(tmp_valid_file) + dtrain <- lgb.Dataset( + train$data + , label = train$label + , free_raw_data = FALSE + ) + dtrain$construct() expect_error({ dtrain$set_reference(reference = dvalid) }, regexp = "cannot get column names before dataset has been constructed") From 9dd6651319adbf8dcb6e991fdfef3f2debbd807f Mon Sep 17 00:00:00 2001 From: James Lamb Date: Fri, 24 Sep 2021 10:05:59 -0400 Subject: [PATCH 6/7] Apply suggestions from code review Co-authored-by: Nikita Titov --- R-package/src/lightgbm_R.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/R-package/src/lightgbm_R.cpp b/R-package/src/lightgbm_R.cpp index caf6b95edad6..e4051113abf6 100644 --- a/R-package/src/lightgbm_R.cpp +++ b/R-package/src/lightgbm_R.cpp @@ -486,6 +486,7 @@ SEXP LGBM_BoosterMerge_R(SEXP handle, SEXP LGBM_BoosterAddValidData_R(SEXP handle, SEXP valid_data) { _AssertBoosterHandleNotNull(handle); + _AssertDatasetHandleNotNull(valid_data); R_API_BEGIN(); CHECK_CALL(LGBM_BoosterAddValidData(R_ExternalPtrAddr(handle), R_ExternalPtrAddr(valid_data))); return R_NilValue; @@ -495,6 +496,7 @@ SEXP LGBM_BoosterAddValidData_R(SEXP handle, SEXP LGBM_BoosterResetTrainingData_R(SEXP handle, SEXP train_data) { _AssertBoosterHandleNotNull(handle); + _AssertDatasetHandleNotNull(train_data); R_API_BEGIN(); CHECK_CALL(LGBM_BoosterResetTrainingData(R_ExternalPtrAddr(handle), R_ExternalPtrAddr(train_data))); return R_NilValue; From caa04b3f4bb5e8eb51f7d40240cfc452d8fd1cf3 Mon Sep 17 00:00:00 2001 From: James Lamb Date: Fri, 24 Sep 2021 09:09:12 -0500 Subject: [PATCH 7/7] move asserts inside try-catch --- R-package/src/lightgbm_R.cpp | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/R-package/src/lightgbm_R.cpp b/R-package/src/lightgbm_R.cpp index e4051113abf6..99dd666dbf9e 100644 --- a/R-package/src/lightgbm_R.cpp +++ b/R-package/src/lightgbm_R.cpp @@ -475,9 +475,9 @@ SEXP LGBM_BoosterLoadModelFromString_R(SEXP model_str) { SEXP LGBM_BoosterMerge_R(SEXP handle, SEXP other_handle) { + R_API_BEGIN(); _AssertBoosterHandleNotNull(handle); _AssertBoosterHandleNotNull(other_handle); - R_API_BEGIN(); CHECK_CALL(LGBM_BoosterMerge(R_ExternalPtrAddr(handle), R_ExternalPtrAddr(other_handle))); return R_NilValue; R_API_END(); @@ -485,9 +485,9 @@ SEXP LGBM_BoosterMerge_R(SEXP handle, SEXP LGBM_BoosterAddValidData_R(SEXP handle, SEXP valid_data) { + R_API_BEGIN(); _AssertBoosterHandleNotNull(handle); _AssertDatasetHandleNotNull(valid_data); - R_API_BEGIN(); CHECK_CALL(LGBM_BoosterAddValidData(R_ExternalPtrAddr(handle), R_ExternalPtrAddr(valid_data))); return R_NilValue; R_API_END(); @@ -495,9 +495,9 @@ SEXP LGBM_BoosterAddValidData_R(SEXP handle, SEXP LGBM_BoosterResetTrainingData_R(SEXP handle, SEXP train_data) { + R_API_BEGIN(); _AssertBoosterHandleNotNull(handle); _AssertDatasetHandleNotNull(train_data); - R_API_BEGIN(); CHECK_CALL(LGBM_BoosterResetTrainingData(R_ExternalPtrAddr(handle), R_ExternalPtrAddr(train_data))); return R_NilValue; R_API_END(); @@ -554,8 +554,8 @@ SEXP LGBM_BoosterUpdateOneIterCustom_R(SEXP handle, } SEXP LGBM_BoosterRollbackOneIter_R(SEXP handle) { - _AssertBoosterHandleNotNull(handle); R_API_BEGIN(); + _AssertBoosterHandleNotNull(handle); CHECK_CALL(LGBM_BoosterRollbackOneIter(R_ExternalPtrAddr(handle))); return R_NilValue; R_API_END(); @@ -574,8 +574,8 @@ SEXP LGBM_BoosterGetCurrentIteration_R(SEXP handle, SEXP LGBM_BoosterGetUpperBoundValue_R(SEXP handle, SEXP out_result) { - _AssertBoosterHandleNotNull(handle); R_API_BEGIN(); + _AssertBoosterHandleNotNull(handle); double* ptr_ret = REAL(out_result); CHECK_CALL(LGBM_BoosterGetUpperBoundValue(R_ExternalPtrAddr(handle), ptr_ret)); return R_NilValue; @@ -584,8 +584,8 @@ SEXP LGBM_BoosterGetUpperBoundValue_R(SEXP handle, SEXP LGBM_BoosterGetLowerBoundValue_R(SEXP handle, SEXP out_result) { - _AssertBoosterHandleNotNull(handle); R_API_BEGIN(); + _AssertBoosterHandleNotNull(handle); double* ptr_ret = REAL(out_result); CHECK_CALL(LGBM_BoosterGetLowerBoundValue(R_ExternalPtrAddr(handle), ptr_ret)); return R_NilValue; @@ -644,8 +644,8 @@ SEXP LGBM_BoosterGetEvalNames_R(SEXP handle) { SEXP LGBM_BoosterGetEval_R(SEXP handle, SEXP data_idx, SEXP out_result) { - _AssertBoosterHandleNotNull(handle); R_API_BEGIN(); + _AssertBoosterHandleNotNull(handle); int len; CHECK_CALL(LGBM_BoosterGetEvalCounts(R_ExternalPtrAddr(handle), &len)); double* ptr_ret = REAL(out_result); @@ -659,8 +659,8 @@ SEXP LGBM_BoosterGetEval_R(SEXP handle, SEXP LGBM_BoosterGetNumPredict_R(SEXP handle, SEXP data_idx, SEXP out) { - _AssertBoosterHandleNotNull(handle); R_API_BEGIN(); + _AssertBoosterHandleNotNull(handle); int64_t len; CHECK_CALL(LGBM_BoosterGetNumPredict(R_ExternalPtrAddr(handle), Rf_asInteger(data_idx), &len)); INTEGER(out)[0] = static_cast(len); @@ -671,8 +671,8 @@ SEXP LGBM_BoosterGetNumPredict_R(SEXP handle, SEXP LGBM_BoosterGetPredict_R(SEXP handle, SEXP data_idx, SEXP out_result) { - _AssertBoosterHandleNotNull(handle); R_API_BEGIN(); + _AssertBoosterHandleNotNull(handle); double* ptr_ret = REAL(out_result); int64_t out_len; CHECK_CALL(LGBM_BoosterGetPredict(R_ExternalPtrAddr(handle), Rf_asInteger(data_idx), &out_len, ptr_ret)); @@ -726,8 +726,8 @@ SEXP LGBM_BoosterCalcNumPredict_R(SEXP handle, SEXP start_iteration, SEXP num_iteration, SEXP out_len) { - _AssertBoosterHandleNotNull(handle); R_API_BEGIN(); + _AssertBoosterHandleNotNull(handle); int pred_type = GetPredictType(is_rawscore, is_leafidx, is_predcontrib); int64_t len = 0; CHECK_CALL(LGBM_BoosterCalcNumPredict(R_ExternalPtrAddr(handle), Rf_asInteger(num_row),