Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[train] RayTrainReportCallback should only save a checkpoint on rank 0 for xgboost/lightgbm #45083

Merged
merged 11 commits into from
May 9, 2024

Conversation

justinvyu
Copy link
Contributor

@justinvyu justinvyu commented May 1, 2024

Why are these changes needed?

This PR adds a condition to only save and report a checkpoint on the rank 0 worker for xgboost and lightgbm. This prevents unnecessary checkpoints being created, since all data parallel workers have the same model states. Note: this also accounts for usage in Tune, where ray.train.get_context().get_world_rank() returns None.

This also includes a drive-by fix for checkpoint_at_end in the xgboost callback. Now, we no longer do a separate checkpoint_at_end if the checkpoint frequency happens to line up with the last iteration. For example, if saving every 5 iterations: [iter] 0 1 2 3 4 (checkpoint) 5 6 7 8 9 (checkpoint) (checkpoint), we no longer have this "duplicate" checkpoint at the end after this fix.

Related issue number

Reported on slack

Checks

  • I've signed off every commit(by using the -s flag, i.e., git commit -s) in this PR.
  • I've run scripts/format.sh to lint the changes in this PR.
  • I've included any doc changes needed for https://docs.ray.io/en/master/.
    • I've added any new APIs to the API Reference. For example, if I added a
      method in Tune, I've added it in doc/source/tune/api/ under the
      corresponding .rst file.
  • I've made sure the tests are passing. Note that there might be a few flaky tests, see the recent failures at https://flakey-tests.ray.io/
  • Testing Strategy
    • Unit tests
    • Release tests
    • This PR is not tested :(

@justinvyu justinvyu changed the title [train] `RayTrain [train] RayTrainReportCallback should only save a checkpoint on rank 0 for xgboost/lightgbm May 1, 2024
Comment on lines 150 to 175
callback = RayTrainReportCallback(frequency=2, checkpoint_at_end=True)

booster = mock.MagicMock()

with mock.patch("ray.train.report") as mock_report, mock.patch(
"ray.train.get_context"
) as mock_get_context:
mock_context = mock.MagicMock()
mock_context.get_world_rank.return_value = rank
mock_get_context.return_value = mock_context

booster.num_boosted_rounds.return_value = 2
callback.after_iteration(booster, epoch=1, evals_log={})

# Only rank 0 should report based on `frequency`
reported_checkpoint = bool(mock_report.call_args.kwargs.get("checkpoint"))
if rank == 0:
assert reported_checkpoint
else:
assert not reported_checkpoint

booster.num_boosted_rounds.return_value = 3
callback.after_iteration(booster, epoch=2, evals_log={})
# Nobody should report a checkpoint on iterations
reported_checkpoint = bool(mock_report.call_args.kwargs.get("checkpoint"))
assert not reported_checkpoint
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this test unnecessarily complicated? Any better ways to do this?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just went through some MagicMock docs. I think this unit-testing function looks good too me :)
For simplicity, probably we can just test the class variable RayTrainReportCallback._last_checkpoint_iteration after the function callback.after_iteration() is called. In this way, we can reduce the mock of ray.train.report. Then the code from line 154 ~ 169 could be simplified as:

with mock.patch("ray.train.get_context") as mock_get_context:
        mock_context = mock.MagicMock()
        mock_context.get_world_rank.return_value = rank
        mock_get_context.return_value = mock_context

        booster.num_boosted_rounds.return_value = 2
        callback.after_iteration(booster, epoch=1, evals_log={})
        assertEqual(callback._last_checkpoint_iteration, 1) # same as epoch number

However, this test assumes ray.train.report and RayTrainReportCallback._get_checkpoint() always works. It's also a little indirect. Correct me if I am wrong :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right, the test was a little too indirect and I was able to find a similar way to simplify it. PTAL

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice simplification! Learnt a lot from the code 👍

python/ray/train/tests/test_xgboost_trainer.py Outdated Show resolved Hide resolved
Comment on lines 150 to 175
callback = RayTrainReportCallback(frequency=2, checkpoint_at_end=True)

booster = mock.MagicMock()

with mock.patch("ray.train.report") as mock_report, mock.patch(
"ray.train.get_context"
) as mock_get_context:
mock_context = mock.MagicMock()
mock_context.get_world_rank.return_value = rank
mock_get_context.return_value = mock_context

booster.num_boosted_rounds.return_value = 2
callback.after_iteration(booster, epoch=1, evals_log={})

# Only rank 0 should report based on `frequency`
reported_checkpoint = bool(mock_report.call_args.kwargs.get("checkpoint"))
if rank == 0:
assert reported_checkpoint
else:
assert not reported_checkpoint

booster.num_boosted_rounds.return_value = 3
callback.after_iteration(booster, epoch=2, evals_log={})
# Nobody should report a checkpoint on iterations
reported_checkpoint = bool(mock_report.call_args.kwargs.get("checkpoint"))
assert not reported_checkpoint
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just went through some MagicMock docs. I think this unit-testing function looks good too me :)
For simplicity, probably we can just test the class variable RayTrainReportCallback._last_checkpoint_iteration after the function callback.after_iteration() is called. In this way, we can reduce the mock of ray.train.report. Then the code from line 154 ~ 169 could be simplified as:

with mock.patch("ray.train.get_context") as mock_get_context:
        mock_context = mock.MagicMock()
        mock_context.get_world_rank.return_value = rank
        mock_get_context.return_value = mock_context

        booster.num_boosted_rounds.return_value = 2
        callback.after_iteration(booster, epoch=1, evals_log={})
        assertEqual(callback._last_checkpoint_iteration, 1) # same as epoch number

However, this test assumes ray.train.report and RayTrainReportCallback._get_checkpoint() always works. It's also a little indirect. Correct me if I am wrong :)

Copy link
Contributor

@matthewdeng matthewdeng left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Clean!


Copy link
Contributor

@hongpeng-guo hongpeng-guo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@justinvyu justinvyu merged commit 112e859 into ray-project:master May 9, 2024
5 checks passed
@justinvyu justinvyu deleted the xgb_lgb_rank0_only_ckpt branch May 9, 2024 19:56
HenryZJY pushed a commit to HenryZJY/ray that referenced this pull request May 10, 2024
…k 0 for xgboost/lightgbm (ray-project#45083)

This PR adds a condition to only save and report a checkpoint on the
rank 0 worker for xgboost and lightgbm. This prevents unnecessary
checkpoints being created, since all data parallel workers have the same
model states. Note: this also accounts for usage in Tune, where
`ray.train.get_context().get_world_rank()` returns `None`.

Fix `checkpoint_at_end` for the xgboost callback to avoid duplicate checkpoints.
---------

Signed-off-by: Justin Yu <[email protected]>
Co-authored-by: Hongpeng Guo <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants