Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[R-package] fix segfaults caused by missing Booster and Dataset handles (fixes #4208) #4586

Merged
merged 12 commits into from
Sep 25, 2021
11 changes: 10 additions & 1 deletion R-package/R/lgb.Dataset.R
Original file line number Diff line number Diff line change
Expand Up @@ -226,9 +226,18 @@ Dataset <- R6::R6Class(
ref_handle <- private$reference$.__enclos_env__$private$get_handle()
}

# Not subsetting
# not subsetting, constructing from raw data
if (is.null(private$used_indices)) {

if (is.null(private$raw_data)) {
stop(paste0(
"Attempting to create a Dataset without any raw data. "
, "This can happen if you have called Dataset$finalize() or if this Dataset was saved with saveRDS(). "
, "To avoid this error in the future, use lgb.Dataset.save() or "
, "Dataset$save_binary() to save lightgbm Datasets."
))
}

# Are we using a data file?
if (is.character(private$raw_data)) {

Expand Down
53 changes: 53 additions & 0 deletions R-package/src/lightgbm_R.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,24 @@ void _DatasetFinalizer(SEXP handle) {
LGBM_DatasetFree_R(handle);
}

void _AssertBoosterHandleNotNull(SEXP handle) {
if (Rf_isNull(handle) || !R_ExternalPtrAddr(handle)) {
Rf_error(
"Attempting to use a Booster which no longer exists. "
"This can happen if you have called Booster$finalize() or if this Booster was saved with saveRDS(). "
"To avoid this error in the future, use saveRDS.lgb.Booster() or Booster$save_model() to save lightgbm Boosters.");
}
}

void _AssertDatasetHandleNotNull(SEXP handle) {
if (Rf_isNull(handle) || !R_ExternalPtrAddr(handle)) {
Rf_error(
"Attempting to use a Dataset which no longer exists. "
"This can happen if you have called Dataset$finalize() or if this Dataset was saved with saveRDS(). "
"To avoid this error in the future, use lgb.Dataset.save() or Dataset$save_binary() to save lightgbm Datasets.");
}
}

SEXP LGBM_DatasetCreateFromFile_R(SEXP filename,
SEXP parameters,
SEXP reference) {
Expand Down Expand Up @@ -172,6 +190,7 @@ SEXP LGBM_DatasetGetSubset_R(SEXP handle,
SEXP len_used_row_indices,
SEXP parameters) {
R_API_BEGIN();
_AssertDatasetHandleNotNull(handle);
SEXP ret = PROTECT(R_MakeExternalPtr(nullptr, R_NilValue, R_NilValue));
int32_t len = static_cast<int32_t>(Rf_asInteger(len_used_row_indices));
std::vector<int32_t> idxvec(len);
Expand All @@ -195,6 +214,7 @@ SEXP LGBM_DatasetGetSubset_R(SEXP handle,
SEXP LGBM_DatasetSetFeatureNames_R(SEXP handle,
SEXP feature_names) {
R_API_BEGIN();
_AssertDatasetHandleNotNull(handle);
auto vec_names = Split(CHAR(PROTECT(Rf_asChar(feature_names))), '\t');
std::vector<const char*> vec_sptr;
int len = static_cast<int>(vec_names.size());
Expand All @@ -211,6 +231,7 @@ SEXP LGBM_DatasetSetFeatureNames_R(SEXP handle,
SEXP LGBM_DatasetGetFeatureNames_R(SEXP handle) {
SEXP cont_token = PROTECT(R_MakeUnwindCont());
R_API_BEGIN();
_AssertDatasetHandleNotNull(handle);
SEXP feature_names;
int len = 0;
CHECK_CALL(LGBM_DatasetGetNumFeature(R_ExternalPtrAddr(handle), &len));
Expand Down Expand Up @@ -258,6 +279,7 @@ SEXP LGBM_DatasetGetFeatureNames_R(SEXP handle) {
SEXP LGBM_DatasetSaveBinary_R(SEXP handle,
SEXP filename) {
R_API_BEGIN();
_AssertDatasetHandleNotNull(handle);
const char* filename_ptr = CHAR(PROTECT(Rf_asChar(filename)));
CHECK_CALL(LGBM_DatasetSaveBinary(R_ExternalPtrAddr(handle),
filename_ptr));
Expand All @@ -281,6 +303,7 @@ SEXP LGBM_DatasetSetField_R(SEXP handle,
SEXP field_data,
SEXP num_element) {
R_API_BEGIN();
_AssertDatasetHandleNotNull(handle);
int len = Rf_asInteger(num_element);
const char* name = CHAR(PROTECT(Rf_asChar(field_name)));
if (!strcmp("group", name) || !strcmp("query", name)) {
Expand Down Expand Up @@ -309,6 +332,7 @@ SEXP LGBM_DatasetGetField_R(SEXP handle,
SEXP field_name,
SEXP field_data) {
R_API_BEGIN();
_AssertDatasetHandleNotNull(handle);
const char* name = CHAR(PROTECT(Rf_asChar(field_name)));
int out_len = 0;
int out_type = 0;
Expand Down Expand Up @@ -343,6 +367,7 @@ SEXP LGBM_DatasetGetFieldSize_R(SEXP handle,
SEXP field_name,
SEXP out) {
R_API_BEGIN();
_AssertDatasetHandleNotNull(handle);
const char* name = CHAR(PROTECT(Rf_asChar(field_name)));
int out_len = 0;
int out_type = 0;
Expand Down Expand Up @@ -370,6 +395,7 @@ SEXP LGBM_DatasetUpdateParamChecking_R(SEXP old_params,

SEXP LGBM_DatasetGetNumData_R(SEXP handle, SEXP out) {
R_API_BEGIN();
_AssertDatasetHandleNotNull(handle);
int nrow;
CHECK_CALL(LGBM_DatasetGetNumData(R_ExternalPtrAddr(handle), &nrow));
INTEGER(out)[0] = nrow;
Expand All @@ -380,6 +406,7 @@ SEXP LGBM_DatasetGetNumData_R(SEXP handle, SEXP out) {
SEXP LGBM_DatasetGetNumFeature_R(SEXP handle,
SEXP out) {
R_API_BEGIN();
_AssertDatasetHandleNotNull(handle);
int nfeature;
CHECK_CALL(LGBM_DatasetGetNumFeature(R_ExternalPtrAddr(handle), &nfeature));
INTEGER(out)[0] = nfeature;
Expand All @@ -406,6 +433,7 @@ SEXP LGBM_BoosterFree_R(SEXP handle) {
SEXP LGBM_BoosterCreate_R(SEXP train_data,
SEXP parameters) {
R_API_BEGIN();
_AssertDatasetHandleNotNull(train_data);
SEXP ret = PROTECT(R_MakeExternalPtr(nullptr, R_NilValue, R_NilValue));
const char* parameters_ptr = CHAR(PROTECT(Rf_asChar(parameters)));
BoosterHandle handle = nullptr;
Expand Down Expand Up @@ -448,6 +476,8 @@ SEXP LGBM_BoosterLoadModelFromString_R(SEXP model_str) {
SEXP LGBM_BoosterMerge_R(SEXP handle,
SEXP other_handle) {
R_API_BEGIN();
_AssertBoosterHandleNotNull(handle);
_AssertBoosterHandleNotNull(other_handle);
CHECK_CALL(LGBM_BoosterMerge(R_ExternalPtrAddr(handle), R_ExternalPtrAddr(other_handle)));
return R_NilValue;
R_API_END();
Expand All @@ -456,6 +486,8 @@ SEXP LGBM_BoosterMerge_R(SEXP handle,
SEXP LGBM_BoosterAddValidData_R(SEXP handle,
SEXP valid_data) {
R_API_BEGIN();
_AssertBoosterHandleNotNull(handle);
_AssertDatasetHandleNotNull(valid_data);
CHECK_CALL(LGBM_BoosterAddValidData(R_ExternalPtrAddr(handle), R_ExternalPtrAddr(valid_data)));
return R_NilValue;
R_API_END();
Expand All @@ -464,6 +496,8 @@ SEXP LGBM_BoosterAddValidData_R(SEXP handle,
SEXP LGBM_BoosterResetTrainingData_R(SEXP handle,
SEXP train_data) {
R_API_BEGIN();
_AssertBoosterHandleNotNull(handle);
_AssertDatasetHandleNotNull(train_data);
CHECK_CALL(LGBM_BoosterResetTrainingData(R_ExternalPtrAddr(handle), R_ExternalPtrAddr(train_data)));
return R_NilValue;
R_API_END();
Expand All @@ -472,6 +506,7 @@ SEXP LGBM_BoosterResetTrainingData_R(SEXP handle,
SEXP LGBM_BoosterResetParameter_R(SEXP handle,
SEXP parameters) {
R_API_BEGIN();
_AssertBoosterHandleNotNull(handle);
const char* parameters_ptr = CHAR(PROTECT(Rf_asChar(parameters)));
CHECK_CALL(LGBM_BoosterResetParameter(R_ExternalPtrAddr(handle), parameters_ptr));
UNPROTECT(1);
Expand All @@ -482,6 +517,7 @@ SEXP LGBM_BoosterResetParameter_R(SEXP handle,
SEXP LGBM_BoosterGetNumClasses_R(SEXP handle,
SEXP out) {
R_API_BEGIN();
_AssertBoosterHandleNotNull(handle);
int num_class;
CHECK_CALL(LGBM_BoosterGetNumClasses(R_ExternalPtrAddr(handle), &num_class));
INTEGER(out)[0] = num_class;
Expand All @@ -491,6 +527,7 @@ SEXP LGBM_BoosterGetNumClasses_R(SEXP handle,

SEXP LGBM_BoosterUpdateOneIter_R(SEXP handle) {
R_API_BEGIN();
_AssertBoosterHandleNotNull(handle);
int is_finished = 0;
CHECK_CALL(LGBM_BoosterUpdateOneIter(R_ExternalPtrAddr(handle), &is_finished));
return R_NilValue;
Expand All @@ -502,6 +539,7 @@ SEXP LGBM_BoosterUpdateOneIterCustom_R(SEXP handle,
SEXP hess,
SEXP len) {
R_API_BEGIN();
_AssertBoosterHandleNotNull(handle);
int is_finished = 0;
int int_len = Rf_asInteger(len);
std::vector<float> tgrad(int_len), thess(int_len);
Expand All @@ -517,6 +555,7 @@ SEXP LGBM_BoosterUpdateOneIterCustom_R(SEXP handle,

SEXP LGBM_BoosterRollbackOneIter_R(SEXP handle) {
R_API_BEGIN();
_AssertBoosterHandleNotNull(handle);
CHECK_CALL(LGBM_BoosterRollbackOneIter(R_ExternalPtrAddr(handle)));
return R_NilValue;
R_API_END();
Expand All @@ -525,6 +564,7 @@ SEXP LGBM_BoosterRollbackOneIter_R(SEXP handle) {
SEXP LGBM_BoosterGetCurrentIteration_R(SEXP handle,
SEXP out) {
R_API_BEGIN();
_AssertBoosterHandleNotNull(handle);
int out_iteration;
CHECK_CALL(LGBM_BoosterGetCurrentIteration(R_ExternalPtrAddr(handle), &out_iteration));
INTEGER(out)[0] = out_iteration;
Expand All @@ -535,6 +575,7 @@ SEXP LGBM_BoosterGetCurrentIteration_R(SEXP handle,
SEXP LGBM_BoosterGetUpperBoundValue_R(SEXP handle,
SEXP out_result) {
R_API_BEGIN();
_AssertBoosterHandleNotNull(handle);
double* ptr_ret = REAL(out_result);
CHECK_CALL(LGBM_BoosterGetUpperBoundValue(R_ExternalPtrAddr(handle), ptr_ret));
return R_NilValue;
Expand All @@ -544,6 +585,7 @@ SEXP LGBM_BoosterGetUpperBoundValue_R(SEXP handle,
SEXP LGBM_BoosterGetLowerBoundValue_R(SEXP handle,
SEXP out_result) {
R_API_BEGIN();
_AssertBoosterHandleNotNull(handle);
double* ptr_ret = REAL(out_result);
CHECK_CALL(LGBM_BoosterGetLowerBoundValue(R_ExternalPtrAddr(handle), ptr_ret));
return R_NilValue;
Expand All @@ -553,6 +595,7 @@ SEXP LGBM_BoosterGetLowerBoundValue_R(SEXP handle,
SEXP LGBM_BoosterGetEvalNames_R(SEXP handle) {
SEXP cont_token = PROTECT(R_MakeUnwindCont());
R_API_BEGIN();
_AssertBoosterHandleNotNull(handle);
SEXP eval_names;
int len;
CHECK_CALL(LGBM_BoosterGetEvalCounts(R_ExternalPtrAddr(handle), &len));
Expand Down Expand Up @@ -602,6 +645,7 @@ SEXP LGBM_BoosterGetEval_R(SEXP handle,
SEXP data_idx,
SEXP out_result) {
R_API_BEGIN();
_AssertBoosterHandleNotNull(handle);
int len;
CHECK_CALL(LGBM_BoosterGetEvalCounts(R_ExternalPtrAddr(handle), &len));
double* ptr_ret = REAL(out_result);
Expand All @@ -616,6 +660,7 @@ SEXP LGBM_BoosterGetNumPredict_R(SEXP handle,
SEXP data_idx,
SEXP out) {
R_API_BEGIN();
_AssertBoosterHandleNotNull(handle);
int64_t len;
CHECK_CALL(LGBM_BoosterGetNumPredict(R_ExternalPtrAddr(handle), Rf_asInteger(data_idx), &len));
INTEGER(out)[0] = static_cast<int>(len);
Expand All @@ -627,6 +672,7 @@ SEXP LGBM_BoosterGetPredict_R(SEXP handle,
SEXP data_idx,
SEXP out_result) {
R_API_BEGIN();
_AssertBoosterHandleNotNull(handle);
double* ptr_ret = REAL(out_result);
int64_t out_len;
CHECK_CALL(LGBM_BoosterGetPredict(R_ExternalPtrAddr(handle), Rf_asInteger(data_idx), &out_len, ptr_ret));
Expand Down Expand Up @@ -659,6 +705,7 @@ SEXP LGBM_BoosterPredictForFile_R(SEXP handle,
SEXP parameter,
SEXP result_filename) {
R_API_BEGIN();
_AssertBoosterHandleNotNull(handle);
const char* data_filename_ptr = CHAR(PROTECT(Rf_asChar(data_filename)));
const char* parameter_ptr = CHAR(PROTECT(Rf_asChar(parameter)));
const char* result_filename_ptr = CHAR(PROTECT(Rf_asChar(result_filename)));
Expand All @@ -680,6 +727,7 @@ SEXP LGBM_BoosterCalcNumPredict_R(SEXP handle,
SEXP num_iteration,
SEXP out_len) {
R_API_BEGIN();
_AssertBoosterHandleNotNull(handle);
int pred_type = GetPredictType(is_rawscore, is_leafidx, is_predcontrib);
int64_t len = 0;
CHECK_CALL(LGBM_BoosterCalcNumPredict(R_ExternalPtrAddr(handle), Rf_asInteger(num_row),
Expand All @@ -704,6 +752,7 @@ SEXP LGBM_BoosterPredictForCSC_R(SEXP handle,
SEXP parameter,
SEXP out_result) {
R_API_BEGIN();
_AssertBoosterHandleNotNull(handle);
int pred_type = GetPredictType(is_rawscore, is_leafidx, is_predcontrib);
const int* p_indptr = INTEGER(indptr);
const int32_t* p_indices = reinterpret_cast<const int32_t*>(INTEGER(indices));
Expand Down Expand Up @@ -735,6 +784,7 @@ SEXP LGBM_BoosterPredictForMat_R(SEXP handle,
SEXP parameter,
SEXP out_result) {
R_API_BEGIN();
_AssertBoosterHandleNotNull(handle);
int pred_type = GetPredictType(is_rawscore, is_leafidx, is_predcontrib);
int32_t nrow = static_cast<int32_t>(Rf_asInteger(num_row));
int32_t ncol = static_cast<int32_t>(Rf_asInteger(num_col));
Expand All @@ -755,6 +805,7 @@ SEXP LGBM_BoosterSaveModel_R(SEXP handle,
SEXP feature_importance_type,
SEXP filename) {
R_API_BEGIN();
_AssertBoosterHandleNotNull(handle);
const char* filename_ptr = CHAR(PROTECT(Rf_asChar(filename)));
CHECK_CALL(LGBM_BoosterSaveModel(R_ExternalPtrAddr(handle), 0, Rf_asInteger(num_iteration), Rf_asInteger(feature_importance_type), filename_ptr));
UNPROTECT(1);
Expand All @@ -767,6 +818,7 @@ SEXP LGBM_BoosterSaveModelToString_R(SEXP handle,
SEXP feature_importance_type) {
SEXP cont_token = PROTECT(R_MakeUnwindCont());
R_API_BEGIN();
_AssertBoosterHandleNotNull(handle);
SEXP model_str;
int64_t out_len = 0;
int64_t buf_len = 1024 * 1024;
Expand All @@ -791,6 +843,7 @@ SEXP LGBM_BoosterDumpModel_R(SEXP handle,
SEXP feature_importance_type) {
SEXP cont_token = PROTECT(R_MakeUnwindCont());
R_API_BEGIN();
_AssertBoosterHandleNotNull(handle);
SEXP model_str;
int64_t out_len = 0;
int64_t buf_len = 1024 * 1024;
Expand Down
48 changes: 48 additions & 0 deletions R-package/tests/testthat/test_dataset.R
Original file line number Diff line number Diff line change
Expand Up @@ -496,3 +496,51 @@ test_that("lgb.Dataset: should be able to create a Dataset from a text file with
expect_identical(dtrain$get_params(), list(header = FALSE))
expect_identical(dtrain$dim(), c(100L, 2L))
})

test_that("Dataset: method calls on a Dataset with a null handle should raise an informative error and not segfault", {
data(agaricus.train, package = "lightgbm")
train <- agaricus.train
dtrain <- lgb.Dataset(train$data, label = train$label)
dtrain$construct()
dvalid <- dtrain$create_valid(
data = train$data[seq_len(100L), ]
, label = train$label[seq_len(100L)]
)
dvalid$construct()
tmp_file <- tempfile(fileext = ".rds")
saveRDS(dtrain, tmp_file)
rm(dtrain)
dtrain <- readRDS(tmp_file)
expect_error({
dtrain$construct()
}, regexp = "Attempting to create a Dataset without any raw data")
expect_error({
dtrain$dim()
}, regexp = "cannot get dimensions before dataset has been constructed")
expect_error({
dtrain$get_colnames()
}, regexp = "cannot get column names before dataset has been constructed")
expect_error({
dtrain$save_binary(fname = tempfile(fileext = ".bin"))
}, regexp = "Attempting to create a Dataset without any raw data")
expect_error({
dtrain$set_categorical_feature(categorical_feature = 1L)
}, regexp = "cannot set categorical feature after freeing raw data")
expect_error({
dtrain$set_reference(reference = dvalid)
}, regexp = "cannot set reference after freeing raw data")

tmp_valid_file <- tempfile(fileext = ".rds")
saveRDS(dvalid, tmp_valid_file)
rm(dvalid)
dvalid <- readRDS(tmp_valid_file)
dtrain <- lgb.Dataset(
train$data
, label = train$label
, free_raw_data = FALSE
)
dtrain$construct()
expect_error({
dtrain$set_reference(reference = dvalid)
}, regexp = "cannot get column names before dataset has been constructed")
})
Loading