Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[R] Predictions with base margin are off by constant number #9869

Closed
david-cortes opened this issue Dec 11, 2023 · 6 comments · Fixed by #9882
Closed

[R] Predictions with base margin are off by constant number #9869

david-cortes opened this issue Dec 11, 2023 · 6 comments · Fixed by #9882

Comments

@david-cortes
Copy link
Contributor

ref #9810

It seems in the R interface, adding a base margin to a DMatrix somehow results in the predictions being off by exactly 1.48.

Example:

library(xgboost)
data("agaricus.train")
x <- agaricus.train$data[1:100, ]
y <- agaricus.train$label[1:100]
n_threads = 1
dm_labelled <- xgb.DMatrix(
    data = x,
    label = y,
    nthread = n_threads
)
dm_only_x <- xgb.DMatrix(
    data = x,
    nthread = n_threads
)
dm_w_base_margin <- xgb.DMatrix(
    data = x,
    nthread = n_threads,
    base_margin = rep(0, nrow(x))
)
params <- list(
    objective = "binary:logistic",
    tree_method = "hist",
    nthread = n_threads
)
model <- xgb.train(
    data = dm_labelled,
    params = params,
    nrounds = 5L
)
pred_wo_margin <- predict(model, dm_only_x, outputmargin = TRUE)
pred_w_margin <- predict(model, dm_w_base_margin, outputmargin = TRUE)
print(pred_wo_margin[1] - pred_w_margin[1])
[1] -1.48

In contrast, the parameter seems to work as expected on the python interface:

import numpy as np
rng = np.random.default_rng(seed=123)
X = rng.standard_normal(size=(100,10))
y = rng.integers(2, size=100)
import xgboost as xgb
dmat_labelled = xgb.DMatrix(data=X, label=y)
dmat_only_x = xgb.DMatrix(data=X)
dmat_w_base_margin = xgb.DMatrix(data=X, base_margin=np.zeros(y.shape[0]))
model = xgb.train(
    dtrain=dmat_labelled,
    params={
        "objective" : "binary:logistic",
        "tree_method" : "hist",
    },
    num_boost_round = 5
)
pred_wo_margin = model.predict(dmat_only_x, output_margin=True)
pred_w_margin = model.predict(dmat_w_base_margin, output_margin=True)
print(pred_w_margin[0] - pred_wo_margin[0])
@trivialfis
Copy link
Member

Hmm, there are some inconsistencies between the base margin handling and base score handling, I will take a deeper look.

@trivialfis
Copy link
Member

basically, whether we should apply the inverse link function to the user-provided base margin. At the moment, it's no. But I'm wondering whether this is a right thing to do.

@david-cortes
Copy link
Contributor Author

basically, whether we should apply the inverse link function to the user-provided base margin. At the moment, it's no. But I'm wondering whether this is a right thing to do.

I do not think this is related to transformations involved in the objective, since here it happens also for outputmargin=TRUE and doesn't happen under the same settings in python.

Curiously, while the python interface does give the expected result for objective binary:logistic, it seems to suffer from the same problem as the R interface for objective reg:squarederror.

From some testing, it seems there are two issues involved here:

  • When the DMatrix has base_margin, the model doesn't add its base_score to it (but note that this would not explain the issue with the binary objective mentioned earlier, since the base_score there is not equal to the difference in predictions).
  • The base_margin is substracted from the output instead of being added.

Here's an example:

  • R
library(xgboost)
library(testthat)
data(mtcars)
y <- mtcars$mpg
x <- as.matrix(mtcars[, -1])
n_threads = 1
dm_labelled <- xgb.DMatrix(
    data = x,
    label = y,
    nthread = n_threads
)
dm_only_x <- xgb.DMatrix(
    data = x,
    nthread = n_threads
)
set.seed(123)
base_margin <- rnorm(nrow(x))
dm_w_base_margin <- xgb.DMatrix(
    data = x,
    nthread = n_threads,
    base_margin = base_margin
)
params <- list(
    objective = "reg:squarederror",
    tree_method = "hist",
    nthread = n_threads
)
model <- xgb.train(
    data = dm_labelled,
    params = params,
    nrounds = 5L
)
models_base_score <- as.numeric(jsonlite::fromJSON(xgb.config(model))$learner$learner_model_param$base_score)
pred_wo_margin <- predict(model, dm_only_x)
pred_w_margin <- predict(model, dm_w_base_margin)
expect_equal(
    pred_wo_margin,
    pred_w_margin + models_base_score - base_margin,
    tolerance = 1e-5
)
  • Python
import numpy as np, xgboost as xgb
import json
mtcars = np.array([[21,6,160,110,3.9,2.62,16.46,0,1,4,4],
[21,6,160,110,3.9,2.875,17.02,0,1,4,4],
[22.8,4,108,93,3.85,2.32,18.61,1,1,4,1],
[21.4,6,258,110,3.08,3.215,19.44,1,0,3,1],
[18.7,8,360,175,3.15,3.44,17.02,0,0,3,2],
[18.1,6,225,105,2.76,3.46,20.22,1,0,3,1],
[14.3,8,360,245,3.21,3.57,15.84,0,0,3,4],
[24.4,4,146.7,62,3.69,3.19,20,1,0,4,2],
[22.8,4,140.8,95,3.92,3.15,22.9,1,0,4,2],
[19.2,6,167.6,123,3.92,3.44,18.3,1,0,4,4],
[17.8,6,167.6,123,3.92,3.44,18.9,1,0,4,4],
[16.4,8,275.8,180,3.07,4.07,17.4,0,0,3,3],
[17.3,8,275.8,180,3.07,3.73,17.6,0,0,3,3],
[15.2,8,275.8,180,3.07,3.78,18,0,0,3,3],
[10.4,8,472,205,2.93,5.25,17.98,0,0,3,4],
[10.4,8,460,215,3,5.424,17.82,0,0,3,4],
[14.7,8,440,230,3.23,5.345,17.42,0,0,3,4],
[32.4,4,78.7,66,4.08,2.2,19.47,1,1,4,1],
[30.4,4,75.7,52,4.93,1.615,18.52,1,1,4,2],
[33.9,4,71.1,65,4.22,1.835,19.9,1,1,4,1],
[21.5,4,120.1,97,3.7,2.465,20.01,1,0,3,1],
[15.5,8,318,150,2.76,3.52,16.87,0,0,3,2],
[15.2,8,304,150,3.15,3.435,17.3,0,0,3,2],
[13.3,8,350,245,3.73,3.84,15.41,0,0,3,4],
[19.2,8,400,175,3.08,3.845,17.05,0,0,3,2],
[27.3,4,79,66,4.08,1.935,18.9,1,1,4,1],
[26,4,120.3,91,4.43,2.14,16.7,0,1,5,2],
[30.4,4,95.1,113,3.77,1.513,16.9,1,1,5,2],
[15.8,8,351,264,4.22,3.17,14.5,0,1,5,4],
[19.7,6,145,175,3.62,2.77,15.5,0,1,5,6],
[15,8,301,335,3.54,3.57,14.6,0,1,5,8],
[21.4,4,121,109,4.11,2.78,18.6,1,1,4,2]])
y = mtcars[:, 0]
X = mtcars[:, 1:]

rng = np.random.default_rng(seed=123)
base_margin = rng.standard_normal(size=X.shape[0])

dm_labelled = xgb.DMatrix(data=X, label=y)
dm_only_x = xgb.DMatrix(data=X)
dm_w_base_margin = xgb.DMatrix(data=X, base_margin=base_margin)
params = {
    "objective" : "reg:squarederror",
    "tree_method" : "hist",
}
model = xgb.train(
    dtrain=dm_labelled,
    params=params,
    num_boost_round=5
)
pred_wo_margin = model.predict(dm_only_x)
pred_w_margin = model.predict(dm_w_base_margin)
model_json = json.loads(model.save_raw("json").decode())
model_base = float(model_json["learner"]["learner_model_param"]["base_score"])

np.testing.assert_almost_equal(
    pred_wo_margin,
    pred_w_margin + model_base - base_margin,
    decimal=5
)

@david-cortes
Copy link
Contributor Author

Actually the comment above is wrong, I messed up the signs, calculation in xgboost is correct. Just doesn't include base_score.

@david-cortes
Copy link
Contributor Author

On further experiments, looks like this is not an R-specific issue. It just seems to be the case that it doesn't sum the intercept when base_margin is added.

For large enough datasets, the intercept that the model will end up using seems to be easily calculable as the constant prediction that minimizes the objective, and adding this to the prediction with margin does seem to produce the expected numbers.

library(xgboost)
data("agaricus.train")
x <- agaricus.train$data
y <- agaricus.train$label
dm_labelled <- xgb.DMatrix(
    data = x,
    label = y,
    nthread = n_threads
)
dm_only_x <- xgb.DMatrix(
    data = x
)
zeros <- rep(0, nrow(x))
dm_w_base_margin <- xgb.DMatrix(
    data = x,
    base_margin = zeros
)
params <- list(
    objective = "binary:logistic"
)
model <- xgb.train(
    data = dm_labelled,
    params = params,
    nrounds = 1L
)
pred_wo_margin <- predict(model, dm_only_x, outputmargin = TRUE)
pred_w_margin <- predict(model, dm_w_base_margin, outputmargin = TRUE)
optimal_intercept <- log(sum(y == 1)) - log(sum(y == 0))

expect_equal(pred_wo_margin - pred_w_margin, rep(optimal_intercept, nrow(x)), tolerance=1e-4)

@trivialfis
Copy link
Member

On further experiments, looks like this is not an R-specific issue. It just seems to be the case that it doesn't sum the intercept when base_margin is added.

Indeed, when a user-supplied base margin is available, the intercept is ignored. They are both global bias. I don't have a strong opinion on how they should mix. I will take a closer look today.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants