From e2348ba294b791c8a304e23a2b6c2fae28f689fe Mon Sep 17 00:00:00 2001 From: Amog Kamsetty Date: Mon, 25 Apr 2022 14:37:46 -0700 Subject: [PATCH] separate to own module --- python/ray/ml/train/data_parallel_trainer.py | 3 +- python/ray/train/backend.py | 58 +++++++++--- python/ray/train/session.py | 3 +- python/ray/train/trainer.py | 30 +++---- python/ray/train/utils.py | 92 -------------------- 5 files changed, 62 insertions(+), 124 deletions(-) diff --git a/python/ray/ml/train/data_parallel_trainer.py b/python/ray/ml/train/data_parallel_trainer.py index 496c68a2b217..f1521e14d30e 100644 --- a/python/ray/ml/train/data_parallel_trainer.py +++ b/python/ray/ml/train/data_parallel_trainer.py @@ -14,7 +14,8 @@ from ray.train import BackendConfig, TrainingIterator from ray.train.backend import BackendExecutor from ray.train.checkpoint import TuneCheckpointManager -from ray.train.utils import construct_train_func, _RayDatasetSpec +from ray.train.utils import construct_train_func +from ray.train.dataset_spec import _RayDatasetSpec from ray.util.annotations import DeveloperAPI logger = logging.getLogger(__name__) diff --git a/python/ray/train/backend.py b/python/ray/train/backend.py index af6af3816a88..01773bca2cd8 100644 --- a/python/ray/train/backend.py +++ b/python/ray/train/backend.py @@ -1,7 +1,7 @@ import logging import os from collections import defaultdict -from typing import Callable, TypeVar, List, Optional, Dict, Type, Tuple +from typing import Callable, TypeVar, List, Optional, Dict, Union, Type, Tuple import ray from ray.exceptions import RayActorError @@ -14,7 +14,8 @@ ) from ray.train.session import TrainingResult from ray.train.session import init_session, get_session, shutdown_session -from ray.train.utils import _RayDatasetSpec, check_for_failure, Singleton +from ray.train.utils import check_for_failure, Singleton +from ray.train.dataset_spec import RayDataset from ray.train.worker_group import WorkerGroup from ray.util.annotations import DeveloperAPI from ray.util.placement_group import get_current_placement_group, remove_placement_group @@ -314,10 +315,42 @@ def _create_local_rank_map(self) -> Dict: ip_dict[node_ip] += 1 return rank_mapping + def _get_dataset_shards(self, dataset_or_dict): + + if dataset_or_dict is None: + # Return None for each shard. + return [None] * len(self.worker_group) + + def split_dataset(dataset_or_pipeline): + actors = [worker.actor for worker in self.worker_group.workers] + return dataset_or_pipeline.split( + len(self.worker_group), equal=True, locality_hints=actors + ) + + if isinstance(dataset_or_dict, dict): + # Return a smaller dict for each shard. + dataset_shards = [{} for _ in range(len(self.worker_group))] + # TODO(amog): Update Backend to accept a generic function with logic on + # how to split dataset, instead of having to support _NO-SHARD in key. + for key, dataset in dataset_or_dict.items(): + if "_NO-SHARD" in key: + # Do not shard this dataset. + split_datasets = [dataset] * len(self.worker_group) + key = key.replace("_NO-SHARD", "") + else: + split_datasets = split_dataset(dataset) + assert len(split_datasets) == len(self.worker_group) + for i in range(len(split_datasets)): + dataset_shards[i][key] = split_datasets[i] + return dataset_shards + else: + # return a smaller RayDataset for each shard. + return split_dataset(dataset_or_dict) + def start_training( self, train_func: Callable[[], T], - dataset_spec: _RayDatasetSpec = None, + dataset: Optional[Union[RayDataset, Dict[str, RayDataset]]] = None, checkpoint: Optional[Dict] = None, ) -> None: """Executes a training function on all workers in a separate thread. @@ -325,11 +358,17 @@ def start_training( ``finish_training`` should be called after this. Args: - train_func: The training function to run on each worker. - dataset_spec: A specification for the Ray Dataset to be - passed to the training workers, and the logic on how to shard the Ray - Dataset. - checkpoint: The checkpoint data that + train_func (Callable): The training function to run on each worker. + dataset (Optional[Union[Dataset, DatasetPipeline]]) + Distributed Ray Dataset or DatasetPipeline to pass into + worker, which can be accessed from the training function via + ``train.get_dataset_shard()``. Sharding will automatically be + handled by the Trainer. Multiple Datasets can be passed in as + a ``Dict`` that maps each name key to a Dataset value, + and each Dataset can be accessed from the training function + by passing in a `dataset_name` argument to + ``train.get_dataset_shard()``. + checkpoint (Optional[Dict]): The checkpoint data that should be loaded onto each worker and accessed by the training function via ``train.load_checkpoint()``. If this is ``None`` then no checkpoint will be loaded. @@ -368,8 +407,7 @@ def initialize_session( ) if self.dataset_shards is None: - actors = [worker.actor for worker in self.worker_group.workers] - self.dataset_shards = dataset_spec.get_dataset_shards(actors) + self.dataset_shards = self._get_dataset_shards(dataset) local_rank_map = self._create_local_rank_map() diff --git a/python/ray/train/session.py b/python/ray/train/session.py index 603aa5501cbb..5bb09e93db13 100644 --- a/python/ray/train/session.py +++ b/python/ray/train/session.py @@ -25,7 +25,8 @@ RESULT_FETCH_TIMEOUT, SESSION_MISUSE_LOG_ONCE_KEY, ) -from ray.train.utils import PropagatingThread, RayDataset +from ray.train.utils import PropagatingThread +from ray.train.dataset_spec import RayDataset from ray.util import PublicAPI, log_once diff --git a/python/ray/train/trainer.py b/python/ray/train/trainer.py index 86e07e47b7bb..740cc87fb242 100644 --- a/python/ray/train/trainer.py +++ b/python/ray/train/trainer.py @@ -15,12 +15,8 @@ ) from ray.train.callbacks.callback import TrainingCallback from ray.train.session import TrainingResultType -from ray.train.utils import ( - RayDataset, - construct_train_func, - ActorWrapper, - _RayDatasetSpec, -) +from ray.train.utils import construct_train_func, ActorWrapper +from ray.train.dataset_spec import RayDataset from ray.train.checkpoint import ( CheckpointStrategy, TuneCheckpointManager, @@ -325,14 +321,12 @@ def run( train_func = construct_train_func(train_func, config) - dataset_spec = _RayDatasetSpec(dataset_or_dict=dataset) - try: iterator = TrainingIterator( backend_executor=self._backend_executor, backend_config=self._backend_config, train_func=train_func, - dataset_spec=dataset_spec, + dataset=dataset, checkpoint_manager=self.checkpoint_manager, checkpoint=checkpoint, checkpoint_strategy=checkpoint_strategy, @@ -404,14 +398,12 @@ def train_func(config): train_func = construct_train_func(train_func, config) - dataset_spec = _RayDatasetSpec(dataset_or_dict=dataset) - return TrainingIterator( backend_executor=self._backend_executor, backend_config=self._backend_config, train_func=train_func, run_dir=self.latest_run_dir, - dataset_spec=dataset_spec, + dataset=dataset, checkpoint_manager=self.checkpoint_manager, checkpoint=checkpoint, checkpoint_strategy=checkpoint_strategy, @@ -643,7 +635,7 @@ def __init__( backend_executor: Union[BackendExecutor, ActorWrapper], backend_config: BackendConfig, train_func: Union[Callable[[], T], Callable[[Dict[str, Any]], T]], - dataset_spec: _RayDatasetSpec, + dataset: Optional[Union[RayDataset, Dict[str, RayDataset]]], checkpoint_manager: CheckpointManager, checkpoint: Optional[Union[Dict, str, Path]], checkpoint_strategy: Optional[CheckpointStrategy], @@ -652,14 +644,14 @@ def __init__( self._backend_executor = backend_executor self._backend = backend_config.backend_cls() self._train_func = train_func - self._dataset_spec = dataset_spec + self._dataset = dataset self._run_dir = run_dir self._checkpoint_manager = checkpoint_manager self._checkpoint_strategy = checkpoint_strategy self._start_training( train_func=train_func, run_dir=run_dir, - dataset_spec=self._dataset_spec, + dataset=dataset, checkpoint=checkpoint, checkpoint_strategy=checkpoint_strategy, ) @@ -674,7 +666,7 @@ def _start_training( self, train_func, run_dir, - dataset_spec, + dataset, checkpoint, checkpoint_strategy, latest_checkpoint_id=None, @@ -687,9 +679,7 @@ def _start_training( checkpoint_dict = self._checkpoint_manager._load_checkpoint(checkpoint) self._run_with_error_handling( lambda: self._backend_executor.start_training( - train_func=train_func, - dataset_config=dataset_spec, - checkpoint=checkpoint_dict, + train_func=train_func, dataset=dataset, checkpoint=checkpoint_dict ) ) @@ -708,7 +698,7 @@ def _run_with_error_handling(self, func: Callable): self._start_training( self._train_func, self._run_dir, - self._dataset_spec, + self._dataset, self._checkpoint_manager.latest_checkpoint, self._checkpoint_strategy, latest_checkpoint_id=self._checkpoint_manager.latest_checkpoint_id, diff --git a/python/ray/train/utils.py b/python/ray/train/utils.py index 051f854912cc..8bb12d650bcf 100644 --- a/python/ray/train/utils.py +++ b/python/ray/train/utils.py @@ -1,5 +1,4 @@ import abc -from dataclasses import dataclass import inspect import os import logging @@ -11,7 +10,6 @@ Dict, List, Any, - TYPE_CHECKING, Union, Callable, TypeVar, @@ -24,11 +22,6 @@ from ray.types import ObjectRef from ray.util.ml_utils.util import find_free_port -if TYPE_CHECKING: - from ray.data import Dataset - from ray.data.dataset_pipeline import DatasetPipeline - -RayDataset = Union["Dataset", "DatasetPipeline"] T = TypeVar("T") logger = logging.getLogger(__name__) @@ -173,88 +166,3 @@ def __getattr__(self, item): # actor. actor_method = getattr(self.actor, item) return lambda *args, **kwargs: ray.get(actor_method.remote(*args, **kwargs)) - - -@dataclass -class _RayDatasetSpec: - """Configuration for Ray Datasets to pass to the training workers. - - dataset_or_dict: An optional Ray Dataset (or DatasetPipeline) or a dictionary of - datasets to be sharded across all the training workers, which can be accessed - from the training function via ``train.get_dataset_shard()``. Multiple Datasets - can be passed in as a ``Dict`` that maps each name key to a Dataset value, - and each Dataset can be accessed from the training function by passing in a - `dataset_name` argument to ``train.get_dataset_shard()``. - dataset_split_fn: An optional callable to specify how the provided ``dataset`` - should be split across the training workers. It is expected to take in two - arguments. The first one is the ``dataset``, just as is passed in to the - ``_RayDatasetSpec``. The second argument is a list of the ActorHandles of the - training workers (to use as locality hints). The Callable is expected to - return a list of RayDatasets or a list of dictionaries of RayDatasets, - with the length of the list equal to the length of the list of actor handles. - If None is provided, the provided Ray Dataset(s) will be simply be split using - the actor handles as locality hints. - - """ - - dataset_or_dict: Optional[Union[RayDataset, Dict[str, RayDataset]]] - dataset_split_fn: Optional[ - Callable[ - [Union[RayDataset, Dict[str, RayDataset]], List[ActorHandle]], - List[Union[RayDataset, Dict[str, RayDataset]]], - ] - ] = None - - def _default_split_fn( - self, training_worker_handles: List[ActorHandle] - ) -> List[Optional[Union[RayDataset, Dict[str, RayDataset]]]]: - def split_dataset(dataset_or_pipeline): - return dataset_or_pipeline.split( - len(training_worker_handles), - equal=True, - locality_hints=training_worker_handles, - ) - - if isinstance(self.dataset_or_dict, dict): - # Return a smaller dict for each shard. - dataset_shards = [{} for _ in range(len(self.worker_group))] - for key, dataset in self.dataset_or_dict.items(): - split_datasets = split_dataset(dataset) - assert len(split_datasets) == len(self.worker_group) - for i in range(len(split_datasets)): - dataset_shards[i][key] = split_datasets[i] - return dataset_shards - else: - # return a smaller RayDataset for each shard. - return split_dataset(self.dataset_or_dict) - - def get_dataset_shards( - self, training_worker_handles: List[ActorHandle] - ) -> List[Optional[Union[RayDataset, Dict[str, RayDataset]]]]: - """Returns Dataset splits based off the spec and the given training workers - - Args: - training_worker_handles: A list of the training worker actor handles. - - Returns: - A list of RayDataset shards or list of dictionaries of RayDataset shards, - one for each training worker. - - """ - if self.dataset_or_dict is None: - # If no Dataset is provided, return None for each shard. - return [None] * len(training_worker_handles) - - if self.dataset_split_fn is None: - return self._default_split_fn(training_worker_handles) - else: - splits = self.dataset_split_fn( - self.dataset_or_dict, training_worker_handles - ) - if not len(splits) == len(training_worker_handles): - raise RuntimeError( - "The list of Datasets returned by the " - f"`dataset_split_fn`: {len(splits)} does not match " - f"the number of training workers: {len(training_worker_handles)}" - ) - return splits