Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
amogkam committed Jan 29, 2022
1 parent fe1bf02 commit 34d336a
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 18 deletions.
1 change: 1 addition & 0 deletions python/ray/train/examples/torch_quick_start.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def train_func():
# __torch_distributed_begin__

from ray import train
import ray.train.torch

def train_func_distributed():
num_epochs = 3
Expand Down
31 changes: 26 additions & 5 deletions python/ray/train/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,11 +236,6 @@ def init_session(*args, **kwargs) -> None:

def get_session() -> Session:
global _session
if _session is None or not isinstance(_session, Session):
raise ValueError("Trying to access a Train session that has not been "
"initialized yet. Train functions like "
"`train.report()` should only be called from inside "
"the training function.")
return _session


Expand Down Expand Up @@ -289,6 +284,11 @@ def train_func():
If no dataset is passed into Trainer, then return None.
"""
session = get_session()
if session is None:
warnings.warn("`train.get_dataset_shard()` is meant to only be called "
"inside a training function that is executed by "
"`Trainer.run`.")
return
shard = session.dataset_shard
if shard is None:
warnings.warn("No dataset passed in. Returning None. Make sure to "
Expand Down Expand Up @@ -330,6 +330,11 @@ def train_func():
intermediate results.
"""
session = get_session()
if session is None:
warnings.warn("`train.report()` is meant to only be called "
"inside a training function that is executed by "
"`Trainer.run`.")
return
session.report(**kwargs)


Expand All @@ -355,6 +360,8 @@ def train_func():
"""
session = get_session()
if session is None:
return 0
return session.world_rank


Expand All @@ -379,6 +386,8 @@ def train_func():
"""
session = get_session()
if session is None:
return 0
return session.local_rank


Expand Down Expand Up @@ -410,6 +419,11 @@ def train_func():
originally initialized with. ``None`` if neither exist.
"""
session = get_session()
if session is None:
warnings.warn("`train.load_checkpoint()` is meant to only be called "
"inside a training function that is executed by "
"`Trainer.run`.")
return
return session.loaded_checkpoint


Expand All @@ -436,6 +450,11 @@ def train_func():
**kwargs: Any key value pair to be checkpointed by Train.
"""
session = get_session()
if session is None:
warnings.warn("`train.save_checkpoint()` is meant to only be called "
"inside a training function that is executed by "
"`Trainer.run`.")
return
session.checkpoint(**kwargs)


Expand All @@ -457,4 +476,6 @@ def train_func():
trainer.shutdown()
"""
session = get_session()
if session is None:
return 1
return session.world_size
25 changes: 12 additions & 13 deletions python/ray/train/tests/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,31 +23,30 @@ def test_init_fail(session):
init_session(lambda: 1, 0)


def test_get_fail(session):
def test_shutdown(session):
shutdown_session()
with pytest.raises(ValueError):
get_session()
assert not get_session


def test_world_rank(session):
assert world_rank() == 0
shutdown_session()
with pytest.raises(ValueError):
world_rank()
# Make sure default to 0.
assert world_rank() == 0


def test_local_rank(session):
assert local_rank() == 0
shutdown_session()
with pytest.raises(ValueError):
local_rank()
# Make sure default to 0.
assert local_rank() == 0


def test_world_size(session):
assert world_size() == 1
shutdown_session()
with pytest.raises(ValueError):
world_size()
# Make sure default to 1.
assert world_size() == 1


def test_train(session):
Expand Down Expand Up @@ -81,8 +80,8 @@ def train_func():
assert session.get_next().data["loss"] == 1
shutdown_session()

with pytest.raises(ValueError):
report(loss=2)
# Should not raise error outside of session.
report(loss=2)


def test_report_fail():
Expand Down Expand Up @@ -150,8 +149,8 @@ def validate_nonzero():
session.finish()
shutdown_session()

with pytest.raises(ValueError):
save_checkpoint(epoch=2)
# Should not raise error outside of session.
save_checkpoint(epoch=2)


def test_encode_data():
Expand Down

0 comments on commit 34d336a

Please sign in to comment.