Skip to content

Commit

Permalink
[R] enable multi-dimensional base_margin (#9885)
Browse files Browse the repository at this point in the history
  • Loading branch information
david-cortes authored Dec 14, 2023
1 parent 936b22f commit cd473c9
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 4 deletions.
6 changes: 3 additions & 3 deletions R-package/R/xgb.DMatrix.R
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
#' only care about the relative ordering of data points within each group,
#' so it doesn't make sense to assign weights to individual data points.
#' @param base_margin Base margin used for boosting from existing model.
#'
#' In the case of multi-output models, one can also pass multi-dimensional base_margin.
#' @param missing a float value to represents missing values in data (used only when input is a dense matrix).
#' It is useful when a 0 or some other extreme value represents missing values in data.
#' @param silent whether to suppress printing an informational message after loading from a file.
Expand Down Expand Up @@ -439,9 +441,7 @@ setinfo.xgb.DMatrix <- function(object, name, info, ...) {
return(TRUE)
}
if (name == "base_margin") {
# if (length(info)!=nrow(object))
# stop("The length of base margin must equal to the number of rows in the input data")
.Call(XGDMatrixSetInfo_R, object, name, as.numeric(info))
.Call(XGDMatrixSetInfo_R, object, name, info)
return(TRUE)
}
if (name == "group") {
Expand Down
4 changes: 3 additions & 1 deletion R-package/man/xgb.DMatrix.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

26 changes: 26 additions & 0 deletions R-package/tests/testthat/test_dmatrix.R
Original file line number Diff line number Diff line change
Expand Up @@ -349,3 +349,29 @@ test_that("xgb.DMatrix: data.frame", {
m <- xgb.DMatrix(df, enable_categorical = TRUE)
expect_equal(getinfo(m, "feature_type"), c("c", "c"))
})

test_that("xgb.DMatrix: can take multi-dimensional 'base_margin'", {
set.seed(123)
x <- matrix(rnorm(100 * 10), nrow = 100)
y <- matrix(rnorm(100 * 2), nrow = 100)
b <- matrix(rnorm(100 * 2), nrow = 100)
model <- xgb.train(
data = xgb.DMatrix(data = x, label = y, nthread = n_threads),
params = list(
objective = "reg:squarederror",
tree_method = "hist",
multi_strategy = "multi_output_tree",
base_score = 0,
nthread = n_threads
),
nround = 1
)
pred_only_x <- predict(model, x, nthread = n_threads, reshape = TRUE)
pred_w_base <- predict(
model,
xgb.DMatrix(data = x, base_margin = b, nthread = n_threads),
nthread = n_threads,
reshape = TRUE
)
expect_equal(pred_only_x, pred_w_base - b, tolerance = 1e-5)
})

0 comments on commit cd473c9

Please sign in to comment.