Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[R] Best iteration index from early stopping is discarded when model is saved to disk #5209

Closed
DavorJ opened this issue Jan 15, 2020 · 33 comments · Fixed by #5573
Closed

[R] Best iteration index from early stopping is discarded when model is saved to disk #5209

DavorJ opened this issue Jan 15, 2020 · 33 comments · Fixed by #5573
Assignees

Comments

@DavorJ
Copy link

DavorJ commented Jan 15, 2020

These values are predicted after xgboost::xgb.train:
247367.2 258693.3 149572.2 201675.8 250493.9 292349.2 414828.0 296503.2 260851.9 190413.3

These values are predicted after xgboost::xgb.save and xgboost::xgb.load of the previous model:
247508.8 258658.2 149252.1 201692.6 250458.1 292313.4 414787.2 296462.5 260879.0 190430.1

They are close, but not the same. The differences between these two predictions range from -1317.094 to 1088.859 on a set of 25k samples. When comparing with true labels, then the MAE/RMSE of these two predictions do not differ much.

So I suspect that this has to do with rounding errors during load/save since the MAE/RMSE do not differ as much. Still, I find this strange since binary storing the model should not introduce rounding errors?

Anyone a clue?

PS Uploading and documenting the training process seems not important to me here. I could provide details if necessary, or make a simulation with dummy data to prove the point.

@trivialfis
Copy link
Member

There shouldn't be any rounding error for both binary or json. Are you using dart?

@DavorJ
Copy link
Author

DavorJ commented Jan 16, 2020

No, I am not:

params <- list(objective = 'reg:squarederror',
               max_depth = 10, eta = 0.02, subsammple = 0.5,
               base_score = median(xgboost::getinfo(xgb.train, 'label'))
)

xgboost::xgb.train(
  params = params, data = xgb.train,
  watchlist = list('train' = xgb.train, 'test' = xgb.test),
  nrounds = 10000, verbose = TRUE, print_every_n = 25,
  eval_metric = 'mae',
  early_stopping_rounds = 3, maximize = FALSE)

@hcho3
Copy link
Collaborator

hcho3 commented Jan 16, 2020

Can you provide us with dummy data where this phenomenon occurs?

@DavorJ
Copy link
Author

DavorJ commented Jan 16, 2020

Here you go (Quick & Dirty):

N <- 100000
set.seed(2020)
X <- data.frame('X1' = rnorm(N), 'X2' = runif(N), 'X3' = rpois(N, lambda = 1))
Y <- with(X, X1 + X2 - X3 + X1*X2^2 - ifelse(X1 > 0, 2, X3))

params <- list(objective = 'reg:squarederror',
               max_depth = 5, eta = 0.02, subsammple = 0.5,
               base_score = median(Y)
)

dtrain <- xgboost::xgb.DMatrix(data = data.matrix(X), label = Y)

fit <- xgboost::xgb.train(
  params = params, data = dtrain,
  watchlist = list('train' = dtrain),
  nrounds = 10000, verbose = TRUE, print_every_n = 25,
  eval_metric = 'mae',
  early_stopping_rounds = 3, maximize = FALSE
)

pred <- stats::predict(fit, newdata = dtrain)

xgboost::xgb.save(fit, 'booster.raw')
fit.loaded <- xgboost::xgb.load('booster.raw')

pred.loaded <- stats::predict(fit.loaded, newdata = dtrain)

identical(pred, pred.loaded)
pred[1:10]
pred.loaded[1:10]

sqrt(mean((Y - pred)^2))
sqrt(mean((Y - pred.loaded)^2))

On my machine, identical(pred, pred.loaded) is FALSE (i.e. should be TRUE). Here is the output of the last commands:

> identical(pred, pred.loaded)
[1] FALSE
> pred[1:10]
 [1] -4.7971768 -2.5070562 -0.8889422 -4.9199696 -4.4374819 -0.2739395 -0.9825708  0.4579227  1.3667605 -4.3333349
> pred.loaded[1:10]
 [1] -4.7971768 -2.5070562 -0.8889424 -4.9199696 -4.4373770 -0.2739397 -0.9825710  0.4579227  1.3667605 -4.3333349
> 
> sqrt(mean((Y - pred)^2))
[1] 0.02890702
> sqrt(mean((Y - pred.loaded)^2))
[1] 0.02890565

You see that the predictions sometimes differ slightly. Can you rerun the example code on your machine and see whether it has the same problem?

Some extra info about R and xgboost:

> sessionInfo()
R version 3.6.1 (2019-07-05)
Platform: x86_64-w64-mingw32/x64 (64-bit)
Running under: Windows >= 8 x64 (build 9200)

Matrix products: default

locale:
[1] LC_COLLATE=English_United States.1252  LC_CTYPE=English_United States.1252    LC_MONETARY=English_United States.1252 LC_NUMERIC=C                          
[5] LC_TIME=English_United States.1252    

attached base packages:
[1] stats     graphics  grDevices utils     datasets  methods   base     

loaded via a namespace (and not attached):
 [1] compiler_3.6.1    magrittr_1.5      Matrix_1.2-17     tools_3.6.1       yaml_2.2.0        xgboost_0.90.0.2  stringi_1.4.3     grid_3.6.1       
 [9] data.table_1.12.4 lattice_0.20-38 

Also note that:

> identical(fit$raw, fit.loaded$raw)
[1] TRUE

@trivialfis
Copy link
Member

trivialfis commented Jan 16, 2020

Thanks for the script. Just an update, I ran it with both saving to json and binary file with:

xgboost::xgb.save(fit, 'booster.json')
fit.loaded <- xgboost::xgb.load('booster.json')

xgboost::xgb.save(fit.loaded, 'booster-1.json')

The hash values (via sha256sum ./booster.json) of booster.json and booster-1.json are exactly the same, so my guess is somewhere there's discrepancy caused by floating point arithmetic.

@DavorJ
Copy link
Author

DavorJ commented Feb 5, 2020

Why close the issue without knowing the cause?

@hcho3 hcho3 reopened this Feb 5, 2020
@hcho3
Copy link
Collaborator

hcho3 commented Feb 5, 2020

@trivialfis Did you get True for identical(pred, pred.loaded)? The OP is asking why the predictions don’t match, even though two models have same binary signature.

@hcho3 hcho3 self-assigned this Feb 5, 2020
@hcho3
Copy link
Collaborator

hcho3 commented Feb 5, 2020

I’ll try to reproduce it myself.

@trivialfis
Copy link
Member

trivialfis commented Feb 5, 2020

Oh, sorry. The cause I found is the prediction cache. After loading the model, prediction values come from true prediction, instead of cached value:

so my guess is somewhere there's discrepancy caused by floating point arithmetic.

@hcho3
Copy link
Collaborator

hcho3 commented Feb 5, 2020

So prediction cache interacts with floating-point arithmetic in a destructive way?

@trivialfis
Copy link
Member

trivialfis commented Feb 5, 2020

@hcho3 It's a problem I found during implementing the new pickling method. I believe it plays a major role in here. So first reduce the number of trees down to 1000 (which is still pretty huge and should be enough for demo).

Re-construct the DMatrix before prediction to get the cache out of the way:

dtrain_2 <- xgboost::xgb.DMatrix(data = data.matrix(X), label = Y)

pred <- stats::predict(fit, newdata = dtrain_2)

It will pass the identical test. Otherwise it fails.

Getting in more trees the identical test still has small differences (1e-7 for 2000 trees). But do we need to produce bit by bit identical result even in multi-threaded environment?

@trivialfis
Copy link
Member

trivialfis commented Feb 5, 2020

As the floating point summation is not associative, we can make it as a to do item to have strong guarantee for order of computation, if that's desired.

@trivialfis
Copy link
Member

Actually making strong guarantee for order won't work (will help a lot but still there will be discrepancy). A floating point in CPU FPU register can have higher precision then stored back to memory. (Hardware implementation can use higher precision for intermedia values, https://en.wikipedia.org/wiki/Extended_precision). My point is when the result for 1000 trees are exactly reproducible within 32 bit float, it's unlikely a programming bug.

@hcho3
Copy link
Collaborator

hcho3 commented Feb 5, 2020

I agree that floating-point summation is not associative. I will run the script myself and see if the difference is small enough to attribute to the floating-point arithmetic.

In general, I usually use np.testing.assert_almost_equal with decimal=5 to test whether two float arrays are almost equal to each other.

@trivialfis
Copy link
Member

Yup. Apologies for closing without detailed notes.

@trivialfis
Copy link
Member

@hcho3 Any update?

@hcho3
Copy link
Collaborator

hcho3 commented Apr 15, 2020

I haven’t gotten around it yet. Let me take a look this week.

@hcho3
Copy link
Collaborator

hcho3 commented Apr 20, 2020

@trivialfis I managed to reproduce the bug. I ran the provided script and got FALSE for identical(pred, pred.loaded). I tried creating a new DMatrix dtrain_2 as you suggested and still got FALSE for the test.

Output from @DavorJ's script:

[1] FALSE     # identical(pred, pred.loaded)
 [1] -4.7760534 -2.5083885 -0.8860036 -4.9163256 -4.4455137 -0.2548684
 [7] -0.9745615  0.4646015  1.3602829 -4.3288369     # pred[1:10]
 [1] -4.7760534 -2.5083888 -0.8860038 -4.9163256 -4.4454765 -0.2548686
 [7] -0.9745617  0.4646015  1.3602829 -4.3288369     # pred.loaded[1:10]
[1] 0.02456085   # MSE on pred
[1] 0.02455945   # MSE on pred.loaded

Output from modified script, with dtrain_2 <- xgboost::xgb.DMatrix(data = data.matrix(X), label = Y):

[1] FALSE     # identical(pred, pred.loaded)
 [1] -4.7760534 -2.5083885 -0.8860036 -4.9163256 -4.4455137 -0.2548684
 [7] -0.9745615  0.4646015  1.3602829 -4.3288369     # pred[1:10]
 [1] -4.7760534 -2.5083888 -0.8860038 -4.9163256 -4.4454765 -0.2548686
 [7] -0.9745617  0.4646015  1.3602829 -4.3288369     # pred.loaded[1:10]
[1] 0.02456085   # MSE on pred
[1] 0.02455945   # MSE on pred.loaded

So something else must be going on.

I also tried running a round-trip test:

xgboost::xgb.save(fit, 'booster.raw')
fit.loaded <- xgboost::xgb.load('booster.raw')
xgboost::xgb.save(fit.loaded, 'booster.raw.roundtrip')

and the two binary files booster.raw and booster.raw.roundtrip were identical.

@hcho3
Copy link
Collaborator

hcho3 commented Apr 20, 2020

Max diff between pred and pred.loaded is 0.0008370876.

@hcho3
Copy link
Collaborator

hcho3 commented Apr 20, 2020

A smaller example that runs faster:

library(xgboost)

N <- 5000
set.seed(2020)
X <- data.frame('X1' = rnorm(N), 'X2' = runif(N), 'X3' = rpois(N, lambda = 1))
Y <- with(X, X1 + X2 - X3 + X1*X2^2 - ifelse(X1 > 0, 2, X3))

params <- list(objective = 'reg:squarederror',
               max_depth = 5, eta = 0.02, subsammple = 0.5,
               base_score = median(Y)
)

dtrain <- xgboost::xgb.DMatrix(data = data.matrix(X), label = Y)

fit <- xgboost::xgb.train(
  params = params, data = dtrain,
  watchlist = list('train' = dtrain),
  nrounds = 10000, verbose = TRUE, print_every_n = 25,
  eval_metric = 'mae',
  early_stopping_rounds = 3, maximize = FALSE
)

pred <- stats::predict(fit, newdata = dtrain)

invisible(xgboost::xgb.save(fit, 'booster.raw'))
fit.loaded <- xgboost::xgb.load('booster.raw')
invisible(xgboost::xgb.save(fit.loaded, 'booster.raw.roundtrip'))

pred.loaded <- stats::predict(fit.loaded, newdata = dtrain)

identical(pred, pred.loaded)
pred[1:10]
pred.loaded[1:10]
max(abs(pred - pred.loaded))

sqrt(mean((Y - pred)^2))
sqrt(mean((Y - pred.loaded)^2))

Output:

[1] FALSE
 [1] -2.4875379 -0.9452241 -6.9658904 -2.9985323 -4.2192593 -0.8505422
 [7] -0.3928839 -1.6886091 -1.3611379 -3.1278882
 [1] -2.4875379 -0.9452239 -6.9658904 -2.9985323 -4.2192593 -0.8505420
 [7] -0.3928837 -1.6886090 -1.3611377 -3.1278882
[1] 0.0001592636
[1] 0.01370754
[1] 0.01370706

@hcho3
Copy link
Collaborator

hcho3 commented Apr 20, 2020

Just tried doing one extra round-trip, and now predictions do not change any more.

library(xgboost)

N <- 5000
set.seed(2020)
X <- data.frame('X1' = rnorm(N), 'X2' = runif(N), 'X3' = rpois(N, lambda = 1))
Y <- with(X, X1 + X2 - X3 + X1*X2^2 - ifelse(X1 > 0, 2, X3))

params <- list(objective = 'reg:squarederror',
               max_depth = 5, eta = 0.02, subsammple = 0.5,
               base_score = median(Y)
)

dtrain <- xgboost::xgb.DMatrix(data = data.matrix(X), label = Y)

fit <- xgboost::xgb.train(
  params = params, data = dtrain,
  watchlist = list('train' = dtrain),
  nrounds = 10000, verbose = TRUE, print_every_n = 25,
  eval_metric = 'mae',
  early_stopping_rounds = 3, maximize = FALSE
)

pred <- stats::predict(fit, newdata = dtrain)

invisible(xgboost::xgb.save(fit, 'booster.raw'))
fit.loaded <- xgboost::xgb.load('booster.raw')
invisible(xgboost::xgb.save(fit.loaded, 'booster.raw.roundtrip'))
fit.loaded2 <- xgboost::xgb.load('booster.raw.roundtrip')

pred.loaded <- stats::predict(fit.loaded, newdata = dtrain)
pred.loaded2 <- stats::predict(fit.loaded2, newdata = dtrain)

identical(pred, pred.loaded)
identical(pred.loaded, pred.loaded2)
pred[1:10]
pred.loaded[1:10]
pred.loaded2[1:10]
max(abs(pred - pred.loaded))
max(abs(pred.loaded - pred.loaded2))

sqrt(mean((Y - pred)^2))
sqrt(mean((Y - pred.loaded)^2))
sqrt(mean((Y - pred.loaded2)^2))

Result:

[1] FALSE
[1] TRUE
 [1] -2.4875379 -0.9452241 -6.9658904 -2.9985323 -4.2192593 -0.8505422
 [7] -0.3928839 -1.6886091 -1.3611379 -3.1278882
 [1] -2.4875379 -0.9452239 -6.9658904 -2.9985323 -4.2192593 -0.8505420
 [7] -0.3928837 -1.6886090 -1.3611377 -3.1278882
 [1] -2.4875379 -0.9452239 -6.9658904 -2.9985323 -4.2192593 -0.8505420
 [7] -0.3928837 -1.6886090 -1.3611377 -3.1278882
[1] 0.0001592636
[1] 0
[1] 0.01370754
[1] 0.01370706
[1] 0.01370706

So maybe prediction cache is indeed a problem.

@hcho3
Copy link
Collaborator

hcho3 commented Apr 20, 2020

I re-ran the script with prediction caching disabled:

diff --git a/src/predictor/cpu_predictor.cc b/src/predictor/cpu_predictor.cc
index ebc15128..c40309bc 100644
--- a/src/predictor/cpu_predictor.cc
+++ b/src/predictor/cpu_predictor.cc
@@ -259,7 +259,7 @@ class CPUPredictor : public Predictor {
     // delta means {size of forest} * {number of newly accumulated layers}
     uint32_t delta = end_version - beg_version;
     CHECK_LE(delta, model.trees.size());
-    predts->Update(delta);
+    //predts->Update(delta);

     CHECK(out_preds->Size() == output_groups * dmat->Info().num_row_ ||
           out_preds->Size() == dmat->Info().num_row_);

(Disabling prediction caching results in very slow training.)

Output:

[1] FALSE
[1] TRUE
 [1] -2.4908853 -0.9507379 -6.9615889 -2.9935317 -4.2165089 -0.8543566
 [7] -0.3940181 -1.6930715 -1.3572118 -3.1403396
 [1] -2.4908853 -0.9507380 -6.9615889 -2.9935317 -4.2165089 -0.8543567
 [7] -0.3940183 -1.6930716 -1.3572119 -3.1403399
 [1] -2.4908853 -0.9507380 -6.9615889 -2.9935317 -4.2165089 -0.8543567
 [7] -0.3940183 -1.6930716 -1.3572119 -3.1403399
[1] 0.0001471043
[1] 0
[1] 0.01284297
[1] 0.01284252
[1] 0.01284252

So prediction cache is definitely NOT the cause of this bug.

@hcho3
Copy link
Collaborator

hcho3 commented Apr 20, 2020

Leaf predictions diverge too:

invisible(xgboost::xgb.save(fit, 'booster.raw'))
fit.loaded <- xgboost::xgb.load('booster.raw')
invisible(xgboost::xgb.save(fit.loaded, 'booster.raw.roundtrip'))
fit.loaded2 <- xgboost::xgb.load('booster.raw.roundtrip')

x <- predict(fit, newdata = dtrain2, predleaf = TRUE)
x2 <- predict(fit.loaded, newdata = dtrain2, predleaf = TRUE)
x3 <- predict(fit.loaded2, newdata = dtrain2, predleaf = TRUE)

identical(x, x2)
identical(x2, x3)

Output:

[1] FALSE
[1] TRUE

@hcho3
Copy link
Collaborator

hcho3 commented Apr 20, 2020

Mystery solved. I identified the true cause. When the model is saved to disk, information about early stopping is discarded. In the example, XGBoost runs 6381 boosting rounds and find the best model at 6378 rounds. The model object in memory contains 6381 trees, not 6378 trees, since no tree is removed. There is an extra field best_iteration that remembers which iteration was the best:

> fit$best_iteration
[1] 6378

This extra field is silently discarded when we save the model to disk. So predict() with the original model uses 6378 trees, whereas predict() with the recovered model uses 6381 trees.

> x <- predict(fit, newdata = dtrain2, predleaf = TRUE)
> x2 <- predict(fit.loaded, newdata = dtrain2, predleaf = TRUE)
> dim(x)
[1] 5000 6378
> dim(x2)
[1] 5000 6381

@hcho3 hcho3 changed the title Information loss after xgboost::xgb.save and xgboost::xgb.load [R] Best iteration index from early stopping is discarded when model is saved to disk Apr 20, 2020
@hcho3 hcho3 mentioned this issue Apr 20, 2020
12 tasks
@hcho3
Copy link
Collaborator

hcho3 commented Apr 20, 2020

@trivialfis I am inclined to physically remove trees. If training stopped at 6381 rounds and the best iteration was at 6378 rounds, users will expect the final model to have 6378 trees.

@trivialfis
Copy link
Member

@hcho3 I think it's a similar issue in #4052 .

The bset_iteration should be saved into Learner::attributes_, which can be accessed through xgboost::xgb.attr.

@DavorJ
Copy link
Author

DavorJ commented Apr 20, 2020

@hcho3, nice find!

Note also the documentation of xgboost:::predict.xgb.Booster():

image

If I understand correctly, the documentation is not entirely correct? Based on the documentation, I was expecting that predicting already used all trees. Unfortunately I had not verified this.

@hcho3
Copy link
Collaborator

hcho3 commented Apr 20, 2020

@DavorJ When early stopping is activated, predict() will use best_iteration field to obtain prediction.

@hcho3
Copy link
Collaborator

hcho3 commented Apr 20, 2020

@trivialfis Situation is worse on the Python side, as xgb.predict() will not use information from early stopping at all:

import xgboost as xgb
import numpy as np
from sklearn.datasets import load_boston
from sklearn.model_selection import train_test_split

X, y = load_boston(return_X_y=True)

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

dtrain = xgb.DMatrix(X_train, label=y_train)
dtest = xgb.DMatrix(X_test, label=y_test)

params = {'objective': 'reg:squarederror'}

bst = xgb.train(params, dtrain, 100, [(dtrain, 'train'), (dtest, 'test')],
                early_stopping_rounds=5)

x = bst.predict(dtrain, pred_leaf=True)
x2 = bst.predict(dtrain, pred_leaf=True, ntree_limit=bst.best_iteration)
print(x.shape)
print(x2.shape)

pred = bst.predict(dtrain)
pred2 = bst.predict(dtrain, ntree_limit=bst.best_iteration)

print(np.max(np.abs(pred - pred2)))

Output:

Will train until test-rmse hasn't improved in 5 rounds.
[1]     train-rmse:12.50316     test-rmse:11.92709
...
[25]    train-rmse:0.56720      test-rmse:2.56874
[26]    train-rmse:0.54151      test-rmse:2.56722
[27]    train-rmse:0.51842      test-rmse:2.56124
[28]    train-rmse:0.47489      test-rmse:2.56640
[29]    train-rmse:0.45489      test-rmse:2.58780
[30]    train-rmse:0.43093      test-rmse:2.59385
[31]    train-rmse:0.41865      test-rmse:2.59364
[32]    train-rmse:0.40823      test-rmse:2.59465
Stopping. Best iteration:
[27]    train-rmse:0.51842      test-rmse:2.56124
(404, 33)
(404, 27)
0.81269073

Users will have to remember to fetch bst.best_iteration and pass it as ntree_limit argument when calling predict(). This is error-prone and makes an unpleasant surprise.

We have two options for a fix:

  1. Physically delete trees that are past best_iteration.
  2. Retain best_iteration information when serializing model, and have predict() function use it.

@trivialfis
Copy link
Member

trivialfis commented Apr 20, 2020

@hcho3 I have a half baked idea about this, which is also related to our process_type = update option, and forest.

Background

For a short recap of the issues we have with update, if num_boost_round used with update is lesser than number of already existing trees, those trees that are not updated will be removed.

For a short introduction of issues with forest, best_iteration doesn't apply to forest as predict function requires specific number of trees instead of iteration, so on Python there's something called best_ntree_limit, which is very confusing to me. I explicitly replaced ntree_limit in inplace_predict with iteration_range to avoid this attribute.

Idea

I want to add a slice and a concat method to booster, which extract the trees into 2 models and concatenate trees from 2 models into 1. If we have these 2 methods:

  • base_margin_ is no longer needed and I believe it's more intuitive for other users.
  • ntree_limit in prediction is no-longer needed, we just slice the model and run prediction on the slices.
  • update process is self-contained, just update the trees in slices in one go, no num_boost_rounds.

Further

Also I believe this is somehow connected to multi target trees. As if we can support multi-class multi-target trees in the future, there will be multiple ways to arrange trees, like using output_groups for each class, or each target, pairing with forest and vector leaf. ntree_limit is not going to be sufficient.

Also #5531 .

@trivialfis
Copy link
Member

trivialfis commented Apr 20, 2020

But the idea is quite early so I didn't have the confidence to share it, now we are on this issue, maybe I can get some inputs about this.

@JohnZed
Copy link
Contributor

JohnZed commented Apr 21, 2020

Given the 1.1 timeline, can we expand the documentation to clarify how users need to manually capture and use this best iteration in prediction?
And add it to known issues in the release notes?

@RAMitchell
Copy link
Member

@trivialfis sounds interesting, so long as we are not further complicating configuration issues by doing this.

Deleting the extra trees from the model as suggested by @hcho3 is appealing as we don't have to deal with any inconsistencies from having an actual model length and a theoretical model length at the same time.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants