diff --git a/requirements/test-requirements.txt b/requirements/test-requirements.txt index 9542962..02e5320 100644 --- a/requirements/test-requirements.txt +++ b/requirements/test-requirements.txt @@ -2,7 +2,7 @@ packaging petastorm pytest pyarrow -ray[tune, data] +ray[tune, data, default] scikit-learn modin dask @@ -10,4 +10,3 @@ dask #workaround for now protobuf<4.0.0 tensorboardX==2.2 -aiohttp diff --git a/xgboost_ray/examples/simple_tune.py b/xgboost_ray/examples/simple_tune.py index 9a3cd33..4513eb3 100644 --- a/xgboost_ray/examples/simple_tune.py +++ b/xgboost_ray/examples/simple_tune.py @@ -66,7 +66,7 @@ def main(cpus_per_actor, num_actors, num_samples): # Load the best model checkpoint. best_bst = xgboost_ray.tune.load_model( - os.path.join(analysis.best_logdir, "tuned.xgb") + os.path.join(analysis.best_trial.local_path, "tuned.xgb") ) best_bst.save_model("best_model.xgb") diff --git a/xgboost_ray/main.py b/xgboost_ray/main.py index ed0daf7..7df2e16 100644 --- a/xgboost_ray/main.py +++ b/xgboost_ray/main.py @@ -557,7 +557,7 @@ def __init__( self.checkpoint_frequency = checkpoint_frequency - self._data: Dict[RayDMatrix, xgb.DMatrix] = {} + self._data: Dict[RayDMatrix, dict] = {} self._local_n: Dict[RayDMatrix, int] = {} self._stop_event = stop_event diff --git a/xgboost_ray/tests/test_tune.py b/xgboost_ray/tests/test_tune.py index 05bf508..8397830 100644 --- a/xgboost_ray/tests/test_tune.py +++ b/xgboost_ray/tests/test_tune.py @@ -158,8 +158,13 @@ def testReplaceTuneCheckpoints(self): replaced = in_dict["callbacks"][0] self.assertTrue(isinstance(replaced, TuneReportCheckpointCallback)) - self.assertSequenceEqual(replaced._report._metrics, ["met"]) - self.assertEqual(replaced._checkpoint._filename, "test") + + if getattr(replaced, "_report", None): + self.assertSequenceEqual(replaced._report._metrics, ["met"]) + self.assertEqual(replaced._checkpoint._filename, "test") + else: + self.assertSequenceEqual(replaced._metrics, ["met"]) + self.assertEqual(replaced._filename, "test") def testEndToEndCheckpointing(self): ray_params = RayParams(cpus_per_actor=1, num_actors=2) diff --git a/xgboost_ray/tune.py b/xgboost_ray/tune.py index bc80272..fc0a531 100644 --- a/xgboost_ray/tune.py +++ b/xgboost_ray/tune.py @@ -3,6 +3,7 @@ from typing import Dict, Optional import ray +from ray.train._internal.session import get_session from ray.util.annotations import PublicAPI from xgboost_ray.session import get_rabit_rank, put_queue @@ -10,7 +11,7 @@ from xgboost_ray.xgb import xgboost as xgb try: - from ray import tune + from ray import train, tune from ray.tune import is_session_enabled from ray.tune.integration.xgboost import ( TuneReportCallback as OrigTuneReportCallback, @@ -39,30 +40,53 @@ def is_session_enabled(): flatten_dict = is_session_enabled TUNE_INSTALLED = False + if TUNE_INSTALLED: - # New style callbacks. - class TuneReportCallback(OrigTuneReportCallback): - def after_iteration(self, model, epoch: int, evals_log: Dict): - if get_rabit_rank() == 0: - report_dict = self._get_report_dict(evals_log) - put_queue(lambda: tune.report(**report_dict)) - - class _TuneCheckpointCallback(_OrigTuneCheckpointCallback): - def after_iteration(self, model, epoch: int, evals_log: Dict): - if get_rabit_rank() == 0: - put_queue( - lambda: self._create_checkpoint( - model, epoch, self._filename, self._frequency + if not hasattr(train, "report"): + + # New style callbacks. + class TuneReportCallback(OrigTuneReportCallback): + def after_iteration(self, model, epoch: int, evals_log: Dict): + if get_rabit_rank() == 0: + report_dict = self._get_report_dict(evals_log) + put_queue(lambda: tune.report(**report_dict)) + + class _TuneCheckpointCallback(_OrigTuneCheckpointCallback): + def after_iteration(self, model, epoch: int, evals_log: Dict): + if get_rabit_rank() == 0: + put_queue( + lambda: self._create_checkpoint( + model, epoch, self._filename, self._frequency + ) ) - ) - class TuneReportCheckpointCallback(OrigTuneReportCheckpointCallback): - _checkpoint_callback_cls = _TuneCheckpointCallback - _report_callbacks_cls = TuneReportCallback + class TuneReportCheckpointCallback(OrigTuneReportCheckpointCallback): + _checkpoint_callback_cls = _TuneCheckpointCallback + _report_callbacks_cls = TuneReportCallback + + else: + + class TuneReportCheckpointCallback(OrigTuneReportCheckpointCallback): + def after_iteration(self, model, epoch: int, evals_log: Dict): + if get_rabit_rank() == 0: + put_queue( + lambda: super( + TuneReportCheckpointCallback, self + ).after_iteration(model=model, epoch=epoch, evals_log=evals_log) + ) + + class TuneReportCallback(OrigTuneReportCallback): + def after_iteration(self, model, epoch: int, evals_log: Dict): + if get_rabit_rank() == 0: + put_queue( + lambda: super(TuneReportCallback, self).after_iteration( + model=model, epoch=epoch, evals_log=evals_log + ) + ) def _try_add_tune_callback(kwargs: Dict): - if TUNE_INSTALLED and is_session_enabled(): + if TUNE_INSTALLED and (is_session_enabled() or get_session()): callbacks = kwargs.get("callbacks", []) or [] new_callbacks = [] has_tune_callback = False @@ -88,10 +112,19 @@ def _try_add_tune_callback(kwargs: Dict): ) has_tune_callback = True elif isinstance(cb, OrigTuneReportCheckpointCallback): + if getattr(cb, "_report", None): + orig_metrics = cb._report._metrics + orig_filename = cb._checkpoint._filename + orig_frequency = cb._checkpoint._frequency + else: + orig_metrics = cb._metrics + orig_filename = cb._filename + orig_frequency = cb._frequency + replace_cb = TuneReportCheckpointCallback( - metrics=cb._report._metrics, - filename=cb._checkpoint._filename, - frequency=cb._checkpoint._frequency, + metrics=orig_metrics, + filename=orig_filename, + frequency=orig_frequency, ) new_callbacks.append(replace_cb) logging.warning(