Skip to content

Commit

Permalink
[Train] Allow train methods to be called outside of the session (ra…
Browse files Browse the repository at this point in the history
…y-project#21969)

Updates to address @worldveil's feedback:

Include import train.torch in the docs
Allow methods in session.py to be called outside of the session with sensible defaults. These will no longer raise an error.

Co-authored-by: Balaji Veeramani <[email protected]>
  • Loading branch information
2 people authored and simonsays1980 committed Feb 27, 2022
1 parent 57e4548 commit f0050a0
Show file tree
Hide file tree
Showing 6 changed files with 97 additions and 28 deletions.
8 changes: 4 additions & 4 deletions python/ray/train/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -637,14 +637,14 @@ def __len__(self):


def _get_session(method_name: str):
try:
# Get the session for this worker.
return get_session()
except ValueError:
# Get the session for this worker.
session = get_session()
if not session:
# Session is not initialized yet.
raise TrainBackendError(
f"`{method_name}` has been called "
"before `start_training`. Please call "
"`start_training` before "
f"`{method_name}`."
)
return session
4 changes: 4 additions & 0 deletions python/ray/train/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,7 @@
# Integer value which if set will change the placement group strategy from
# PACK to SPREAD. 1 for True, 0 for False.
TRAIN_ENABLE_WORKER_SPREAD_ENV = "TRAIN_ENABLE_WORKER_SPREAD"

# The key used to identify whether we have already warned about ray.train
# functions being used outside of the session
SESSION_MISUSE_LOG_ONCE_KEY = "train_warn_session_misuse"
1 change: 1 addition & 0 deletions python/ray/train/examples/torch_quick_start.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def train_func():
# __torch_distributed_begin__

from ray import train
import ray.train.torch

def train_func_distributed():
num_epochs = 3
Expand Down
46 changes: 37 additions & 9 deletions python/ray/train/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,10 @@
HOSTNAME,
DATE,
RESULT_FETCH_TIMEOUT,
SESSION_MISUSE_LOG_ONCE_KEY,
)
from ray.train.utils import PropagatingThread, RayDataset
from ray.util import PublicAPI
from ray.util import PublicAPI, log_once


class TrainingResultType(Enum):
Expand Down Expand Up @@ -235,6 +236,22 @@ def checkpoint(self, **kwargs):
_session = None


def _warn_session_misuse(fn_name: str):
"""Logs warning message on provided fn being used outside of session.
Args:
fn_name (str): The name of the function to warn about.
"""

if log_once(f"{SESSION_MISUSE_LOG_ONCE_KEY}-{fn_name}"):
warnings.warn(
f"`train.{fn_name}()` is meant to only be "
f"called "
"inside a training function that is executed by "
"`Trainer.run`. Returning None."
)


def init_session(*args, **kwargs) -> None:
global _session
if _session:
Expand All @@ -245,15 +262,8 @@ def init_session(*args, **kwargs) -> None:
_session = Session(*args, **kwargs)


def get_session() -> Session:
def get_session() -> Optional[Session]:
global _session
if _session is None or not isinstance(_session, Session):
raise ValueError(
"Trying to access a Train session that has not been "
"initialized yet. Train functions like "
"`train.report()` should only be called from inside "
"the training function."
)
return _session


Expand Down Expand Up @@ -301,6 +311,9 @@ def train_func():
If no dataset is passed into Trainer, then return None.
"""
session = get_session()
if session is None:
_warn_session_misuse(get_dataset_shard.__name__)
return
shard = session.dataset_shard
if shard is None:
warnings.warn(
Expand Down Expand Up @@ -345,6 +358,9 @@ def train_func():
intermediate results.
"""
session = get_session()
if session is None:
_warn_session_misuse(report.__name__)
return
session.report(**kwargs)


Expand All @@ -370,6 +386,8 @@ def train_func():
"""
session = get_session()
if session is None:
return 0
return session.world_rank


Expand All @@ -394,6 +412,8 @@ def train_func():
"""
session = get_session()
if session is None:
return 0
return session.local_rank


Expand Down Expand Up @@ -425,6 +445,9 @@ def train_func():
originally initialized with. ``None`` if neither exist.
"""
session = get_session()
if session is None:
_warn_session_misuse(load_checkpoint.__name__)
return
return session.loaded_checkpoint


Expand All @@ -451,6 +474,9 @@ def train_func():
**kwargs: Any key value pair to be checkpointed by Train.
"""
session = get_session()
if session is None:
_warn_session_misuse(save_checkpoint.__name__)
return
session.checkpoint(**kwargs)


Expand All @@ -472,4 +498,6 @@ def train_func():
trainer.shutdown()
"""
session = get_session()
if session is None:
return 1
return session.world_size
60 changes: 45 additions & 15 deletions python/ray/train/tests/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pytest

import ray
from ray.train.constants import SESSION_MISUSE_LOG_ONCE_KEY
from ray.train.session import (
init_session,
shutdown_session,
Expand Down Expand Up @@ -33,31 +34,30 @@ def test_init_fail(session):
init_session(lambda: 1, 0)


def test_get_fail(session):
def test_shutdown(session):
shutdown_session()
with pytest.raises(ValueError):
get_session()
assert not get_session()


def test_world_rank(session):
assert world_rank() == 0
shutdown_session()
with pytest.raises(ValueError):
world_rank()
# Make sure default to 0.
assert world_rank() == 0


def test_local_rank(session):
assert local_rank() == 0
shutdown_session()
with pytest.raises(ValueError):
local_rank()
# Make sure default to 0.
assert local_rank() == 0


def test_world_size(session):
assert world_size() == 1
shutdown_session()
with pytest.raises(ValueError):
world_size()
# Make sure default to 1.
assert world_size() == 1


def test_train(session):
Expand Down Expand Up @@ -91,9 +91,6 @@ def train_func():
assert session.get_next().data["loss"] == 1
shutdown_session()

with pytest.raises(ValueError):
report(loss=2)


def test_report_fail():
def train_func():
Expand Down Expand Up @@ -157,9 +154,6 @@ def validate_nonzero():
session.finish()
shutdown_session()

with pytest.raises(ValueError):
save_checkpoint(epoch=2)


def test_encode_data():
def train_func():
Expand Down Expand Up @@ -242,6 +236,42 @@ def train_2():
shutdown_session()


def reset_log_once_with_str(str_to_append=None):
key = SESSION_MISUSE_LOG_ONCE_KEY
if str_to_append:
key += f"-{str_to_append}"
ray.util.debug.reset_log_once(key)


@pytest.mark.parametrize(
"fn", [load_checkpoint, save_checkpoint, report, get_dataset_shard]
)
def test_warn(fn):
"""Checks if calling train functions outside of session raises warning."""

with pytest.warns(UserWarning) as record:
fn()

assert fn.__name__ in record[0].message.args[0]

reset_log_once_with_str(fn.__name__)


def test_warn_once():
"""Checks if session misuse warning is only shown once per function."""

with pytest.warns(UserWarning) as record:
assert not load_checkpoint()
assert not load_checkpoint()
assert not save_checkpoint(x=2)
assert not report(x=2)
assert not report(x=3)
assert not get_dataset_shard()

# Should only warn once.
assert len(record) == 4


if __name__ == "__main__":
import pytest
import sys
Expand Down
6 changes: 6 additions & 0 deletions python/ray/util/debug.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,9 @@ def enable_periodic_logging():

global _periodic_log
_periodic_log = True


def reset_log_once(key):
"""Resets log_once for the provided key."""

_logged.discard(key)

0 comments on commit f0050a0

Please sign in to comment.