diff --git a/python/ray/train/lightgbm/_lightgbm_utils.py b/python/ray/train/lightgbm/_lightgbm_utils.py index 99a1309ad5b1..15c4e344bd16 100644 --- a/python/ray/train/lightgbm/_lightgbm_utils.py +++ b/python/ray/train/lightgbm/_lightgbm_utils.py @@ -6,7 +6,7 @@ from lightgbm.basic import Booster from lightgbm.callback import CallbackEnv -from ray import train +import ray.train from ray.train import Checkpoint from ray.tune.utils import flatten_dict from ray.util.annotations import PublicAPI @@ -142,25 +142,29 @@ def _get_eval_result(self, env: CallbackEnv) -> dict: @contextmanager def _get_checkpoint(self, model: Booster) -> Optional[Checkpoint]: - with tempfile.TemporaryDirectory() as temp_checkpoint_dir: - model.save_model(Path(temp_checkpoint_dir, self._filename).as_posix()) - yield Checkpoint.from_directory(temp_checkpoint_dir) + if ray.train.get_context().get_world_rank() in (0, None): + with tempfile.TemporaryDirectory() as temp_checkpoint_dir: + model.save_model(Path(temp_checkpoint_dir, self._filename).as_posix()) + yield Checkpoint.from_directory(temp_checkpoint_dir) + else: + yield None def __call__(self, env: CallbackEnv) -> None: eval_result = self._get_eval_result(env) report_dict = self._get_report_dict(eval_result) + # Ex: if frequency=2, checkpoint_at_end=True and num_boost_rounds=11, + # you will checkpoint at iterations 1, 3, 5, ..., 9, and 10 (checkpoint_at_end) + # (iterations count from 0) on_last_iter = env.iteration == env.end_iteration - 1 - checkpointing_disabled = self._frequency == 0 - # Ex: if frequency=2, checkpoint_at_end=True and num_boost_rounds=10, - # you will checkpoint at iterations 1, 3, 5, ..., and 9 (checkpoint_at_end) - # (counting from 0) - should_checkpoint = ( - not checkpointing_disabled and (env.iteration + 1) % self._frequency == 0 - ) or (on_last_iter and self._checkpoint_at_end) + should_checkpoint_at_end = on_last_iter and self._checkpoint_at_end + should_checkpoint_with_frequency = ( + self._frequency != 0 and (env.iteration + 1) % self._frequency == 0 + ) + should_checkpoint = should_checkpoint_at_end or should_checkpoint_with_frequency if should_checkpoint: with self._get_checkpoint(model=env.model) as checkpoint: - train.report(report_dict, checkpoint=checkpoint) + ray.train.report(report_dict, checkpoint=checkpoint) else: - train.report(report_dict) + ray.train.report(report_dict) diff --git a/python/ray/train/tests/test_lightgbm_trainer.py b/python/ray/train/tests/test_lightgbm_trainer.py index fed115c7a13e..81c75a73a28f 100644 --- a/python/ray/train/tests/test_lightgbm_trainer.py +++ b/python/ray/train/tests/test_lightgbm_trainer.py @@ -1,4 +1,5 @@ import math +from unittest import mock import lightgbm as lgbm import pandas as pd @@ -10,7 +11,7 @@ from ray import tune from ray.train import ScalingConfig from ray.train.constants import TRAIN_DATASET_KEY -from ray.train.lightgbm import LightGBMTrainer +from ray.train.lightgbm import LightGBMTrainer, RayTrainReportCallback @pytest.fixture @@ -101,10 +102,11 @@ def test_resume_from_checkpoint(ray_start_6_cpus, tmpdir): @pytest.mark.parametrize( "freq_end_expected", [ - (4, True, 7), # 4, 8, 12, 16, 20, 24, 25 - (4, False, 6), # 4, 8, 12, 16, 20, 24 - (5, True, 5), # 5, 10, 15, 20, 25 - (0, True, 1), + # With num_boost_round=25 with 0 indexing, the checkpoints will be at: + (4, True, 7), # 3, 7, 11, 15, 19, 23, 24 (end) + (4, False, 6), # 3, 7, 11, 15, 19, 23 + (5, True, 5), # 4, 9, 14, 19, 24 + (0, True, 1), # 24 (end) (0, False, 0), ], ) @@ -166,6 +168,26 @@ def test_validation(ray_start_6_cpus): ) +@pytest.mark.parametrize("rank", [None, 0, 1]) +def test_checkpoint_only_on_rank0(rank): + """Tests that the callback only reports checkpoints on rank 0, + or if the rank is not available (Tune usage).""" + callback = RayTrainReportCallback(frequency=2, checkpoint_at_end=True) + + booster = mock.MagicMock() + + with mock.patch("ray.train.get_context") as mock_get_context: + mock_context = mock.MagicMock() + mock_context.get_world_rank.return_value = rank + mock_get_context.return_value = mock_context + + with callback._get_checkpoint(booster) as checkpoint: + if rank in (0, None): + assert checkpoint + else: + assert not checkpoint + + if __name__ == "__main__": import sys diff --git a/python/ray/train/tests/test_xgboost_trainer.py b/python/ray/train/tests/test_xgboost_trainer.py index acf27a4fc04a..463b8ce8c226 100644 --- a/python/ray/train/tests/test_xgboost_trainer.py +++ b/python/ray/train/tests/test_xgboost_trainer.py @@ -1,4 +1,4 @@ -import json +from unittest import mock import pandas as pd import pytest @@ -43,11 +43,6 @@ def ray_start_8_cpus(): } -def get_num_trees(booster: xgb.Booster) -> int: - data = [json.loads(d) for d in booster.get_dump(dump_format="json")] - return len(data) - - def test_fit(ray_start_4_cpus): train_dataset = ray.data.from_pandas(train_df) valid_dataset = ray.data.from_pandas(test_df) @@ -114,12 +109,11 @@ def test_resume_from_checkpoint(ray_start_4_cpus, tmpdir): @pytest.mark.parametrize( "freq_end_expected", [ - (4, True, 7), # 4, 8, 12, 16, 20, 24, 25 - (4, False, 6), # 4, 8, 12, 16, 20, 24 - # TODO(justinvyu): [simplify_xgb] - # Fix this checkpoint_at_end/checkpoint_frequency overlap behavior. - # (5, True, 5), # 5, 10, 15, 20, 25 - (0, True, 1), # end + # With num_boost_round=25 with 0 indexing, the checkpoints will be at: + (4, True, 7), # 3, 7, 11, 15, 19, 23, 24 (end) + (4, False, 6), # 3, 7, 11, 15, 19, 23 + (5, True, 5), # 4, 9, 14, 19, 24 + (0, True, 1), # 24 (end) (0, False, 0), ], ) @@ -152,6 +146,26 @@ def test_checkpoint_freq(ray_start_4_cpus, freq_end_expected): assert cp_paths == sorted(cp_paths), str(cp_paths) +@pytest.mark.parametrize("rank", [None, 0, 1]) +def test_checkpoint_only_on_rank0(rank): + """Tests that the callback only reports checkpoints on rank 0, + or if the rank is not available (Tune usage).""" + callback = RayTrainReportCallback(frequency=2, checkpoint_at_end=True) + + booster = mock.MagicMock() + + with mock.patch("ray.train.get_context") as mock_get_context: + mock_context = mock.MagicMock() + mock_context.get_world_rank.return_value = rank + mock_get_context.return_value = mock_context + + with callback._get_checkpoint(booster) as checkpoint: + if rank in (0, None): + assert checkpoint + else: + assert not checkpoint + + def test_tune(ray_start_8_cpus): train_dataset = ray.data.from_pandas(train_df) valid_dataset = ray.data.from_pandas(test_df) diff --git a/python/ray/train/xgboost/_xgboost_utils.py b/python/ray/train/xgboost/_xgboost_utils.py index ac8c0ad74908..459dfcf07a22 100644 --- a/python/ray/train/xgboost/_xgboost_utils.py +++ b/python/ray/train/xgboost/_xgboost_utils.py @@ -6,7 +6,7 @@ from xgboost.core import Booster -from ray import train +import ray.train from ray.train import Checkpoint from ray.tune.utils import flatten_dict from ray.util.annotations import PublicAPI @@ -118,6 +118,9 @@ def __init__( # so that the latest metrics can be reported with the checkpoint # at the end of training. self._evals_log = None + # Keep track of the last checkpoint iteration to avoid double-checkpointing + # when using `checkpoint_at_end=True`. + self._last_checkpoint_iteration = None @classmethod def get_model( @@ -163,9 +166,13 @@ def _get_report_dict(self, evals_log): @contextmanager def _get_checkpoint(self, model: Booster) -> Optional[Checkpoint]: - with tempfile.TemporaryDirectory() as temp_checkpoint_dir: - model.save_model(Path(temp_checkpoint_dir, self._filename).as_posix()) - yield Checkpoint(temp_checkpoint_dir) + # NOTE: The world rank returns None for Tune usage without Train. + if ray.train.get_context().get_world_rank() in (0, None): + with tempfile.TemporaryDirectory() as temp_checkpoint_dir: + model.save_model(Path(temp_checkpoint_dir, self._filename).as_posix()) + yield Checkpoint(temp_checkpoint_dir) + else: + yield None def after_iteration(self, model: Booster, epoch: int, evals_log: Dict): self._evals_log = evals_log @@ -178,17 +185,26 @@ def after_iteration(self, model: Booster, epoch: int, evals_log: Dict): report_dict = self._get_report_dict(evals_log) if should_checkpoint: + self._last_checkpoint_iteration = epoch with self._get_checkpoint(model=model) as checkpoint: - train.report(report_dict, checkpoint=checkpoint) + ray.train.report(report_dict, checkpoint=checkpoint) else: - train.report(report_dict) + ray.train.report(report_dict) - def after_training(self, model: Booster): + def after_training(self, model: Booster) -> Booster: if not self._checkpoint_at_end: return model + if ( + self._last_checkpoint_iteration is not None + and model.num_boosted_rounds() - 1 == self._last_checkpoint_iteration + ): + # Avoids a duplicate checkpoint if the checkpoint frequency happens + # to align with the last iteration. + return model + report_dict = self._get_report_dict(self._evals_log) if self._evals_log else {} with self._get_checkpoint(model=model) as checkpoint: - train.report(report_dict, checkpoint=checkpoint) + ray.train.report(report_dict, checkpoint=checkpoint) return model