diff --git a/R-package/R/xgb.DMatrix.R b/R-package/R/xgb.DMatrix.R index fead30413159..4b2bb0d2a8a1 100644 --- a/R-package/R/xgb.DMatrix.R +++ b/R-package/R/xgb.DMatrix.R @@ -16,6 +16,8 @@ #' only care about the relative ordering of data points within each group, #' so it doesn't make sense to assign weights to individual data points. #' @param base_margin Base margin used for boosting from existing model. +#' +#' In the case of multi-output models, one can also pass multi-dimensional base_margin. #' @param missing a float value to represents missing values in data (used only when input is a dense matrix). #' It is useful when a 0 or some other extreme value represents missing values in data. #' @param silent whether to suppress printing an informational message after loading from a file. @@ -439,9 +441,7 @@ setinfo.xgb.DMatrix <- function(object, name, info, ...) { return(TRUE) } if (name == "base_margin") { - # if (length(info)!=nrow(object)) - # stop("The length of base margin must equal to the number of rows in the input data") - .Call(XGDMatrixSetInfo_R, object, name, as.numeric(info)) + .Call(XGDMatrixSetInfo_R, object, name, info) return(TRUE) } if (name == "group") { diff --git a/R-package/man/xgb.DMatrix.Rd b/R-package/man/xgb.DMatrix.Rd index 95cc8d3cd34f..a1ef39f0b21f 100644 --- a/R-package/man/xgb.DMatrix.Rd +++ b/R-package/man/xgb.DMatrix.Rd @@ -36,7 +36,9 @@ is assigned to each group (not each data point). This is because we only care about the relative ordering of data points within each group, so it doesn't make sense to assign weights to individual data points.} -\item{base_margin}{Base margin used for boosting from existing model.} +\item{base_margin}{Base margin used for boosting from existing model. + + In the case of multi-output models, one can also pass multi-dimensional base_margin.} \item{missing}{a float value to represents missing values in data (used only when input is a dense matrix). It is useful when a 0 or some other extreme value represents missing values in data.} diff --git a/R-package/tests/testthat/test_dmatrix.R b/R-package/tests/testthat/test_dmatrix.R index 87a73d84b31a..55a6996874fb 100644 --- a/R-package/tests/testthat/test_dmatrix.R +++ b/R-package/tests/testthat/test_dmatrix.R @@ -349,3 +349,29 @@ test_that("xgb.DMatrix: data.frame", { m <- xgb.DMatrix(df, enable_categorical = TRUE) expect_equal(getinfo(m, "feature_type"), c("c", "c")) }) + +test_that("xgb.DMatrix: can take multi-dimensional 'base_margin'", { + set.seed(123) + x <- matrix(rnorm(100 * 10), nrow = 100) + y <- matrix(rnorm(100 * 2), nrow = 100) + b <- matrix(rnorm(100 * 2), nrow = 100) + model <- xgb.train( + data = xgb.DMatrix(data = x, label = y, nthread = n_threads), + params = list( + objective = "reg:squarederror", + tree_method = "hist", + multi_strategy = "multi_output_tree", + base_score = 0, + nthread = n_threads + ), + nround = 1 + ) + pred_only_x <- predict(model, x, nthread = n_threads, reshape = TRUE) + pred_w_base <- predict( + model, + xgb.DMatrix(data = x, base_margin = b, nthread = n_threads), + nthread = n_threads, + reshape = TRUE + ) + expect_equal(pred_only_x, pred_w_base - b, tolerance = 1e-5) +})