Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[train] wrap BackendExecutor in ray.remote() #20123

Merged
merged 20 commits into from
Nov 13, 2021
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions doc/examples/datasets_train/datasets_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,10 @@ def train_func(config):

# Checkpoint model.
if is_distributed:
train.save_checkpoint(model_state_dict=net.module.state_dict())
import copy
model_copy = copy.deepcopy(net.module)
matthewdeng marked this conversation as resolved.
Show resolved Hide resolved
train.save_checkpoint(
model_state_dict=model_copy.cpu().state_dict())
else:
torch.save(net.state_dict(), f"models/model-epoch-{epoch}.torch")

Expand All @@ -386,7 +389,7 @@ def train_func(config):

if is_distributed:
if train.world_rank() == 0:
return net.module
return net.module.cpu()
else:
return None
else:
Expand Down
8 changes: 8 additions & 0 deletions python/ray/train/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,14 @@ py_test(
deps = [":train_lib"]
)

py_test(
name = "test_examples",
size = "large",
srcs = ["tests/test_examples.py"],
tags = ["team:ml", "exclusive"],
deps = [":train_lib"]
)

py_test(
name = "test_gpu",
size = "large",
Expand Down
160 changes: 24 additions & 136 deletions python/ray/train/backends/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,25 +8,17 @@
import ray
from ray.exceptions import RayActorError
from ray.ray_constants import env_integer
from ray.train.checkpoint import CheckpointManager, CheckpointStrategy, \
TuneCheckpointManager
from ray.train.constants import ENABLE_DETAILED_AUTOFILLED_METRICS_ENV, \
TUNE_INSTALLED, ENABLE_SHARE_CUDA_VISIBLE_DEVICES_ENV, \
TRAIN_ENABLE_WORKER_SPREAD_ENV, \
TRAIN_PLACEMENT_GROUP_TIMEOUT_S_ENV
from ray.train.session import TrainingResultType, TrainingResult
ENABLE_SHARE_CUDA_VISIBLE_DEVICES_ENV, \
TRAIN_PLACEMENT_GROUP_TIMEOUT_S_ENV, TRAIN_ENABLE_WORKER_SPREAD_ENV
from ray.train.session import TrainingResult
from ray.train.session import init_session, get_session, shutdown_session
from ray.train.utils import RayDataset
from ray.train.utils import check_for_failure
from ray.train.worker_group import WorkerGroup
from ray.util.placement_group import get_current_placement_group, \
remove_placement_group

if TUNE_INSTALLED:
from ray import tune
else:
tune = None

T = TypeVar("T")

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -67,15 +59,6 @@ class BackendExecutor:
and ``num_gpus_per_worker``.
max_retries (int): Number of retries when Ray actors fail.
Defaults to 3. Set to -1 for unlimited retries.

Attributes:
latest_checkpoint_dir (Optional[Path]): Path to the file directory for
the checkpoints from the latest run. Configured through
``start_training``
best_checkpoint_path (Optional[Path]): Path to the best persisted
checkpoint from the latest run.
latest_checkpoint (Optional[Dict]): The latest saved checkpoint. This
checkpoint may not be saved to disk.
"""

def __init__(
Expand All @@ -99,16 +82,9 @@ def __init__(
self._initialization_hook = None
self._placement_group = None

if tune is not None and tune.is_session_enabled():
self.checkpoint_manager = TuneCheckpointManager()
else:
self.checkpoint_manager = CheckpointManager()

self.worker_group = InactiveWorkerGroup()
self.dataset_shards = None

self.checkpoint_manager.on_init()

def start(self,
initialization_hook: Optional[Callable[[], None]] = None,
train_cls: Optional[Type] = None,
Expand Down Expand Up @@ -304,10 +280,7 @@ def start_training(
train_func: Callable[[], T],
run_dir: Path,
dataset: Optional[Union[RayDataset, Dict[str, RayDataset]]] = None,
checkpoint: Optional[Union[Dict, str, Path]] = None,
checkpoint_strategy: Optional[CheckpointStrategy] = None,
latest_checkpoint_id: Optional[int] = None,
) -> None:
checkpoint: Optional[Dict] = None) -> None:
"""Executes a training function on all workers in a separate thread.

``finish_training`` should be called after this.
Expand All @@ -324,22 +297,11 @@ def start_training(
and each Dataset can be accessed from the training function
by passing in a `dataset_name` argument to
``train.get_dataset_shard()``.
checkpoint (Optional[Dict|str|Path]): The checkpoint data that
checkpoint (Optional[Dict]): The checkpoint data that
amogkam marked this conversation as resolved.
Show resolved Hide resolved
should be loaded onto each worker and accessed by the
training function via ``train.load_checkpoint()``. If this is a
``str`` or ``Path`` then the value is expected to be a path
to a file that contains a serialized checkpoint dict. If this
training function via ``train.load_checkpoint()``. If this
is ``None`` then no checkpoint will be loaded.
checkpoint_strategy (Optional[CheckpointStrategy]): The
configurations for saving checkpoints.
latest_checkpoint_id (Optional[int]): The checkpoint id of the
most recently saved checkpoint.
"""
self.checkpoint_manager.on_start_training(
checkpoint_strategy=checkpoint_strategy,
run_dir=run_dir,
latest_checkpoint_id=latest_checkpoint_id)

use_detailed_autofilled_metrics = env_integer(
ENABLE_DETAILED_AUTOFILLED_METRICS_ENV, 0)

Expand All @@ -365,8 +327,6 @@ def initialize_session(train_func, world_rank, local_rank, checkpoint,
if self.dataset_shards is None:
self.dataset_shards = self._get_dataset_shards(dataset)

checkpoint_dict = self.checkpoint_manager._load_checkpoint(checkpoint)

local_rank_map = self._create_local_rank_map()

futures = []
Expand All @@ -379,7 +339,7 @@ def initialize_session(train_func, world_rank, local_rank, checkpoint,
local_rank=local_rank_map[index],
train_func=train_func,
dataset_shard=self.dataset_shards[index],
checkpoint=checkpoint_dict))
checkpoint=checkpoint))

self.get_with_failure_handling(futures)

Expand All @@ -390,7 +350,7 @@ def train_async():

self.worker_group.execute_async(train_async)

def _get_next_results(self) -> Optional[List[TrainingResult]]:
def get_next_results(self) -> Optional[List[TrainingResult]]:
"""Fetches the next ``TrainingResult`` from each worker.

Each ``TrainingResult`` is expected to correspond to the same step from
amogkam marked this conversation as resolved.
Show resolved Hide resolved
Expand Down Expand Up @@ -451,49 +411,8 @@ def get_next():
"each worker.")
return results

def fetch_next_result(self) -> Optional[List[Dict]]:
"""Fetch next results produced by ``train.report()`` from each worker.

Assumes ``start_training`` has already been called.

Returns:
A list of dictionaries of values passed to ``train.report()`` from
each worker. Each item corresponds to an intermediate result
a single worker. If there are no more items to fetch,
returns None.
"""

while True:
results = self._get_next_results()
if results is None:
return None
first_result = results[0]
result_type = first_result.type
if result_type is TrainingResultType.REPORT:
result_data = [r.data for r in results]
return result_data
elif result_type is TrainingResultType.CHECKPOINT:
self.checkpoint_manager._process_checkpoint(results)
# Iterate until next REPORT call or training has finished.
else:
raise TrainBackendError(f"Unexpected result type: "
f"{result_type}. "
f"Expected one of "
f"{[type in TrainingResultType]}")

def finish_training(self) -> List[T]:
"""Finish training and return final results. Propagate any exceptions.

Blocks until training is finished on all workers.

Assumes `start_training` has already been called.
matthewdeng marked this conversation as resolved.
Show resolved Hide resolved

Returns:
A list of return values from calling ``train_func`` on each worker.
Each item corresponds to the return value from a single worker.
"""

def pause_reporting():
def pause_reporting(self):
amogkam marked this conversation as resolved.
Show resolved Hide resolved
def pause_session_reporting():
# Get the session for this worker.
try:
session = get_session()
matthewdeng marked this conversation as resolved.
Show resolved Hide resolved
Expand All @@ -506,6 +425,14 @@ def pause_reporting():

return session.pause_reporting()

# Disable workers from enqueuing results from `train.report()`.
# Results will not be processed during the execution of `finish`.
# Note: Reported results may still be enqueued at this point,
# and should be handled appropriately.
futures = self.worker_group.execute_async(pause_session_reporting)
self.get_with_failure_handling(futures)

def finish_training(self):
amogkam marked this conversation as resolved.
Show resolved Hide resolved
def end_training():
# Get the session for this worker.
try:
Expand All @@ -527,23 +454,6 @@ def end_training():

return output

# Disable workers from enqueuing results from `train.report()`.
# Results will not be processed during the execution of `finish`.
# Note: Reported results may still be enqueued at this point,
# and should be handled appropriately.
futures = self.worker_group.execute_async(pause_reporting)
self.get_with_failure_handling(futures)

# Finish up processing checkpoints. Reporting has been disabled.
while True:
results = self._get_next_results()
if results is None:
break
result_type = results[0].type
# Process checkpoints and ignore other result types.
if result_type is TrainingResultType.CHECKPOINT:
self.checkpoint_manager._process_checkpoint(results)

futures = self.worker_group.execute_async(end_training)
results = self.get_with_failure_handling(futures)
return results
Expand Down Expand Up @@ -594,37 +504,9 @@ def shutdown(self):

self.dataset_shards = None

@property
def is_started(self):
return not isinstance(self.worker_group, InactiveWorkerGroup)

@property
def latest_checkpoint_dir(self) -> Optional[Path]:
"""Path to the latest checkpoint directory."""
return self.checkpoint_manager.latest_checkpoint_dir

@property
def best_checkpoint_path(self) -> Optional[Path]:
"""Path to the best persisted checkpoint."""
return self.checkpoint_manager.best_checkpoint_path

@property
def latest_checkpoint_id(self) -> Optional[int]:
"""The checkpoint id of most recently saved checkpoint.

If no checkpoint has been saved yet, then return None.
"""
checkpoint_id = self.checkpoint_manager._latest_checkpoint_id
if checkpoint_id == 0:
return None
else:
return checkpoint_id

@property
def latest_checkpoint(self) -> Optional[Dict]:
"""Latest checkpoint object."""
return self.checkpoint_manager.latest_checkpoint

def _restart(self):
self.worker_group.shutdown()
if self._initialization_hook is not None:
Expand All @@ -646,6 +528,12 @@ def _increment_failures(self):
"`max_retries` arg in your `Trainer`.") \
from None

def get_worker_group(self):
return self.worker_group

def _get_num_failures(self):
return self._num_failures


class Backend(metaclass=abc.ABCMeta):
"""Metaclass for distributed communication backend.
Expand Down
12 changes: 12 additions & 0 deletions python/ray/train/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,18 @@ def best_checkpoint_path(self) -> Optional[Path]:
else:
return None

@property
def latest_checkpoint_id(self) -> Optional[int]:
amogkam marked this conversation as resolved.
Show resolved Hide resolved
"""The checkpoint id of most recently saved checkpoint.

If no checkpoint has been saved yet, then return None.
"""
checkpoint_id = self._latest_checkpoint_id
if checkpoint_id == 0:
return None
else:
return checkpoint_id


class TuneCheckpointManager(CheckpointManager):
def create_logdir(self, log_dir: Optional[Union[str, Path]]):
Expand Down
4 changes: 3 additions & 1 deletion python/ray/train/examples/train_linear_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,9 @@ def validate_epoch(dataloader, model, loss_fn, device):
pred = model(X)
loss += loss_fn(pred, y).item()
loss /= num_batches
result = {"model": model.state_dict(), "loss": loss}
import copy
model_copy = copy.deepcopy(model)
amogkam marked this conversation as resolved.
Show resolved Hide resolved
result = {"model": model_copy.cpu().state_dict(), "loss": loss}
return result


Expand Down
Loading