From f0050a010dd307628ffcfc90576c483851d81496 Mon Sep 17 00:00:00 2001 From: Amog Kamsetty Date: Fri, 11 Feb 2022 17:42:55 -0800 Subject: [PATCH] [Train] Allow `train` methods to be called outside of the session (#21969) Updates to address @worldveil's feedback: Include import train.torch in the docs Allow methods in session.py to be called outside of the session with sensible defaults. These will no longer raise an error. Co-authored-by: Balaji Veeramani --- python/ray/train/backend.py | 8 +-- python/ray/train/constants.py | 4 ++ .../ray/train/examples/torch_quick_start.py | 1 + python/ray/train/session.py | 46 +++++++++++--- python/ray/train/tests/test_session.py | 60 ++++++++++++++----- python/ray/util/debug.py | 6 ++ 6 files changed, 97 insertions(+), 28 deletions(-) diff --git a/python/ray/train/backend.py b/python/ray/train/backend.py index d6a890b52410..f5daeecbc6fa 100644 --- a/python/ray/train/backend.py +++ b/python/ray/train/backend.py @@ -637,10 +637,9 @@ def __len__(self): def _get_session(method_name: str): - try: - # Get the session for this worker. - return get_session() - except ValueError: + # Get the session for this worker. + session = get_session() + if not session: # Session is not initialized yet. raise TrainBackendError( f"`{method_name}` has been called " @@ -648,3 +647,4 @@ def _get_session(method_name: str): "`start_training` before " f"`{method_name}`." ) + return session diff --git a/python/ray/train/constants.py b/python/ray/train/constants.py index 0e8e9429fd5f..9c7ebf9b62a8 100644 --- a/python/ray/train/constants.py +++ b/python/ray/train/constants.py @@ -60,3 +60,7 @@ # Integer value which if set will change the placement group strategy from # PACK to SPREAD. 1 for True, 0 for False. TRAIN_ENABLE_WORKER_SPREAD_ENV = "TRAIN_ENABLE_WORKER_SPREAD" + +# The key used to identify whether we have already warned about ray.train +# functions being used outside of the session +SESSION_MISUSE_LOG_ONCE_KEY = "train_warn_session_misuse" diff --git a/python/ray/train/examples/torch_quick_start.py b/python/ray/train/examples/torch_quick_start.py index 8911a37c71ad..e152c8604610 100644 --- a/python/ray/train/examples/torch_quick_start.py +++ b/python/ray/train/examples/torch_quick_start.py @@ -49,6 +49,7 @@ def train_func(): # __torch_distributed_begin__ from ray import train +import ray.train.torch def train_func_distributed(): num_epochs = 3 diff --git a/python/ray/train/session.py b/python/ray/train/session.py index 1bf9efcd435b..3701d5165f52 100644 --- a/python/ray/train/session.py +++ b/python/ray/train/session.py @@ -22,9 +22,10 @@ HOSTNAME, DATE, RESULT_FETCH_TIMEOUT, + SESSION_MISUSE_LOG_ONCE_KEY, ) from ray.train.utils import PropagatingThread, RayDataset -from ray.util import PublicAPI +from ray.util import PublicAPI, log_once class TrainingResultType(Enum): @@ -235,6 +236,22 @@ def checkpoint(self, **kwargs): _session = None +def _warn_session_misuse(fn_name: str): + """Logs warning message on provided fn being used outside of session. + + Args: + fn_name (str): The name of the function to warn about. + """ + + if log_once(f"{SESSION_MISUSE_LOG_ONCE_KEY}-{fn_name}"): + warnings.warn( + f"`train.{fn_name}()` is meant to only be " + f"called " + "inside a training function that is executed by " + "`Trainer.run`. Returning None." + ) + + def init_session(*args, **kwargs) -> None: global _session if _session: @@ -245,15 +262,8 @@ def init_session(*args, **kwargs) -> None: _session = Session(*args, **kwargs) -def get_session() -> Session: +def get_session() -> Optional[Session]: global _session - if _session is None or not isinstance(_session, Session): - raise ValueError( - "Trying to access a Train session that has not been " - "initialized yet. Train functions like " - "`train.report()` should only be called from inside " - "the training function." - ) return _session @@ -301,6 +311,9 @@ def train_func(): If no dataset is passed into Trainer, then return None. """ session = get_session() + if session is None: + _warn_session_misuse(get_dataset_shard.__name__) + return shard = session.dataset_shard if shard is None: warnings.warn( @@ -345,6 +358,9 @@ def train_func(): intermediate results. """ session = get_session() + if session is None: + _warn_session_misuse(report.__name__) + return session.report(**kwargs) @@ -370,6 +386,8 @@ def train_func(): """ session = get_session() + if session is None: + return 0 return session.world_rank @@ -394,6 +412,8 @@ def train_func(): """ session = get_session() + if session is None: + return 0 return session.local_rank @@ -425,6 +445,9 @@ def train_func(): originally initialized with. ``None`` if neither exist. """ session = get_session() + if session is None: + _warn_session_misuse(load_checkpoint.__name__) + return return session.loaded_checkpoint @@ -451,6 +474,9 @@ def train_func(): **kwargs: Any key value pair to be checkpointed by Train. """ session = get_session() + if session is None: + _warn_session_misuse(save_checkpoint.__name__) + return session.checkpoint(**kwargs) @@ -472,4 +498,6 @@ def train_func(): trainer.shutdown() """ session = get_session() + if session is None: + return 1 return session.world_size diff --git a/python/ray/train/tests/test_session.py b/python/ray/train/tests/test_session.py index b4d1e712f56a..350a10515efd 100644 --- a/python/ray/train/tests/test_session.py +++ b/python/ray/train/tests/test_session.py @@ -3,6 +3,7 @@ import pytest import ray +from ray.train.constants import SESSION_MISUSE_LOG_ONCE_KEY from ray.train.session import ( init_session, shutdown_session, @@ -33,31 +34,30 @@ def test_init_fail(session): init_session(lambda: 1, 0) -def test_get_fail(session): +def test_shutdown(session): shutdown_session() - with pytest.raises(ValueError): - get_session() + assert not get_session() def test_world_rank(session): assert world_rank() == 0 shutdown_session() - with pytest.raises(ValueError): - world_rank() + # Make sure default to 0. + assert world_rank() == 0 def test_local_rank(session): assert local_rank() == 0 shutdown_session() - with pytest.raises(ValueError): - local_rank() + # Make sure default to 0. + assert local_rank() == 0 def test_world_size(session): assert world_size() == 1 shutdown_session() - with pytest.raises(ValueError): - world_size() + # Make sure default to 1. + assert world_size() == 1 def test_train(session): @@ -91,9 +91,6 @@ def train_func(): assert session.get_next().data["loss"] == 1 shutdown_session() - with pytest.raises(ValueError): - report(loss=2) - def test_report_fail(): def train_func(): @@ -157,9 +154,6 @@ def validate_nonzero(): session.finish() shutdown_session() - with pytest.raises(ValueError): - save_checkpoint(epoch=2) - def test_encode_data(): def train_func(): @@ -242,6 +236,42 @@ def train_2(): shutdown_session() +def reset_log_once_with_str(str_to_append=None): + key = SESSION_MISUSE_LOG_ONCE_KEY + if str_to_append: + key += f"-{str_to_append}" + ray.util.debug.reset_log_once(key) + + +@pytest.mark.parametrize( + "fn", [load_checkpoint, save_checkpoint, report, get_dataset_shard] +) +def test_warn(fn): + """Checks if calling train functions outside of session raises warning.""" + + with pytest.warns(UserWarning) as record: + fn() + + assert fn.__name__ in record[0].message.args[0] + + reset_log_once_with_str(fn.__name__) + + +def test_warn_once(): + """Checks if session misuse warning is only shown once per function.""" + + with pytest.warns(UserWarning) as record: + assert not load_checkpoint() + assert not load_checkpoint() + assert not save_checkpoint(x=2) + assert not report(x=2) + assert not report(x=3) + assert not get_dataset_shard() + + # Should only warn once. + assert len(record) == 4 + + if __name__ == "__main__": import pytest import sys diff --git a/python/ray/util/debug.py b/python/ray/util/debug.py index 180d4358c01c..189a8d80807c 100644 --- a/python/ray/util/debug.py +++ b/python/ray/util/debug.py @@ -44,3 +44,9 @@ def enable_periodic_logging(): global _periodic_log _periodic_log = True + + +def reset_log_once(key): + """Resets log_once for the provided key.""" + + _logged.discard(key)