Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[train] RayTrainReportCallback should only save a checkpoint on rank 0 for xgboost/lightgbm #45083

Merged
merged 11 commits into from
May 9, 2024
30 changes: 17 additions & 13 deletions python/ray/train/lightgbm/_lightgbm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from lightgbm.basic import Booster
from lightgbm.callback import CallbackEnv

from ray import train
import ray.train
from ray.train import Checkpoint
from ray.tune.utils import flatten_dict
from ray.util.annotations import PublicAPI
Expand Down Expand Up @@ -142,25 +142,29 @@ def _get_eval_result(self, env: CallbackEnv) -> dict:

@contextmanager
def _get_checkpoint(self, model: Booster) -> Optional[Checkpoint]:
with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
model.save_model(Path(temp_checkpoint_dir, self._filename).as_posix())
yield Checkpoint.from_directory(temp_checkpoint_dir)
if ray.train.get_context().get_world_rank() in (0, None):
with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
model.save_model(Path(temp_checkpoint_dir, self._filename).as_posix())
yield Checkpoint.from_directory(temp_checkpoint_dir)
else:
yield None

def __call__(self, env: CallbackEnv) -> None:
eval_result = self._get_eval_result(env)
report_dict = self._get_report_dict(eval_result)

# Ex: if frequency=2, checkpoint_at_end=True and num_boost_rounds=11,
# you will checkpoint at iterations 1, 3, 5, ..., 9, and 10 (checkpoint_at_end)
# (iterations count from 0)
on_last_iter = env.iteration == env.end_iteration - 1
checkpointing_disabled = self._frequency == 0
# Ex: if frequency=2, checkpoint_at_end=True and num_boost_rounds=10,
# you will checkpoint at iterations 1, 3, 5, ..., and 9 (checkpoint_at_end)
# (counting from 0)
should_checkpoint = (
not checkpointing_disabled and (env.iteration + 1) % self._frequency == 0
) or (on_last_iter and self._checkpoint_at_end)
should_checkpoint_at_end = on_last_iter and self._checkpoint_at_end
should_checkpoint_with_frequency = (
self._frequency != 0 and (env.iteration + 1) % self._frequency == 0
)
should_checkpoint = should_checkpoint_at_end or should_checkpoint_with_frequency

if should_checkpoint:
with self._get_checkpoint(model=env.model) as checkpoint:
train.report(report_dict, checkpoint=checkpoint)
ray.train.report(report_dict, checkpoint=checkpoint)
else:
train.report(report_dict)
ray.train.report(report_dict)
32 changes: 27 additions & 5 deletions python/ray/train/tests/test_lightgbm_trainer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import math
from unittest import mock

import lightgbm as lgbm
import pandas as pd
Expand All @@ -10,7 +11,7 @@
from ray import tune
from ray.train import ScalingConfig
from ray.train.constants import TRAIN_DATASET_KEY
from ray.train.lightgbm import LightGBMTrainer
from ray.train.lightgbm import LightGBMTrainer, RayTrainReportCallback


@pytest.fixture
Expand Down Expand Up @@ -101,10 +102,11 @@ def test_resume_from_checkpoint(ray_start_6_cpus, tmpdir):
@pytest.mark.parametrize(
"freq_end_expected",
[
(4, True, 7), # 4, 8, 12, 16, 20, 24, 25
(4, False, 6), # 4, 8, 12, 16, 20, 24
(5, True, 5), # 5, 10, 15, 20, 25
(0, True, 1),
# With num_boost_round=25 with 0 indexing, the checkpoints will be at:
(4, True, 7), # 3, 7, 11, 15, 19, 23, 24 (end)
(4, False, 6), # 3, 7, 11, 15, 19, 23
(5, True, 5), # 4, 9, 14, 19, 24
(0, True, 1), # 24 (end)
(0, False, 0),
],
)
Expand Down Expand Up @@ -166,6 +168,26 @@ def test_validation(ray_start_6_cpus):
)


@pytest.mark.parametrize("rank", [None, 0, 1])
def test_checkpoint_only_on_rank0(rank):
"""Tests that the callback only reports checkpoints on rank 0,
or if the rank is not available (Tune usage)."""
callback = RayTrainReportCallback(frequency=2, checkpoint_at_end=True)

booster = mock.MagicMock()

with mock.patch("ray.train.get_context") as mock_get_context:
mock_context = mock.MagicMock()
mock_context.get_world_rank.return_value = rank
mock_get_context.return_value = mock_context

with callback._get_checkpoint(booster) as checkpoint:
if rank in (0, None):
assert checkpoint
else:
assert not checkpoint


if __name__ == "__main__":
import sys

Expand Down
38 changes: 26 additions & 12 deletions python/ray/train/tests/test_xgboost_trainer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import json
from unittest import mock

import pandas as pd
import pytest
Expand Down Expand Up @@ -43,11 +43,6 @@ def ray_start_8_cpus():
}


def get_num_trees(booster: xgb.Booster) -> int:
data = [json.loads(d) for d in booster.get_dump(dump_format="json")]
return len(data)


def test_fit(ray_start_4_cpus):
train_dataset = ray.data.from_pandas(train_df)
valid_dataset = ray.data.from_pandas(test_df)
Expand Down Expand Up @@ -114,12 +109,11 @@ def test_resume_from_checkpoint(ray_start_4_cpus, tmpdir):
@pytest.mark.parametrize(
"freq_end_expected",
[
(4, True, 7), # 4, 8, 12, 16, 20, 24, 25
(4, False, 6), # 4, 8, 12, 16, 20, 24
# TODO(justinvyu): [simplify_xgb]
# Fix this checkpoint_at_end/checkpoint_frequency overlap behavior.
# (5, True, 5), # 5, 10, 15, 20, 25
(0, True, 1), # end
# With num_boost_round=25 with 0 indexing, the checkpoints will be at:
(4, True, 7), # 3, 7, 11, 15, 19, 23, 24 (end)
(4, False, 6), # 3, 7, 11, 15, 19, 23
(5, True, 5), # 4, 9, 14, 19, 24
(0, True, 1), # 24 (end)
(0, False, 0),
],
)
Expand Down Expand Up @@ -152,6 +146,26 @@ def test_checkpoint_freq(ray_start_4_cpus, freq_end_expected):
assert cp_paths == sorted(cp_paths), str(cp_paths)


@pytest.mark.parametrize("rank", [None, 0, 1])
def test_checkpoint_only_on_rank0(rank):
"""Tests that the callback only reports checkpoints on rank 0,
or if the rank is not available (Tune usage)."""
callback = RayTrainReportCallback(frequency=2, checkpoint_at_end=True)

booster = mock.MagicMock()

with mock.patch("ray.train.get_context") as mock_get_context:
mock_context = mock.MagicMock()
mock_context.get_world_rank.return_value = rank
mock_get_context.return_value = mock_context

with callback._get_checkpoint(booster) as checkpoint:
if rank in (0, None):
assert checkpoint
else:
assert not checkpoint


def test_tune(ray_start_8_cpus):
train_dataset = ray.data.from_pandas(train_df)
valid_dataset = ray.data.from_pandas(test_df)
Expand Down
32 changes: 24 additions & 8 deletions python/ray/train/xgboost/_xgboost_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from xgboost.core import Booster

from ray import train
import ray.train
from ray.train import Checkpoint
from ray.tune.utils import flatten_dict
from ray.util.annotations import PublicAPI
Expand Down Expand Up @@ -118,6 +118,9 @@ def __init__(
# so that the latest metrics can be reported with the checkpoint
# at the end of training.
self._evals_log = None
# Keep track of the last checkpoint iteration to avoid double-checkpointing
# when using `checkpoint_at_end=True`.
self._last_checkpoint_iteration = None

@classmethod
def get_model(
Expand Down Expand Up @@ -163,9 +166,13 @@ def _get_report_dict(self, evals_log):

@contextmanager
def _get_checkpoint(self, model: Booster) -> Optional[Checkpoint]:
with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
model.save_model(Path(temp_checkpoint_dir, self._filename).as_posix())
yield Checkpoint(temp_checkpoint_dir)
# NOTE: The world rank returns None for Tune usage without Train.
if ray.train.get_context().get_world_rank() in (0, None):
with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
model.save_model(Path(temp_checkpoint_dir, self._filename).as_posix())
yield Checkpoint(temp_checkpoint_dir)
else:
yield None

def after_iteration(self, model: Booster, epoch: int, evals_log: Dict):
self._evals_log = evals_log
Expand All @@ -178,17 +185,26 @@ def after_iteration(self, model: Booster, epoch: int, evals_log: Dict):

report_dict = self._get_report_dict(evals_log)
if should_checkpoint:
self._last_checkpoint_iteration = epoch
with self._get_checkpoint(model=model) as checkpoint:
train.report(report_dict, checkpoint=checkpoint)
ray.train.report(report_dict, checkpoint=checkpoint)
else:
train.report(report_dict)
ray.train.report(report_dict)

def after_training(self, model: Booster):
def after_training(self, model: Booster) -> Booster:
if not self._checkpoint_at_end:
return model

if (
self._last_checkpoint_iteration is not None
and model.num_boosted_rounds() - 1 == self._last_checkpoint_iteration
):
# Avoids a duplicate checkpoint if the checkpoint frequency happens
# to align with the last iteration.
return model

report_dict = self._get_report_dict(self._evals_log) if self._evals_log else {}
with self._get_checkpoint(model=model) as checkpoint:
train.report(report_dict, checkpoint=checkpoint)
ray.train.report(report_dict, checkpoint=checkpoint)

return model
Loading