diff --git a/python/ray/train/session.py b/python/ray/train/session.py index 9b91d086a5e7..bac120fb78a4 100644 --- a/python/ray/train/session.py +++ b/python/ray/train/session.py @@ -12,8 +12,17 @@ import ray from ray.train.constants import ( - DETAILED_AUTOFILLED_KEYS, TIME_THIS_ITER_S, PID, TIMESTAMP, TIME_TOTAL_S, - NODE_IP, TRAINING_ITERATION, HOSTNAME, DATE, RESULT_FETCH_TIMEOUT) + DETAILED_AUTOFILLED_KEYS, + TIME_THIS_ITER_S, + PID, + TIMESTAMP, + TIME_TOTAL_S, + NODE_IP, + TRAINING_ITERATION, + HOSTNAME, + DATE, + RESULT_FETCH_TIMEOUT, +) from ray.train.utils import PropagatingThread, RayDataset from ray.util import PublicAPI @@ -32,21 +41,22 @@ class TrainingResult: class Session: """Holds information for training on each worker.""" - def __init__(self, - training_func: Callable, - world_rank: int, - local_rank: int, - world_size: int, - dataset_shard: Optional[RayDataset] = None, - checkpoint: Optional[Dict] = None, - encode_data_fn: Callable = None, - detailed_autofilled_metrics: bool = False): + def __init__( + self, + training_func: Callable, + world_rank: int, + local_rank: int, + world_size: int, + dataset_shard: Optional[RayDataset] = None, + checkpoint: Optional[Dict] = None, + encode_data_fn: Callable = None, + detailed_autofilled_metrics: bool = False, + ): self.dataset_shard = dataset_shard # The Thread object that is running the training function. - self.training_thread = PropagatingThread( - target=training_func, daemon=True) + self.training_thread = PropagatingThread(target=training_func, daemon=True) self.world_rank = world_rank self.local_rank = local_rank self.world_size = world_size @@ -115,8 +125,7 @@ def get_next(self) -> Optional[TrainingResult]: # While training is still ongoing, attempt to get the result. while result is None and self.training_thread.is_alive(): try: - result = self.result_queue.get( - block=True, timeout=RESULT_FETCH_TIMEOUT) + result = self.result_queue.get(block=True, timeout=RESULT_FETCH_TIMEOUT) except queue.Empty: pass @@ -127,7 +136,8 @@ def get_next(self) -> Optional[TrainingResult]: # termination of the thread runner. try: result = self.result_queue.get( - block=False, timeout=RESULT_FETCH_TIMEOUT) + block=False, timeout=RESULT_FETCH_TIMEOUT + ) except queue.Empty: pass @@ -157,7 +167,7 @@ def _auto_fill_metrics(self, result: dict) -> dict: PID: os.getpid(), HOSTNAME: platform.node(), NODE_IP: self.local_ip, - TRAINING_ITERATION: self.iteration + TRAINING_ITERATION: self.iteration, } if not self.detailed_autofilled_metrics: @@ -211,8 +221,7 @@ def checkpoint(self, **kwargs): if self.world_rank != 0: kwargs = {} else: - kwargs = self._encode_data_fn( - self._auto_fill_checkpoint_metrics(kwargs)) + kwargs = self._encode_data_fn(self._auto_fill_checkpoint_metrics(kwargs)) result = TrainingResult(TrainingResultType.CHECKPOINT, kwargs) # Add result to a thread-safe queue. @@ -229,8 +238,10 @@ def checkpoint(self, **kwargs): def init_session(*args, **kwargs) -> None: global _session if _session: - raise ValueError("A Train session is already in use. Do not call " - "`init_session()` manually.") + raise ValueError( + "A Train session is already in use. Do not call " + "`init_session()` manually." + ) _session = Session(*args, **kwargs) @@ -246,8 +257,7 @@ def shutdown_session(): @PublicAPI(stability="beta") -def get_dataset_shard( - dataset_name: Optional[str] = None) -> Optional[RayDataset]: +def get_dataset_shard(dataset_name: Optional[str] = None) -> Optional[RayDataset]: """Returns the Ray Dataset or DatasetPipeline shard for this worker. You should call ``to_torch()`` or ``to_tf()`` on this shard to convert @@ -285,22 +295,27 @@ def train_func(): """ session = get_session() if session is None: - warnings.warn("`train.get_dataset_shard()` is meant to only be called " - "inside a training function that is executed by " - "`Trainer.run`.") + warnings.warn( + "`train.get_dataset_shard()` is meant to only be called " + "inside a training function that is executed by " + "`Trainer.run`." + ) return shard = session.dataset_shard if shard is None: - warnings.warn("No dataset passed in. Returning None. Make sure to " - "pass in a Ray Dataset to Trainer.run to use this " - "function.") + warnings.warn( + "No dataset passed in. Returning None. Make sure to " + "pass in a Ray Dataset to Trainer.run to use this " + "function." + ) elif isinstance(shard, dict): if not dataset_name: raise RuntimeError( "Multiple datasets were passed into ``Trainer``, " "but no ``dataset_name`` is passed into " "``get_dataset_shard``. Please specify which " - "dataset shard to retrieve.") + "dataset shard to retrieve." + ) return shard[dataset_name] return shard @@ -331,9 +346,11 @@ def train_func(): """ session = get_session() if session is None: - warnings.warn("`train.report()` is meant to only be called " - "inside a training function that is executed by " - "`Trainer.run`.") + warnings.warn( + "`train.report()` is meant to only be called " + "inside a training function that is executed by " + "`Trainer.run`." + ) return session.report(**kwargs) @@ -420,9 +437,11 @@ def train_func(): """ session = get_session() if session is None: - warnings.warn("`train.load_checkpoint()` is meant to only be called " - "inside a training function that is executed by " - "`Trainer.run`.") + warnings.warn( + "`train.load_checkpoint()` is meant to only be called " + "inside a training function that is executed by " + "`Trainer.run`." + ) return return session.loaded_checkpoint @@ -451,9 +470,11 @@ def train_func(): """ session = get_session() if session is None: - warnings.warn("`train.save_checkpoint()` is meant to only be called " - "inside a training function that is executed by " - "`Trainer.run`.") + warnings.warn( + "`train.save_checkpoint()` is meant to only be called " + "inside a training function that is executed by " + "`Trainer.run`." + ) return session.checkpoint(**kwargs) diff --git a/python/ray/train/tests/test_session.py b/python/ray/train/tests/test_session.py index 37b225d8b173..24902c08a7db 100644 --- a/python/ray/train/tests/test_session.py +++ b/python/ray/train/tests/test_session.py @@ -3,9 +3,19 @@ import pytest import ray -from ray.train.session import init_session, shutdown_session, \ - get_session, world_rank, local_rank, report, save_checkpoint, \ - TrainingResultType, load_checkpoint, get_dataset_shard, world_size +from ray.train.session import ( + init_session, + shutdown_session, + get_session, + world_rank, + local_rank, + report, + save_checkpoint, + TrainingResultType, + load_checkpoint, + get_dataset_shard, + world_size, +) @pytest.fixture(scope="function") @@ -62,7 +72,8 @@ def test_get_dataset_shard(): world_rank=0, local_rank=0, world_size=1, - dataset_shard=dataset) + dataset_shard=dataset, + ) assert get_dataset_shard() == dataset shutdown_session() @@ -72,8 +83,7 @@ def train_func(): for i in range(2): report(loss=i) - init_session( - training_func=train_func, world_rank=0, local_rank=0, world_size=1) + init_session(training_func=train_func, world_rank=0, local_rank=0, world_size=1) session = get_session() session.start() assert session.get_next().data["loss"] == 0 @@ -90,8 +100,7 @@ def train_func(): report(i) return 1 - init_session( - training_func=train_func, world_rank=0, local_rank=0, world_size=1) + init_session(training_func=train_func, world_rank=0, local_rank=0, world_size=1) session = get_session() session.start() assert session.get_next() is None @@ -125,8 +134,7 @@ def validate_zero(expected): assert next.type == TrainingResultType.CHECKPOINT assert next.data["epoch"] == expected - init_session( - training_func=train_func, world_rank=0, local_rank=0, world_size=1) + init_session(training_func=train_func, world_rank=0, local_rank=0, world_size=1) session = get_session() session.start() validate_zero(0) @@ -140,8 +148,7 @@ def validate_nonzero(): assert next.type == TrainingResultType.CHECKPOINT assert next.data == {} - init_session( - training_func=train_func, world_rank=1, local_rank=1, world_size=1) + init_session(training_func=train_func, world_rank=1, local_rank=1, world_size=1) session = get_session() session.start() validate_nonzero() @@ -172,7 +179,8 @@ def validate_encoded(result_type: TrainingResultType): world_rank=0, local_rank=0, world_size=1, - encode_data_fn=encode_checkpoint) + encode_data_fn=encode_checkpoint, + ) session = get_session() session.start() @@ -191,8 +199,7 @@ def train_func(): checkpoint = load_checkpoint() assert checkpoint["epoch"] == i - init_session( - training_func=train_func, world_rank=0, local_rank=0, world_size=1) + init_session(training_func=train_func, world_rank=0, local_rank=0, world_size=1) session = get_session() session.start() for i in range(2): @@ -206,10 +213,10 @@ def test_locking(): def train_1(): import _thread + _thread.interrupt_main() - init_session( - training_func=train_1, world_rank=0, local_rank=0, world_size=1) + init_session(training_func=train_1, world_rank=0, local_rank=0, world_size=1) session = get_session() with pytest.raises(KeyboardInterrupt): session.start() @@ -220,8 +227,7 @@ def train_2(): report(loss=i) train_1() - init_session( - training_func=train_2, world_rank=0, local_rank=0, world_size=1) + init_session(training_func=train_2, world_rank=0, local_rank=0, world_size=1) session = get_session() session.start() time.sleep(3)