Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[doc] Brief introduction to base_score. #9882

Merged
merged 3 commits into from
Dec 17, 2023
Merged

Conversation

trivialfis
Copy link
Member

Close #9869 .

@trivialfis
Copy link
Member Author

cc @david-cortes

@david-cortes
Copy link
Contributor

Thanks for looking into this.

However, I am still feeling that the docs do not explain what base_score is or how it is obtained.

Suppose that I fit a model to the first 100 rows of the agaricus data:

library(xgboost)
data("agaricus.train")
x <- agaricus.train$data[1:100, ]
y <- agaricus.train$label[1:100]
dm_labelled <- xgb.DMatrix(data = x, label = y)
model <- xgb.train(
    data = dm_labelled,
    params = list(objective = "binary:logistic"),
    nrounds = 1L
)

In this case, 13% of the rows belong to the "positive" class:

print(mean(y))
[1] 0.13

The model's base_score as returned by the json is as follows:

model_json <- jsonlite::fromJSON(xgb.config(model))
base_score <- as.numeric(model_json$learner$learner_model_param$base_score)
print(base_score)
[1] 0.1854274

It does seem to be the case that base_margin overrides the intercept, and gets added to the raw scores as described in this PR.

In this case however, the intercept that gets added seems to correspond to what would be obtained by applying the link function to base_score, not to base_score itself:

$$ \log(\frac{0.1854274}{1 - 0.1854274}) = -1.48 $$

dm_only_x <- xgb.DMatrix(data = x)
dm_w_base_margin <- xgb.DMatrix(data = x, base_margin = rep(0, nrow(x)))
pred_wo_margin <- predict(model, dm_only_x, outputmargin = TRUE)
pred_w_margin <- predict(model, dm_w_base_margin, outputmargin = TRUE)
print(pred_wo_margin[1] - pred_w_margin[1])
[1] -1.48

Note also that the optimal intercept in this case would be a slightly different number than what the model estimated:

optimal_intercept <- log(sum(y == 1)) - log(sum(y == 0))
print(optimal_intercept)
[1] -1.900959

@david-cortes
Copy link
Contributor

Thanks again for looking into it.

Would be ideal if you could also update the docs for base_score in the parameters section.

There's still a few other issues that could be worth mentioning in this doc:

  • What's the behavior if the user supplies base_score as a training parameter AND passes a DMatrix with base_margin to xgb.train? It seems it then keeps the base_score but doesn't apply it to that training data.
  • While base_score is restricted to a single number, can base_margin have multiple dimensions if e.g. one is using a multi-output model? It seems to work as expected in python, so I'm thinking of updating the R interface to allow for it.

I am also thinking: if the current behavior around base_score is due to the old binary model serialization format, shouldn't it be possible to introduce a new field in the model, say "intercept", which would be auto-filled from "base_score" if not present in a serialized model, and otherwise be set to the actual intercept and allow per-output intercepts for e.g. multi-output and multinomial logistic?

@trivialfis
Copy link
Member Author

trivialfis commented Dec 13, 2023

What's the behavior if the user supplies base_score as a training parameter AND passes a DMatrix with base_margin to xgb.train? It seems it then keeps the base_score but doesn't apply it to that training data.

The base margin overrides the base score, but this is not "by design", it's more like an uncaught user error.

While base_score is restricted to a single number, can base_margin have multiple dimensions if e.g. one is using a multi-output model?

Yes, will update the doc.

if the current behavior around base_score is due to the old binary model serialization format

Yes, it's limited by the binary model serialization and we can't use a vector value.

which would be auto-filled from "base_score" if not present in a serialized model, and otherwise be set to the actual intercept and allow per-output intercepts for e.g. multi-output and multinomial logistic?

I thought about it, but the current serialization code is just messy and I don't want to modify the binary model format in any way, considering that we are deprecating it anyway.

@trivialfis
Copy link
Member Author

Also as a reminder, I should make a specialization for binary classification to use frequency instead of Newton.

@david-cortes
Copy link
Contributor

Thanks for the info. I guess that clarifies the behavior of the intercept around the logistic objective.

I'm still curious about this method of estimation and the usages though.

Suppose now that the objective is count:poisson, and I create a small example where there's only an intercept:

import numpy as np
import xgboost as xgb
rng = np.random.default_rng(seed=123)
X = rng.standard_normal(size=(100,10))
y = 1 + rng.poisson(10, size=100)
model = xgb.train(
    dtrain=xgb.DMatrix(X, label=y),
    params={"objective" : "count:poisson", "gamma" : 1e4},
    num_boost_round=1,
)
pred = model.predict(xgb.DMatrix(X))
import json
base_score = json.loads(
    model.save_raw("json").decode()
)["learner"]["learner_model_param"]["base_score"]

In this case, we have:

np.mean(y)
10.74

I understand that here the function should work as

$$ \hat{y} = \exp(F) $$

Hence, base_score should be mean(y) and the intercept should be the logarithm of that (inverse of the link function), but:

float(base_score)
126.05785
np.log(float(base_score))
4.8367409285690615

Then, the predictions are constant, but they do not seem to correspond to the intercept, nor to be close to the actual values of y:

pred[0]
109.99812
model.predict(xgb.DMatrix(X), output_margin=True)[0]
4.7004633

What's happening in this case?

@trivialfis
Copy link
Member Author

The poisson case is a result of poor Newton optimization. We use 0 as the starting point, which has a significant error compared to the actual mean.

@david-cortes
Copy link
Contributor

david-cortes commented Dec 14, 2023

But if that was the actual intercept, even if off by some amount, shouldn't the predictions be equal to base_score, since the model doesn't have any further conditions adding towards it?

Also, if one applies one-step Newton starting from zero, even then, the results would not quite match with the earlier base_score:

import numpy as np
rng = np.random.default_rng(seed=123)
X = rng.standard_normal(size=(100,10))
y = 1 + rng.poisson(10, size=100)

def grad_poisson(x):
    return np.exp(x) - y

def hess_poisson(x):
    return np.exp(x)

x0 = np.zeros_like(y)
x1_newton = -grad_poisson(x0).sum() / hess_poisson(x0).sum()
print(x1_newton)
9.74

(this is also very off from the optimal of 2.37 (=log(mean(y))), which takes ~12 newton steps here to get to)

By the way, if you are applying the inverse of the link function to base_score, then the optimal base_score should be the mean of y whenever that link function is what GLM theory calls a "cannonical" link function for the given likelihood.

That applies to: reg:squarederror, reg:gamma, binary:logistic, count:poisson. I'm not 100% sure but I think it also applies to tweedie.

@trivialfis
Copy link
Member Author

Let me take another look tomorrow, my brain is fried today.

@trivialfis
Copy link
Member Author

trivialfis commented Dec 15, 2023

But if that was the actual intercept, even if off by some amount, shouldn't the predictions be equal to base_score, since the model doesn't have any further conditions adding towards it?

There's still an extra tree stump after the base_score since there's at least one iteration being fitted. Hopefully, the script helps illustrate:

import json

import numpy as np
import xgboost as xgb

rng = np.random.default_rng(seed=123)
X = rng.standard_normal(size=(100, 10))
y = 1 + rng.poisson(10, size=100)
Xy = xgb.DMatrix(X, label=y)

model = xgb.train(
    dtrain=Xy,
    params={"objective": "count:poisson", "min_split_loss": 1e5},  # no split
    num_boost_round=1,  # but still has one tree stump
    evals=[(Xy, "train")],
)

config = json.loads(model.save_config())["learner"]["learner_model_param"]
intercept = float(config["base_score"])
print("intercept:", intercept)

jmodel = json.loads(str(model.save_raw(raw_format="json"), "ascii"))
jtrees = jmodel["learner"]["gradient_booster"]["model"]["trees"]
assert len(jtrees) == 1
# split_condition actually means weight for leaf, and split threshold for internal split
# nodes
assert len(jtrees[0]["split_conditions"]) == 1
leaf_weight = jtrees[0]["split_conditions"][0]
print(leaf_weight)  # raw prediction

output_margin = np.log(intercept) + leaf_weight
predt_0 = np.exp(output_margin)

print(np.log(intercept))
print(predt_0)

predt_1 = model.predict(xgb.DMatrix(X))
print(predt_1[0])
np.testing.assert_allclose(predt_0, predt_1)

Also, if one applies one-step Newton starting from zero, even then, the results would not quite match with the earlier

That's caused by an old parameter to control step size, which I want to remove but did not get the approval: #7267

@RAMitchell would be great if we could revisit that PR.

By the way, if you are applying the inverse of the link function to base_score, then the optimal base_score should be the mean of y whenever that link function is what GLM theory calls a "cannonical" link function for the given likelihood.

Thank you for sharing! I will learn more about it. On the other hand, I think the canonical link for gamma is the inverse link, even though the log link is often used in practice. Also, unit deviance is used to derive the objective, instead of the log-likelihood. Not sure if the mean value still applies.

@david-cortes
Copy link
Contributor

Thanks for explaining the situation here.

It still leaves me wondering about one last detail though: in the case above, is it that the tree had no split conditions and it created a terminal node to conform to a tree structure, or is it that trees can have their own intercept regardless of whether they have subnodes or not?


About gamma objective: you are right, my mistake - it's not using the cannonical link.

That being said, I am fairly certain that the optimal intercept is still equal to the logarithm of the mean.

I haven't checked any proofs but here's a small example to support the claim:

import numpy as np
from sklearn.metrics import mean_gamma_deviance
from scipy.optimize import minimize

rng = np.random.default_rng(seed=123)
y = rng.standard_gamma(5, size=100)

def fx(c):
    yhat = np.exp(np.repeat(c, y.shape[0]))
    return mean_gamma_deviance(y, yhat)

x0 = 0.
res = minimize(fx, x0)["x"]
np.testing.assert_almost_equal(res, np.log(y.mean()))

@trivialfis
Copy link
Member Author

no split conditions and it created a terminal node to conform to a tree structure

Yes, a root node as a leaf.

That being said, I am fairly certain that the optimal intercept is still equal to the logarithm of the mean.

I think you are right, I will look into it.

@trivialfis
Copy link
Member Author

Open an issue to keep track of the status #9899

@trivialfis
Copy link
Member Author

Is there anything I need to modify from the documentation s perspective?

@david-cortes
Copy link
Contributor

Is there anything I need to modify from the documentation s perspective?

I cannot think of anything that'd be missing after all the last changes. Thanks for this addition. Would be ideal to merge now.

@trivialfis trivialfis merged commit 0edd600 into dmlc:master Dec 17, 2023
24 of 28 checks passed
@trivialfis trivialfis deleted the doc-link branch December 17, 2023 05:34
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[R] Predictions with base margin are off by constant number
2 participants