Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use new train.report API #49

Merged
merged 4 commits into from
Aug 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion lightgbm_ray/examples/simple_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def main(cpus_per_actor, num_actors, num_samples):

# Load the best model checkpoint.
best_bst = lightgbm_ray.tune.load_model(
os.path.join(analysis.best_logdir, "tuned.lgbm")
os.path.join(analysis.best_trial.local_path, "tuned.lgbm")
)

best_bst.save_model("best_model.lgbm")
Expand Down
2 changes: 1 addition & 1 deletion lightgbm_ray/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ def _save_internal_checkpoint_callback() -> Callable:
def _callback(env: CallbackEnv) -> None:
if not is_rank_0:
return
if (
if this.checkpoint_frequency > 0 and (
env.iteration == env.end_iteration - 1
or env.iteration % this.checkpoint_frequency == 0
):
Expand Down
9 changes: 7 additions & 2 deletions lightgbm_ray/tests/test_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,13 @@ def testReplaceTuneCheckpoints(self):

replaced = in_dict["callbacks"][0]
self.assertTrue(isinstance(replaced, TuneReportCheckpointCallback))
self.assertSequenceEqual(replaced._report._metrics, ["met"])
self.assertEqual(replaced._checkpoint._filename, "test")

if getattr(replaced, "_report", None):
self.assertSequenceEqual(replaced._report._metrics, ["met"])
self.assertEqual(replaced._checkpoint._filename, "test")
else:
self.assertSequenceEqual(replaced._metrics, ["met"])
self.assertEqual(replaced._filename, "test")

def testEndToEndCheckpointing(self):
ray.init(num_cpus=4)
Expand Down
115 changes: 72 additions & 43 deletions lightgbm_ray/tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@
import ray
from lightgbm.basic import Booster
from lightgbm.callback import CallbackEnv
from ray.train._internal.session import get_session
from ray.util.annotations import PublicAPI
from xgboost_ray.session import put_queue
from xgboost_ray.util import force_on_current_node

try:
from ray import tune
from ray import train, tune
from ray.tune import is_session_enabled
from ray.tune.integration.lightgbm import (
TuneReportCallback as OrigTuneReportCallback,
Expand Down Expand Up @@ -49,49 +50,68 @@ def is_rank_0(self, val: bool):


if TUNE_INSTALLED:

class TuneReportCallback(_TuneLGBMRank0Mixin, OrigTuneReportCallback):
def __call__(self, env: CallbackEnv) -> None:
if not self.is_rank_0:
return
eval_result = self._get_eval_result(env)
report_dict = self._get_report_dict(eval_result)
put_queue(lambda: tune.report(**report_dict))

class _TuneCheckpointCallback(_TuneLGBMRank0Mixin, _OrigTuneCheckpointCallback):
def __call__(self, env: CallbackEnv) -> None:
if not self.is_rank_0:
return
put_queue(
lambda: self._create_checkpoint(
env.model, env.iteration, self._filename, self._frequency
if not hasattr(train, "report"):

class TuneReportCallback(_TuneLGBMRank0Mixin, OrigTuneReportCallback):
def __call__(self, env: CallbackEnv) -> None:
if not self.is_rank_0:
return
eval_result = self._get_eval_result(env)
report_dict = self._get_report_dict(eval_result)
put_queue(lambda: tune.report(**report_dict))

class _TuneCheckpointCallback(_TuneLGBMRank0Mixin, _OrigTuneCheckpointCallback):
def __call__(self, env: CallbackEnv) -> None:
if not self.is_rank_0:
return
put_queue(
lambda: self._create_checkpoint(
env.model, env.iteration, self._filename, self._frequency
)
)
)

class TuneReportCheckpointCallback(
_TuneLGBMRank0Mixin, OrigTuneReportCheckpointCallback
):
_checkpoint_callback_cls = _TuneCheckpointCallback
_report_callback_cls = TuneReportCallback

@property
def is_rank_0(self) -> bool:
try:
return self._is_rank_0
except AttributeError:
return False

@is_rank_0.setter
def is_rank_0(self, val: bool):
self._is_rank_0 = val
if hasattr(self, "_checkpoint"):
self._checkpoint.is_rank_0 = val
if hasattr(self, "_report"):
self._report.is_rank_0 = val

class TuneReportCheckpointCallback(
_TuneLGBMRank0Mixin, OrigTuneReportCheckpointCallback
):
_checkpoint_callback_cls = _TuneCheckpointCallback
_report_callback_cls = TuneReportCallback

@property
def is_rank_0(self) -> bool:
try:
return self._is_rank_0
except AttributeError:
return False

@is_rank_0.setter
def is_rank_0(self, val: bool):
self._is_rank_0 = val
if hasattr(self, "_checkpoint"):
self._checkpoint.is_rank_0 = val
if hasattr(self, "_report"):
self._report.is_rank_0 = val

else:

class TuneReportCheckpointCallback(
_TuneLGBMRank0Mixin, OrigTuneReportCheckpointCallback
):
def __call__(self, env: CallbackEnv):
if self.is_rank_0:
put_queue(
lambda: super(TuneReportCheckpointCallback, self).__call__(
env=env
)
)

class TuneReportCallback(_TuneLGBMRank0Mixin, OrigTuneReportCallback):
def __call__(self, env: CallbackEnv):
if self.is_rank_0:
put_queue(lambda: super(TuneReportCallback, self).__call__(env=env))


def _try_add_tune_callback(kwargs: Dict):
if TUNE_INSTALLED and is_session_enabled():
if TUNE_INSTALLED and is_session_enabled() or get_session():
callbacks = kwargs.get("callbacks", []) or []
new_callbacks = []
has_tune_callback = False
Expand All @@ -117,10 +137,19 @@ def _try_add_tune_callback(kwargs: Dict):
)
has_tune_callback = True
elif isinstance(cb, OrigTuneReportCheckpointCallback):
if getattr(cb, "_report", None):
orig_metrics = cb._report._metrics
orig_filename = cb._checkpoint._filename
orig_frequency = cb._checkpoint._frequency
else:
orig_metrics = cb._metrics
orig_filename = cb._filename
orig_frequency = cb._frequency

replace_cb = TuneReportCheckpointCallback(
metrics=cb._report._metrics,
filename=cb._checkpoint._filename,
frequency=cb._checkpoint._frequency,
metrics=orig_metrics,
filename=orig_filename,
frequency=orig_frequency,
)
new_callbacks.append(replace_cb)
logging.warning(
Expand Down
Loading