diff --git a/python/ray/train/examples/torch_quick_start.py b/python/ray/train/examples/torch_quick_start.py index 27f784071ffb..533227602bdb 100644 --- a/python/ray/train/examples/torch_quick_start.py +++ b/python/ray/train/examples/torch_quick_start.py @@ -49,6 +49,7 @@ def train_func(): # __torch_distributed_begin__ from ray import train +import ray.train.torch def train_func_distributed(): num_epochs = 3 diff --git a/python/ray/train/session.py b/python/ray/train/session.py index 55a04c5aedba..9b91d086a5e7 100644 --- a/python/ray/train/session.py +++ b/python/ray/train/session.py @@ -236,11 +236,6 @@ def init_session(*args, **kwargs) -> None: def get_session() -> Session: global _session - if _session is None or not isinstance(_session, Session): - raise ValueError("Trying to access a Train session that has not been " - "initialized yet. Train functions like " - "`train.report()` should only be called from inside " - "the training function.") return _session @@ -289,6 +284,11 @@ def train_func(): If no dataset is passed into Trainer, then return None. """ session = get_session() + if session is None: + warnings.warn("`train.get_dataset_shard()` is meant to only be called " + "inside a training function that is executed by " + "`Trainer.run`.") + return shard = session.dataset_shard if shard is None: warnings.warn("No dataset passed in. Returning None. Make sure to " @@ -330,6 +330,11 @@ def train_func(): intermediate results. """ session = get_session() + if session is None: + warnings.warn("`train.report()` is meant to only be called " + "inside a training function that is executed by " + "`Trainer.run`.") + return session.report(**kwargs) @@ -355,6 +360,8 @@ def train_func(): """ session = get_session() + if session is None: + return 0 return session.world_rank @@ -379,6 +386,8 @@ def train_func(): """ session = get_session() + if session is None: + return 0 return session.local_rank @@ -410,6 +419,11 @@ def train_func(): originally initialized with. ``None`` if neither exist. """ session = get_session() + if session is None: + warnings.warn("`train.load_checkpoint()` is meant to only be called " + "inside a training function that is executed by " + "`Trainer.run`.") + return return session.loaded_checkpoint @@ -436,6 +450,11 @@ def train_func(): **kwargs: Any key value pair to be checkpointed by Train. """ session = get_session() + if session is None: + warnings.warn("`train.save_checkpoint()` is meant to only be called " + "inside a training function that is executed by " + "`Trainer.run`.") + return session.checkpoint(**kwargs) @@ -457,4 +476,6 @@ def train_func(): trainer.shutdown() """ session = get_session() + if session is None: + return 1 return session.world_size diff --git a/python/ray/train/tests/test_session.py b/python/ray/train/tests/test_session.py index f368c8384dd3..37b225d8b173 100644 --- a/python/ray/train/tests/test_session.py +++ b/python/ray/train/tests/test_session.py @@ -23,31 +23,30 @@ def test_init_fail(session): init_session(lambda: 1, 0) -def test_get_fail(session): +def test_shutdown(session): shutdown_session() - with pytest.raises(ValueError): - get_session() + assert not get_session def test_world_rank(session): assert world_rank() == 0 shutdown_session() - with pytest.raises(ValueError): - world_rank() + # Make sure default to 0. + assert world_rank() == 0 def test_local_rank(session): assert local_rank() == 0 shutdown_session() - with pytest.raises(ValueError): - local_rank() + # Make sure default to 0. + assert local_rank() == 0 def test_world_size(session): assert world_size() == 1 shutdown_session() - with pytest.raises(ValueError): - world_size() + # Make sure default to 1. + assert world_size() == 1 def test_train(session): @@ -81,8 +80,8 @@ def train_func(): assert session.get_next().data["loss"] == 1 shutdown_session() - with pytest.raises(ValueError): - report(loss=2) + # Should not raise error outside of session. + report(loss=2) def test_report_fail(): @@ -150,8 +149,8 @@ def validate_nonzero(): session.finish() shutdown_session() - with pytest.raises(ValueError): - save_checkpoint(epoch=2) + # Should not raise error outside of session. + save_checkpoint(epoch=2) def test_encode_data():