Skip to content

Commit

Permalink
[sgd] Distributed Training via PyTorch (#4797)
Browse files Browse the repository at this point in the history
Implements distributed SGD using distributed PyTorch.
  • Loading branch information
pschafhalter authored and richardliaw committed Jun 2, 2019
1 parent 88bab5d commit c2ade07
Show file tree
Hide file tree
Showing 11 changed files with 751 additions and 23 deletions.
23 changes: 1 addition & 22 deletions ci/jenkins_tests/run_multi_node_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -31,25 +31,4 @@ $SUPPRESS_OUTPUT docker run --rm --shm-size=60G --memory=60G $DOCKER_SHA \
######################## SGD TESTS #################################

$SUPPRESS_OUTPUT docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
python /ray/python/ray/experimental/sgd/test_sgd.py --num-iters=2 \
--batch-size=1 --strategy=simple

$SUPPRESS_OUTPUT docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
python /ray/python/ray/experimental/sgd/test_sgd.py --num-iters=2 \
--batch-size=1 --strategy=ps

$SUPPRESS_OUTPUT docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
python /ray/python/ray/experimental/sgd/test_save_and_restore.py --num-iters=2 \
--batch-size=1 --strategy=simple

$SUPPRESS_OUTPUT docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
python /ray/python/ray/experimental/sgd/test_save_and_restore.py --num-iters=2 \
--batch-size=1 --strategy=ps

$SUPPRESS_OUTPUT docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
python /ray/python/ray/experimental/sgd/mnist_example.py --num-iters=1 \
--num-workers=1 --devices-per-worker=1 --strategy=ps

$SUPPRESS_OUTPUT docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
python /ray/python/ray/experimental/sgd/mnist_example.py --num-iters=1 \
--num-workers=1 --devices-per-worker=1 --strategy=ps --tune
python -m pytest /ray/python/ray/experimental/sgd/tests
4 changes: 4 additions & 0 deletions doc/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@
"tensorflow.python",
"tensorflow.python.client",
"tensorflow.python.util",
"torch",
"torch.distributed",
"torch.nn",
"torch.utils.data",
]
for mod_name in MOCK_MODULES:
sys.modules[mod_name] = mock.Mock()
Expand Down
48 changes: 48 additions & 0 deletions doc/source/distributed_training.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
Distributed Training (Experimental)
===================================


Ray includes abstractions for distributed model training that integrate with
deep learning frameworks, such as PyTorch.

Ray Train is built on top of the Ray task and actor abstractions to provide
seamless integration into existing Ray applications.

PyTorch Interface
-----------------

To use Ray Train with PyTorch, pass model and data creator functions to the
``ray.experimental.sgd.pytorch.PyTorchTrainer`` class.
To drive the distributed training, ``trainer.train()`` can be called
repeatedly.

.. code-block:: python
model_creator = lambda config: YourPyTorchModel()
data_creator = lambda config: YourTrainingSet(), YourValidationSet()
trainer = PyTorchTrainer(
model_creator,
data_creator,
optimizer_creator=utils.sgd_mse_optimizer,
config={"lr": 1e-4},
num_replicas=2,
resources_per_replica=Resources(num_gpus=1),
batch_size=16,
backend="auto")
for i in range(NUM_EPOCHS):
trainer.train()
Under the hood, Ray Train will create *replicas* of your model
(controlled by ``num_replicas``) which are each managed by a worker.
Multiple devices (e.g. GPUs) can be managed by each replica (controlled by ``resources_per_replica``),
which allows training of lage models across multiple GPUs.
The ``PyTorchTrainer`` class coordinates the distributed computation and training to improve the model.

The full documentation for ``PyTorchTrainer`` is as follows:

.. autoclass:: ray.experimental.sgd.pytorch.PyTorchTrainer
:members:

.. automethod:: __init__
3 changes: 2 additions & 1 deletion doc/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ Ray comes with libraries that accelerate deep learning and reinforcement learnin

- `Tune`_: Scalable Hyperparameter Search
- `RLlib`_: Scalable Reinforcement Learning
- `Distributed Training <distributed_sgd.html>`__
- `Distributed Training <distributed_training.html>`__

.. _`Tune`: tune.html
.. _`RLlib`: rllib.html
Expand Down Expand Up @@ -107,6 +107,7 @@ Ray comes with libraries that accelerate deep learning and reinforcement learnin
:maxdepth: 1
:caption: Other Libraries

distributed_training.rst
distributed_sgd.rst
pandas_on_ray.rst

Expand Down
8 changes: 8 additions & 0 deletions python/ray/experimental/sgd/pytorch/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from ray.experimental.sgd.pytorch.pytorch_trainer import PyTorchTrainer
from ray.experimental.sgd.pytorch.utils import Resources

__all__ = ["PyTorchTrainer", "Resources"]
182 changes: 182 additions & 0 deletions python/ray/experimental/sgd/pytorch/pytorch_runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import logging
import os
import torch
import torch.distributed as dist
import torch.utils.data

import ray
from ray.experimental.sgd.pytorch import utils

logger = logging.getLogger(__name__)


class PyTorchRunner(object):
"""Manages a distributed PyTorch model replica"""

def __init__(self,
model_creator,
data_creator,
optimizer_creator,
config=None,
batch_size=16,
backend="gloo"):
"""Initializes the runner.
Args:
model_creator (dict -> torch.nn.Module): creates the model using
the config.
data_creator (dict -> Dataset, Dataset): creates the training and
validation data sets using the config.
optimizer_creator (torch.nn.Module, dict -> loss, optimizer):
creates the loss and optimizer using the model and the config.
config (dict): configuration passed to 'model_creator',
'data_creator', and 'optimizer_creator'.
batch_size (int): batch size used in an update.
backend (string): backend used by distributed PyTorch.
"""

self.model_creator = model_creator
self.data_creator = data_creator
self.optimizer_creator = optimizer_creator
self.config = {} if config is None else config
self.batch_size = batch_size
self.backend = backend
self.verbose = True

self.epoch = 0
self._timers = {
k: utils.TimerStat(window_size=1)
for k in [
"setup_proc", "setup_model", "get_state", "set_state",
"validation", "training"
]
}

def setup(self, url, world_rank, world_size):
"""Connects to the distributed PyTorch backend and initializes the model.
Args:
url (str): the URL used to connect to distributed PyTorch.
world_rank (int): the index of the runner.
world_size (int): the total number of runners.
"""
self._setup_distributed_pytorch(url, world_rank, world_size)
self._setup_training()

def _setup_distributed_pytorch(self, url, world_rank, world_size):
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
with self._timers["setup_proc"]:
self.world_rank = world_rank
logger.debug(
"Connecting to {} world_rank: {} world_size: {}".format(
url, world_rank, world_size))
logger.debug("using {}".format(self.backend))
dist.init_process_group(
backend=self.backend,
init_method=url,
rank=world_rank,
world_size=world_size)

def _setup_training(self):
logger.debug("Creating model")
self.model = self.model_creator(self.config)
if torch.cuda.is_available():
self.model = torch.nn.parallel.DistributedDataParallel(
self.model.cuda())
else:
self.model = torch.nn.parallel.DistributedDataParallelCPU(
self.model)

logger.debug("Creating optimizer")
self.criterion, self.optimizer = self.optimizer_creator(
self.model, self.config)

if torch.cuda.is_available():
self.criterion = self.criterion.cuda()

logger.debug("Creating dataset")
self.training_set, self.validation_set = self.data_creator(self.config)

# TODO: make num_workers configurable
self.train_sampler = torch.utils.data.distributed.DistributedSampler(
self.training_set)
self.train_loader = torch.utils.data.DataLoader(
self.training_set,
batch_size=self.batch_size,
shuffle=(self.train_sampler is None),
num_workers=2,
pin_memory=False,
sampler=self.train_sampler)

self.validation_sampler = (
torch.utils.data.distributed.DistributedSampler(
self.validation_set))
self.validation_loader = torch.utils.data.DataLoader(
self.validation_set,
batch_size=self.batch_size,
shuffle=(self.validation_sampler is None),
num_workers=2,
pin_memory=False,
sampler=self.validation_sampler)

def get_node_ip(self):
"""Returns the IP address of the current node"""
return ray.services.get_node_ip_address()

def step(self):
"""Runs a training epoch and updates the model parameters"""
logger.debug("Starting step")
self.train_sampler.set_epoch(self.epoch)

logger.debug("Begin Training Epoch {}".format(self.epoch + 1))
with self._timers["training"]:
train_stats = utils.train(self.train_loader, self.model,
self.criterion, self.optimizer)
train_stats["epoch"] = self.epoch

self.epoch += 1

train_stats.update(self.stats())
return train_stats

def validate(self):
"""Evaluates the model on the validation data set"""
with self._timers["validation"]:
validation_stats = utils.validate(self.validation_loader,
self.model, self.criterion)

validation_stats.update(self.stats())
return validation_stats

def stats(self):
"""Returns a dictionary of statistics collected"""
stats = {"epoch": self.epoch}
for k, t in self._timers.items():
stats[k + "_time_mean"] = t.mean
stats[k + "_time_total"] = t.sum
t.reset()
return stats

def get_state(self):
"""Returns the state of the runner"""
return {
"epoch": self.epoch,
"model": self.model.state_dict(),
"optimizer": self.optimizer.state_dict(),
"stats": self.stats()
}

def set_state(self, state):
"""Sets the state of the model"""
# TODO: restore timer stats
self.model.load_state_dict(state["model"])
self.optimizer.load_state_dict(state["optimizer"])
self.epoch = state["stats"]["epoch"]

def shutdown(self):
"""Attempts to shut down the worker"""
dist.destroy_process_group()
Loading

0 comments on commit c2ade07

Please sign in to comment.