Skip to content

Commit

Permalink
Use new train.report API (#292)
Browse files Browse the repository at this point in the history
We are converging on using train.report throughout the Ray library code base instead of tune.report.
  • Loading branch information
krfricke authored Aug 24, 2023
1 parent d415b49 commit 6a4685b
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 28 deletions.
3 changes: 1 addition & 2 deletions requirements/test-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,11 @@ packaging
petastorm
pytest
pyarrow
ray[tune, data]
ray[tune, data, default]
scikit-learn
modin
dask

#workaround for now
protobuf<4.0.0
tensorboardX==2.2
aiohttp
2 changes: 1 addition & 1 deletion xgboost_ray/examples/simple_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def main(cpus_per_actor, num_actors, num_samples):

# Load the best model checkpoint.
best_bst = xgboost_ray.tune.load_model(
os.path.join(analysis.best_logdir, "tuned.xgb")
os.path.join(analysis.best_trial.local_path, "tuned.xgb")
)

best_bst.save_model("best_model.xgb")
Expand Down
2 changes: 1 addition & 1 deletion xgboost_ray/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,7 +557,7 @@ def __init__(

self.checkpoint_frequency = checkpoint_frequency

self._data: Dict[RayDMatrix, xgb.DMatrix] = {}
self._data: Dict[RayDMatrix, dict] = {}
self._local_n: Dict[RayDMatrix, int] = {}

self._stop_event = stop_event
Expand Down
9 changes: 7 additions & 2 deletions xgboost_ray/tests/test_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,13 @@ def testReplaceTuneCheckpoints(self):

replaced = in_dict["callbacks"][0]
self.assertTrue(isinstance(replaced, TuneReportCheckpointCallback))
self.assertSequenceEqual(replaced._report._metrics, ["met"])
self.assertEqual(replaced._checkpoint._filename, "test")

if getattr(replaced, "_report", None):
self.assertSequenceEqual(replaced._report._metrics, ["met"])
self.assertEqual(replaced._checkpoint._filename, "test")
else:
self.assertSequenceEqual(replaced._metrics, ["met"])
self.assertEqual(replaced._filename, "test")

def testEndToEndCheckpointing(self):
ray_params = RayParams(cpus_per_actor=1, num_actors=2)
Expand Down
77 changes: 55 additions & 22 deletions xgboost_ray/tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@
from typing import Dict, Optional

import ray
from ray.train._internal.session import get_session
from ray.util.annotations import PublicAPI

from xgboost_ray.session import get_rabit_rank, put_queue
from xgboost_ray.util import Unavailable, force_on_current_node
from xgboost_ray.xgb import xgboost as xgb

try:
from ray import tune
from ray import train, tune
from ray.tune import is_session_enabled
from ray.tune.integration.xgboost import (
TuneReportCallback as OrigTuneReportCallback,
Expand Down Expand Up @@ -39,30 +40,53 @@ def is_session_enabled():
flatten_dict = is_session_enabled
TUNE_INSTALLED = False


if TUNE_INSTALLED:
# New style callbacks.
class TuneReportCallback(OrigTuneReportCallback):
def after_iteration(self, model, epoch: int, evals_log: Dict):
if get_rabit_rank() == 0:
report_dict = self._get_report_dict(evals_log)
put_queue(lambda: tune.report(**report_dict))

class _TuneCheckpointCallback(_OrigTuneCheckpointCallback):
def after_iteration(self, model, epoch: int, evals_log: Dict):
if get_rabit_rank() == 0:
put_queue(
lambda: self._create_checkpoint(
model, epoch, self._filename, self._frequency
if not hasattr(train, "report"):

# New style callbacks.
class TuneReportCallback(OrigTuneReportCallback):
def after_iteration(self, model, epoch: int, evals_log: Dict):
if get_rabit_rank() == 0:
report_dict = self._get_report_dict(evals_log)
put_queue(lambda: tune.report(**report_dict))

class _TuneCheckpointCallback(_OrigTuneCheckpointCallback):
def after_iteration(self, model, epoch: int, evals_log: Dict):
if get_rabit_rank() == 0:
put_queue(
lambda: self._create_checkpoint(
model, epoch, self._filename, self._frequency
)
)
)

class TuneReportCheckpointCallback(OrigTuneReportCheckpointCallback):
_checkpoint_callback_cls = _TuneCheckpointCallback
_report_callbacks_cls = TuneReportCallback
class TuneReportCheckpointCallback(OrigTuneReportCheckpointCallback):
_checkpoint_callback_cls = _TuneCheckpointCallback
_report_callbacks_cls = TuneReportCallback

else:

class TuneReportCheckpointCallback(OrigTuneReportCheckpointCallback):
def after_iteration(self, model, epoch: int, evals_log: Dict):
if get_rabit_rank() == 0:
put_queue(
lambda: super(
TuneReportCheckpointCallback, self
).after_iteration(model=model, epoch=epoch, evals_log=evals_log)
)

class TuneReportCallback(OrigTuneReportCallback):
def after_iteration(self, model, epoch: int, evals_log: Dict):
if get_rabit_rank() == 0:
put_queue(
lambda: super(TuneReportCallback, self).after_iteration(
model=model, epoch=epoch, evals_log=evals_log
)
)


def _try_add_tune_callback(kwargs: Dict):
if TUNE_INSTALLED and is_session_enabled():
if TUNE_INSTALLED and (is_session_enabled() or get_session()):
callbacks = kwargs.get("callbacks", []) or []
new_callbacks = []
has_tune_callback = False
Expand All @@ -88,10 +112,19 @@ def _try_add_tune_callback(kwargs: Dict):
)
has_tune_callback = True
elif isinstance(cb, OrigTuneReportCheckpointCallback):
if getattr(cb, "_report", None):
orig_metrics = cb._report._metrics
orig_filename = cb._checkpoint._filename
orig_frequency = cb._checkpoint._frequency
else:
orig_metrics = cb._metrics
orig_filename = cb._filename
orig_frequency = cb._frequency

replace_cb = TuneReportCheckpointCallback(
metrics=cb._report._metrics,
filename=cb._checkpoint._filename,
frequency=cb._checkpoint._frequency,
metrics=orig_metrics,
filename=orig_filename,
frequency=orig_frequency,
)
new_callbacks.append(replace_cb)
logging.warning(
Expand Down

0 comments on commit 6a4685b

Please sign in to comment.