-
Notifications
You must be signed in to change notification settings - Fork 5.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[train] RayTrainReportCallback
should only save a checkpoint on rank 0 for xgboost/lightgbm
#45083
[train] RayTrainReportCallback
should only save a checkpoint on rank 0 for xgboost/lightgbm
#45083
Conversation
Signed-off-by: Justin Yu <[email protected]>
Signed-off-by: Justin Yu <[email protected]>
Signed-off-by: Justin Yu <[email protected]>
…lgb_rank0_only_ckpt
Signed-off-by: Justin Yu <[email protected]>
RayTrainReportCallback
should only save a checkpoint on rank 0 for xgboost/lightgbm
callback = RayTrainReportCallback(frequency=2, checkpoint_at_end=True) | ||
|
||
booster = mock.MagicMock() | ||
|
||
with mock.patch("ray.train.report") as mock_report, mock.patch( | ||
"ray.train.get_context" | ||
) as mock_get_context: | ||
mock_context = mock.MagicMock() | ||
mock_context.get_world_rank.return_value = rank | ||
mock_get_context.return_value = mock_context | ||
|
||
booster.num_boosted_rounds.return_value = 2 | ||
callback.after_iteration(booster, epoch=1, evals_log={}) | ||
|
||
# Only rank 0 should report based on `frequency` | ||
reported_checkpoint = bool(mock_report.call_args.kwargs.get("checkpoint")) | ||
if rank == 0: | ||
assert reported_checkpoint | ||
else: | ||
assert not reported_checkpoint | ||
|
||
booster.num_boosted_rounds.return_value = 3 | ||
callback.after_iteration(booster, epoch=2, evals_log={}) | ||
# Nobody should report a checkpoint on iterations | ||
reported_checkpoint = bool(mock_report.call_args.kwargs.get("checkpoint")) | ||
assert not reported_checkpoint |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this test unnecessarily complicated? Any better ways to do this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just went through some MagicMock docs. I think this unit-testing function looks good too me :)
For simplicity, probably we can just test the class variable RayTrainReportCallback._last_checkpoint_iteration
after the function callback.after_iteration()
is called. In this way, we can reduce the mock of ray.train.report
. Then the code from line 154 ~ 169 could be simplified as:
with mock.patch("ray.train.get_context") as mock_get_context:
mock_context = mock.MagicMock()
mock_context.get_world_rank.return_value = rank
mock_get_context.return_value = mock_context
booster.num_boosted_rounds.return_value = 2
callback.after_iteration(booster, epoch=1, evals_log={})
assertEqual(callback._last_checkpoint_iteration, 1) # same as epoch number
However, this test assumes ray.train.report
and RayTrainReportCallback._get_checkpoint()
always works. It's also a little indirect. Correct me if I am wrong :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're right, the test was a little too indirect and I was able to find a similar way to simplify it. PTAL
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice simplification! Learnt a lot from the code 👍
callback = RayTrainReportCallback(frequency=2, checkpoint_at_end=True) | ||
|
||
booster = mock.MagicMock() | ||
|
||
with mock.patch("ray.train.report") as mock_report, mock.patch( | ||
"ray.train.get_context" | ||
) as mock_get_context: | ||
mock_context = mock.MagicMock() | ||
mock_context.get_world_rank.return_value = rank | ||
mock_get_context.return_value = mock_context | ||
|
||
booster.num_boosted_rounds.return_value = 2 | ||
callback.after_iteration(booster, epoch=1, evals_log={}) | ||
|
||
# Only rank 0 should report based on `frequency` | ||
reported_checkpoint = bool(mock_report.call_args.kwargs.get("checkpoint")) | ||
if rank == 0: | ||
assert reported_checkpoint | ||
else: | ||
assert not reported_checkpoint | ||
|
||
booster.num_boosted_rounds.return_value = 3 | ||
callback.after_iteration(booster, epoch=2, evals_log={}) | ||
# Nobody should report a checkpoint on iterations | ||
reported_checkpoint = bool(mock_report.call_args.kwargs.get("checkpoint")) | ||
assert not reported_checkpoint |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just went through some MagicMock docs. I think this unit-testing function looks good too me :)
For simplicity, probably we can just test the class variable RayTrainReportCallback._last_checkpoint_iteration
after the function callback.after_iteration()
is called. In this way, we can reduce the mock of ray.train.report
. Then the code from line 154 ~ 169 could be simplified as:
with mock.patch("ray.train.get_context") as mock_get_context:
mock_context = mock.MagicMock()
mock_context.get_world_rank.return_value = rank
mock_get_context.return_value = mock_context
booster.num_boosted_rounds.return_value = 2
callback.after_iteration(booster, epoch=1, evals_log={})
assertEqual(callback._last_checkpoint_iteration, 1) # same as epoch number
However, this test assumes ray.train.report
and RayTrainReportCallback._get_checkpoint()
always works. It's also a little indirect. Correct me if I am wrong :)
…lgb_rank0_only_ckpt
Signed-off-by: Justin Yu <[email protected]>
Signed-off-by: Justin Yu <[email protected]>
Signed-off-by: Justin Yu <[email protected]>
Signed-off-by: Justin Yu <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
…k 0 for xgboost/lightgbm (ray-project#45083) This PR adds a condition to only save and report a checkpoint on the rank 0 worker for xgboost and lightgbm. This prevents unnecessary checkpoints being created, since all data parallel workers have the same model states. Note: this also accounts for usage in Tune, where `ray.train.get_context().get_world_rank()` returns `None`. Fix `checkpoint_at_end` for the xgboost callback to avoid duplicate checkpoints. --------- Signed-off-by: Justin Yu <[email protected]> Co-authored-by: Hongpeng Guo <[email protected]>
Why are these changes needed?
This PR adds a condition to only save and report a checkpoint on the rank 0 worker for xgboost and lightgbm. This prevents unnecessary checkpoints being created, since all data parallel workers have the same model states. Note: this also accounts for usage in Tune, where
ray.train.get_context().get_world_rank()
returnsNone
.This also includes a drive-by fix for
checkpoint_at_end
in the xgboost callback. Now, we no longer do a separatecheckpoint_at_end
if the checkpoint frequency happens to line up with the last iteration. For example, if saving every 5 iterations:[iter] 0 1 2 3 4 (checkpoint) 5 6 7 8 9 (checkpoint) (checkpoint)
, we no longer have this "duplicate" checkpoint at the end after this fix.Related issue number
Reported on slack
Checks
git commit -s
) in this PR.scripts/format.sh
to lint the changes in this PR.method in Tune, I've added it in
doc/source/tune/api/
under thecorresponding
.rst
file.