Skip to content

Commit

Permalink
Format Python code with Black
Browse files Browse the repository at this point in the history
  • Loading branch information
bveeramani committed Jan 30, 2022
1 parent 34d336a commit 8d45502
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 58 deletions.
99 changes: 60 additions & 39 deletions python/ray/train/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,17 @@

import ray
from ray.train.constants import (
DETAILED_AUTOFILLED_KEYS, TIME_THIS_ITER_S, PID, TIMESTAMP, TIME_TOTAL_S,
NODE_IP, TRAINING_ITERATION, HOSTNAME, DATE, RESULT_FETCH_TIMEOUT)
DETAILED_AUTOFILLED_KEYS,
TIME_THIS_ITER_S,
PID,
TIMESTAMP,
TIME_TOTAL_S,
NODE_IP,
TRAINING_ITERATION,
HOSTNAME,
DATE,
RESULT_FETCH_TIMEOUT,
)
from ray.train.utils import PropagatingThread, RayDataset
from ray.util import PublicAPI

Expand All @@ -32,21 +41,22 @@ class TrainingResult:
class Session:
"""Holds information for training on each worker."""

def __init__(self,
training_func: Callable,
world_rank: int,
local_rank: int,
world_size: int,
dataset_shard: Optional[RayDataset] = None,
checkpoint: Optional[Dict] = None,
encode_data_fn: Callable = None,
detailed_autofilled_metrics: bool = False):
def __init__(
self,
training_func: Callable,
world_rank: int,
local_rank: int,
world_size: int,
dataset_shard: Optional[RayDataset] = None,
checkpoint: Optional[Dict] = None,
encode_data_fn: Callable = None,
detailed_autofilled_metrics: bool = False,
):

self.dataset_shard = dataset_shard

# The Thread object that is running the training function.
self.training_thread = PropagatingThread(
target=training_func, daemon=True)
self.training_thread = PropagatingThread(target=training_func, daemon=True)
self.world_rank = world_rank
self.local_rank = local_rank
self.world_size = world_size
Expand Down Expand Up @@ -115,8 +125,7 @@ def get_next(self) -> Optional[TrainingResult]:
# While training is still ongoing, attempt to get the result.
while result is None and self.training_thread.is_alive():
try:
result = self.result_queue.get(
block=True, timeout=RESULT_FETCH_TIMEOUT)
result = self.result_queue.get(block=True, timeout=RESULT_FETCH_TIMEOUT)
except queue.Empty:
pass

Expand All @@ -127,7 +136,8 @@ def get_next(self) -> Optional[TrainingResult]:
# termination of the thread runner.
try:
result = self.result_queue.get(
block=False, timeout=RESULT_FETCH_TIMEOUT)
block=False, timeout=RESULT_FETCH_TIMEOUT
)
except queue.Empty:
pass

Expand Down Expand Up @@ -157,7 +167,7 @@ def _auto_fill_metrics(self, result: dict) -> dict:
PID: os.getpid(),
HOSTNAME: platform.node(),
NODE_IP: self.local_ip,
TRAINING_ITERATION: self.iteration
TRAINING_ITERATION: self.iteration,
}

if not self.detailed_autofilled_metrics:
Expand Down Expand Up @@ -211,8 +221,7 @@ def checkpoint(self, **kwargs):
if self.world_rank != 0:
kwargs = {}
else:
kwargs = self._encode_data_fn(
self._auto_fill_checkpoint_metrics(kwargs))
kwargs = self._encode_data_fn(self._auto_fill_checkpoint_metrics(kwargs))

result = TrainingResult(TrainingResultType.CHECKPOINT, kwargs)
# Add result to a thread-safe queue.
Expand All @@ -229,8 +238,10 @@ def checkpoint(self, **kwargs):
def init_session(*args, **kwargs) -> None:
global _session
if _session:
raise ValueError("A Train session is already in use. Do not call "
"`init_session()` manually.")
raise ValueError(
"A Train session is already in use. Do not call "
"`init_session()` manually."
)
_session = Session(*args, **kwargs)


Expand All @@ -246,8 +257,7 @@ def shutdown_session():


@PublicAPI(stability="beta")
def get_dataset_shard(
dataset_name: Optional[str] = None) -> Optional[RayDataset]:
def get_dataset_shard(dataset_name: Optional[str] = None) -> Optional[RayDataset]:
"""Returns the Ray Dataset or DatasetPipeline shard for this worker.
You should call ``to_torch()`` or ``to_tf()`` on this shard to convert
Expand Down Expand Up @@ -285,22 +295,27 @@ def train_func():
"""
session = get_session()
if session is None:
warnings.warn("`train.get_dataset_shard()` is meant to only be called "
"inside a training function that is executed by "
"`Trainer.run`.")
warnings.warn(
"`train.get_dataset_shard()` is meant to only be called "
"inside a training function that is executed by "
"`Trainer.run`."
)
return
shard = session.dataset_shard
if shard is None:
warnings.warn("No dataset passed in. Returning None. Make sure to "
"pass in a Ray Dataset to Trainer.run to use this "
"function.")
warnings.warn(
"No dataset passed in. Returning None. Make sure to "
"pass in a Ray Dataset to Trainer.run to use this "
"function."
)
elif isinstance(shard, dict):
if not dataset_name:
raise RuntimeError(
"Multiple datasets were passed into ``Trainer``, "
"but no ``dataset_name`` is passed into "
"``get_dataset_shard``. Please specify which "
"dataset shard to retrieve.")
"dataset shard to retrieve."
)
return shard[dataset_name]
return shard

Expand Down Expand Up @@ -331,9 +346,11 @@ def train_func():
"""
session = get_session()
if session is None:
warnings.warn("`train.report()` is meant to only be called "
"inside a training function that is executed by "
"`Trainer.run`.")
warnings.warn(
"`train.report()` is meant to only be called "
"inside a training function that is executed by "
"`Trainer.run`."
)
return
session.report(**kwargs)

Expand Down Expand Up @@ -420,9 +437,11 @@ def train_func():
"""
session = get_session()
if session is None:
warnings.warn("`train.load_checkpoint()` is meant to only be called "
"inside a training function that is executed by "
"`Trainer.run`.")
warnings.warn(
"`train.load_checkpoint()` is meant to only be called "
"inside a training function that is executed by "
"`Trainer.run`."
)
return
return session.loaded_checkpoint

Expand Down Expand Up @@ -451,9 +470,11 @@ def train_func():
"""
session = get_session()
if session is None:
warnings.warn("`train.save_checkpoint()` is meant to only be called "
"inside a training function that is executed by "
"`Trainer.run`.")
warnings.warn(
"`train.save_checkpoint()` is meant to only be called "
"inside a training function that is executed by "
"`Trainer.run`."
)
return
session.checkpoint(**kwargs)

Expand Down
44 changes: 25 additions & 19 deletions python/ray/train/tests/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,19 @@
import pytest

import ray
from ray.train.session import init_session, shutdown_session, \
get_session, world_rank, local_rank, report, save_checkpoint, \
TrainingResultType, load_checkpoint, get_dataset_shard, world_size
from ray.train.session import (
init_session,
shutdown_session,
get_session,
world_rank,
local_rank,
report,
save_checkpoint,
TrainingResultType,
load_checkpoint,
get_dataset_shard,
world_size,
)


@pytest.fixture(scope="function")
Expand Down Expand Up @@ -62,7 +72,8 @@ def test_get_dataset_shard():
world_rank=0,
local_rank=0,
world_size=1,
dataset_shard=dataset)
dataset_shard=dataset,
)
assert get_dataset_shard() == dataset
shutdown_session()

Expand All @@ -72,8 +83,7 @@ def train_func():
for i in range(2):
report(loss=i)

init_session(
training_func=train_func, world_rank=0, local_rank=0, world_size=1)
init_session(training_func=train_func, world_rank=0, local_rank=0, world_size=1)
session = get_session()
session.start()
assert session.get_next().data["loss"] == 0
Expand All @@ -90,8 +100,7 @@ def train_func():
report(i)
return 1

init_session(
training_func=train_func, world_rank=0, local_rank=0, world_size=1)
init_session(training_func=train_func, world_rank=0, local_rank=0, world_size=1)
session = get_session()
session.start()
assert session.get_next() is None
Expand Down Expand Up @@ -125,8 +134,7 @@ def validate_zero(expected):
assert next.type == TrainingResultType.CHECKPOINT
assert next.data["epoch"] == expected

init_session(
training_func=train_func, world_rank=0, local_rank=0, world_size=1)
init_session(training_func=train_func, world_rank=0, local_rank=0, world_size=1)
session = get_session()
session.start()
validate_zero(0)
Expand All @@ -140,8 +148,7 @@ def validate_nonzero():
assert next.type == TrainingResultType.CHECKPOINT
assert next.data == {}

init_session(
training_func=train_func, world_rank=1, local_rank=1, world_size=1)
init_session(training_func=train_func, world_rank=1, local_rank=1, world_size=1)
session = get_session()
session.start()
validate_nonzero()
Expand Down Expand Up @@ -172,7 +179,8 @@ def validate_encoded(result_type: TrainingResultType):
world_rank=0,
local_rank=0,
world_size=1,
encode_data_fn=encode_checkpoint)
encode_data_fn=encode_checkpoint,
)

session = get_session()
session.start()
Expand All @@ -191,8 +199,7 @@ def train_func():
checkpoint = load_checkpoint()
assert checkpoint["epoch"] == i

init_session(
training_func=train_func, world_rank=0, local_rank=0, world_size=1)
init_session(training_func=train_func, world_rank=0, local_rank=0, world_size=1)
session = get_session()
session.start()
for i in range(2):
Expand All @@ -206,10 +213,10 @@ def test_locking():

def train_1():
import _thread

_thread.interrupt_main()

init_session(
training_func=train_1, world_rank=0, local_rank=0, world_size=1)
init_session(training_func=train_1, world_rank=0, local_rank=0, world_size=1)
session = get_session()
with pytest.raises(KeyboardInterrupt):
session.start()
Expand All @@ -220,8 +227,7 @@ def train_2():
report(loss=i)
train_1()

init_session(
training_func=train_2, world_rank=0, local_rank=0, world_size=1)
init_session(training_func=train_2, world_rank=0, local_rank=0, world_size=1)
session = get_session()
session.start()
time.sleep(3)
Expand Down

0 comments on commit 8d45502

Please sign in to comment.