diff --git a/.buildkite/pipeline.ml.yml b/.buildkite/pipeline.ml.yml index 844737fdaecc..86e78889135b 100644 --- a/.buildkite/pipeline.ml.yml +++ b/.buildkite/pipeline.ml.yml @@ -2,7 +2,7 @@ conditions: ["RAY_CI_ML_AFFECTED"] commands: - cleanup() { if [ "${BUILDKITE_PULL_REQUEST}" = "false" ]; then ./ci/travis/upload_build_info.sh; fi }; trap cleanup EXIT - - DATA_PROCESSING_TESTING=1 ./ci/travis/install-dependencies.sh + - DATA_PROCESSING_TESTING=1 INSTALL_HOROVOD=1 ./ci/travis/install-dependencies.sh - bazel test --config=ci $(./scripts/bazel_export_options) --build_tests_only --test_tag_filters=-gpu python/ray/ml/... - label: ":brain: RLlib: Learning discr. actions TF2-static-graph (from rllib/tuned_examples/*.yaml)" diff --git a/python/ray/ml/BUILD b/python/ray/ml/BUILD index f7b68f75d614..70040b30549d 100644 --- a/python/ray/ml/BUILD +++ b/python/ray/ml/BUILD @@ -129,6 +129,14 @@ py_test( deps = [":ml_lib"] ) +py_test( + name = "test_horovod_trainer", + size = "large", + srcs = ["tests/test_horovod_trainer.py"], + tags = ["team:ml", "exclusive"], + deps = [":ml_lib"] +) + py_test( name = "test_lightgbm_predictor", size = "small", diff --git a/python/ray/ml/examples/horovod/__init__.py b/python/ray/ml/examples/horovod/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/python/ray/ml/examples/horovod/horovod_pytorch_example.py b/python/ray/ml/examples/horovod/horovod_pytorch_example.py new file mode 100644 index 000000000000..9b20132a4fc5 --- /dev/null +++ b/python/ray/ml/examples/horovod/horovod_pytorch_example.py @@ -0,0 +1,266 @@ +import argparse +from filelock import FileLock +import horovod.torch as hvd +import os +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +import torch.utils.data.distributed +from torchvision import datasets, transforms + +import ray +from ray import train +from ray.ml.train.integrations.horovod import HorovodTrainer + + +def metric_average(val, name): + tensor = torch.tensor(val) + avg_tensor = hvd.allreduce(tensor, name=name) + return avg_tensor.item() + + +class Net(nn.Module): + def __init__(self): + super(Net, self).__init__() + self.conv1 = nn.Conv2d(1, 10, kernel_size=5) + self.conv2 = nn.Conv2d(10, 20, kernel_size=5) + self.conv2_drop = nn.Dropout2d() + self.fc1 = nn.Linear(320, 50) + self.fc2 = nn.Linear(50, 10) + + def forward(self, x): + x = F.relu(F.max_pool2d(self.conv1(x), 2)) + x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) + x = x.view(-1, 320) + x = F.relu(self.fc1(x)) + x = F.dropout(x, training=self.training) + x = self.fc2(x) + return F.log_softmax(x) + + +def setup(config): + data_dir = config.get("data_dir", None) + seed = config.get("seed", 42) + batch_size = config.get("batch_size", 64) + use_adasum = config.get("use_adasum", False) + lr = config.get("lr", 0.01) + momentum = config.get("momentum", 0.5) + use_cuda = config.get("use_cuda", False) + + # Horovod: initialize library. + hvd.init() + torch.manual_seed(seed) + + if use_cuda: + # Horovod: pin GPU to local rank. + torch.cuda.set_device(hvd.local_rank()) + torch.cuda.manual_seed(seed) + + # Horovod: limit # of CPU threads to be used per worker. + torch.set_num_threads(1) + + kwargs = {"num_workers": 1, "pin_memory": True} if use_cuda else {} + data_dir = data_dir or "~/data" + with FileLock(os.path.expanduser("~/.horovod_lock")): + train_dataset = datasets.MNIST( + data_dir, + train=True, + download=True, + transform=transforms.Compose( + [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] + ), + ) + # Horovod: use DistributedSampler to partition the training data. + train_sampler = torch.utils.data.distributed.DistributedSampler( + train_dataset, num_replicas=hvd.size(), rank=hvd.rank() + ) + train_loader = torch.utils.data.DataLoader( + train_dataset, batch_size=batch_size, sampler=train_sampler, **kwargs + ) + + model = Net() + + # By default, Adasum doesn't need scaling up learning rate. + lr_scaler = hvd.size() if not use_adasum else 1 + + if use_cuda: + # Move model to GPU. + model.cuda() + # If using GPU Adasum allreduce, scale learning rate by local_size. + if use_adasum and hvd.nccl_built(): + lr_scaler = hvd.local_size() + + # Horovod: scale learning rate by lr_scaler. + optimizer = optim.SGD(model.parameters(), lr=lr * lr_scaler, momentum=momentum) + + # Horovod: wrap optimizer with DistributedOptimizer. + optimizer = hvd.DistributedOptimizer( + optimizer, + named_parameters=model.named_parameters(), + op=hvd.Adasum if use_adasum else hvd.Average, + ) + + return model, optimizer, train_loader, train_sampler + + +def train_epoch( + model, optimizer, train_sampler, train_loader, epoch, log_interval, use_cuda +): + loss = None + model.train() + # Horovod: set epoch to sampler for shuffling. + train_sampler.set_epoch(epoch) + for batch_idx, (data, target) in enumerate(train_loader): + if use_cuda: + data, target = data.cuda(), target.cuda() + optimizer.zero_grad() + output = model(data) + loss = F.nll_loss(output, target) + loss.backward() + optimizer.step() + if batch_idx % log_interval == 0: + # Horovod: use train_sampler to determine the number of + # examples in this worker's partition. + print( + "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format( + epoch, + batch_idx * len(data), + len(train_sampler), + 100.0 * batch_idx / len(train_loader), + loss.item(), + ) + ) + return loss.item() if loss else None + + +def train_func(config): + num_epochs = config.get("num_epochs", 10) + log_interval = config.get("log_interval", 10) + use_cuda = config.get("use_cuda", False) + save_model_as_dict = config.get("save_model_as_dict", False) + + model, optimizer, train_loader, train_sampler = setup(config) + + results = [] + for epoch in range(num_epochs): + loss = train_epoch( + model, optimizer, train_sampler, train_loader, epoch, log_interval, use_cuda + ) + results.append(loss) + if save_model_as_dict: + train.save_checkpoint(model=model.state_dict()) + else: + train.save_checkpoint(model=model) + print("losses of each epoch:") + print(results) + return results + + +def main(num_workers, use_gpu, kwargs): + trainer = HorovodTrainer( + train_loop_per_worker=train_func, + train_loop_config={ + "num_epochs": kwargs["num_epochs"], + "log_interval": kwargs["log_interval"], + "use_cuda": kwargs["use_cuda"], + }, + scaling_config={"num_workers": num_workers, "use_gpu": use_gpu}, + ) + result = trainer.fit() + print(result) + + +if __name__ == "__main__": + # Training settings + parser = argparse.ArgumentParser( + description="PyTorch MNIST Example", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "--batch-size", + type=int, + default=64, + metavar="N", + help="input batch size for training (default: 64)", + ) + parser.add_argument( + "--num-epochs", + type=int, + default=5, + metavar="N", + help="number of epochs to train (default: 10)", + ) + parser.add_argument( + "--lr", + type=float, + default=0.01, + metavar="LR", + help="learning rate (default: 0.01)", + ) + parser.add_argument( + "--momentum", + type=float, + default=0.5, + metavar="M", + help="SGD momentum (default: 0.5)", + ) + parser.add_argument( + "--use-gpu", action="store_true", default=False, help="enables CUDA training" + ) + parser.add_argument( + "--seed", type=int, default=42, metavar="S", help="random seed (default: 42)" + ) + parser.add_argument( + "--log-interval", + type=int, + default=10, + metavar="N", + help="how many batches to wait before logging training status", + ) + parser.add_argument( + "--use-adasum", + action="store_true", + default=False, + help="use adasum algorithm to do reduction", + ) + parser.add_argument( + "--num-workers", + type=int, + default=2, + help="Number of Ray workers to use for training.", + ) + parser.add_argument( + "--data-dir", + help="location of the training dataset in the local filesystem (" + "will be downloaded if needed)", + ) + parser.add_argument( + "--address", + required=False, + type=str, + default=None, + help="Address of Ray cluster.", + ) + + args = parser.parse_args() + + if args.address: + ray.init(args.address) + else: + ray.init() + + use_cuda = args.use_gpu if args.use_gpu is not None else False + + kwargs = { + "data_dir": args.data_dir, + "seed": args.seed, + "use_cuda": use_cuda, + "batch_size": args.batch_size, + "use_adasum": args.use_adasum if args.use_adasum else False, + "lr": args.lr, + "momentum": args.momentum, + "num_epochs": args.num_epochs, + "log_interval": args.log_interval, + } + + main(num_workers=args.num_workers, use_gpu=use_cuda, kwargs=kwargs) diff --git a/python/ray/ml/tests/test_horovod_trainer.py b/python/ray/ml/tests/test_horovod_trainer.py new file mode 100644 index 000000000000..c0946ea28493 --- /dev/null +++ b/python/ray/ml/tests/test_horovod_trainer.py @@ -0,0 +1,109 @@ +import pytest +import torch +import torch.nn +from torch.utils.data import DataLoader +from torchvision import datasets +from torchvision.transforms import transforms + +import ray +from ray.ml.examples.horovod.horovod_pytorch_example import ( + train_func as hvd_train_func, + Net, +) +from ray.ml.predictors.integrations.torch import TorchPredictor +from ray.ml.train.integrations.horovod import HorovodTrainer + + +@pytest.fixture +def ray_start_4_cpus(): + address_info = ray.init(num_cpus=4) + yield address_info + # The code after the yield will run as teardown code. + ray.shutdown() + + +def run_image_prediction(model: torch.nn.Module, images: torch.Tensor) -> torch.Tensor: + model.eval() + with torch.no_grad(): + return torch.exp(model(images)).argmax(dim=1) + + +def test_horovod(ray_start_4_cpus): + def train_func(config): + result = hvd_train_func(config) + assert len(result) == epochs + assert result[-1] < result[0] + + num_workers = 1 + epochs = 10 + scaling_config = {"num_workers": num_workers} + config = {"num_epochs": epochs, "save_model_as_dict": False} + trainer = HorovodTrainer( + train_loop_per_worker=train_func, + train_loop_config=config, + scaling_config=scaling_config, + ) + result = trainer.fit() + predictor = TorchPredictor.from_checkpoint(result.checkpoint) + + # Find some test data to run on. + test_set = datasets.MNIST( + "./data", + train=False, + download=True, + transform=transforms.Compose( + [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] + ), + ) + + test_dataloader = DataLoader(test_set, batch_size=10) + test_dataloader_iter = iter(test_dataloader) + images, labels = next( + test_dataloader_iter + ) # only running a batch inference of 10 images + predicted_labels = run_image_prediction(predictor.model, images) + assert torch.equal(predicted_labels, labels) + + +def test_horovod_state_dict(ray_start_4_cpus): + def train_func(config): + result = hvd_train_func(config) + assert len(result) == epochs + assert result[-1] < result[0] + + num_workers = 2 + epochs = 10 + scaling_config = {"num_workers": num_workers} + config = {"num_epochs": epochs, "save_model_as_dict": True} + trainer = HorovodTrainer( + train_loop_per_worker=train_func, + train_loop_config=config, + scaling_config=scaling_config, + ) + result = trainer.fit() + predictor = TorchPredictor.from_checkpoint(result.checkpoint, model=Net()) + + # Find some test data to run on. + test_set = datasets.MNIST( + "./data", + train=False, + download=True, + transform=transforms.Compose( + [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] + ), + ) + + test_dataloader = DataLoader(test_set, batch_size=10) + test_dataloader_iter = iter(test_dataloader) + images, labels = next( + test_dataloader_iter + ) # only running a batch inference of 10 images + predicted_labels = run_image_prediction(predictor.model, images) + assert torch.equal(predicted_labels, labels) + + +if __name__ == "__main__": + import pytest + import sys + + sys.exit(pytest.main(["-v", "-x", __file__])) diff --git a/python/ray/ml/train/integrations/horovod/__init__.py b/python/ray/ml/train/integrations/horovod/__init__.py new file mode 100644 index 000000000000..3df5710e7db5 --- /dev/null +++ b/python/ray/ml/train/integrations/horovod/__init__.py @@ -0,0 +1,3 @@ +from ray.ml.train.integrations.horovod.horovod_trainer import HorovodTrainer + +__all__ = ["HorovodTrainer"] diff --git a/python/ray/ml/train/integrations/horovod/horovod_trainer.py b/python/ray/ml/train/integrations/horovod/horovod_trainer.py new file mode 100644 index 000000000000..39f92bec072d --- /dev/null +++ b/python/ray/ml/train/integrations/horovod/horovod_trainer.py @@ -0,0 +1,185 @@ +from typing import Dict, Callable, Optional, Union + +from ray.ml.config import ScalingConfig, RunConfig +from ray.ml.trainer import GenDataset +from ray.ml.preprocessor import Preprocessor +from ray.ml.checkpoint import Checkpoint + + +from ray.ml.train.data_parallel_trainer import DataParallelTrainer +from ray.train.horovod import HorovodConfig + + +class HorovodTrainer(DataParallelTrainer): + """A Trainer for data parallel Horovod training. + + This Trainer runs the function ``train_loop_per_worker`` on multiple Ray + Actors. These actors already have the necessary Horovod setup already + configured for distributed Horovod training. + + The ``train_loop_per_worker`` function is expected to take in either 0 or 1 + arguments: + + .. code-block:: python + + def train_loop_per_worker(): + ... + + .. code-block:: python + + def train_loop_per_worker(config: Dict): + ... + + If ``train_loop_per_worker`` accepts an argument, then + ``train_loop_config`` will be passed in as the argument. This is useful if you + want to tune the values in ``train_loop_config`` as hyperparameters. + + If the ``datasets`` dict contains a training dataset (denoted by + the "train" key), then it will be split into multiple dataset + shards that can then be accessed by ``ray.train.get_dataset_shard("train")`` inside + ``train_loop_per_worker``. All the other datasets will not be split and + ``ray.train.get_dataset_shard(...)`` will return the the entire Dataset. + + Inside the ``train_loop_per_worker`` function, you can use any of the + :ref:`Ray Train function utils `. + + .. code-block:: python + + def train_loop_per_worker(): + # Report intermediate results for callbacks or logging. + train.report(...) + + # Checkpoints the provided args as restorable state. + train.save_checkpoint(...) + + # Returns dict of last saved checkpoint. + train.load_checkpoint() + + # Returns the Ray Dataset shard for the given key. + train.get_dataset_shard("my_dataset") + + # Returns the total number of workers executing training. + train.get_world_size() + + # Returns the rank of this worker. + train.get_world_rank() + + # Returns the rank of the worker on the current node. + train.get_local_rank() + + You could use ``TensorflowPredictor`` or ``TorchPredictor`` in conjunction with + HorovodTrainer. You must save the model under the "model" kwarg in + ``train.save_checkpoint()``, so that it can be used by corresponding predictors. + + Example: + + .. code-block:: python + + import ray + import ray.train as train + import ray.train.torch. # Need this to use `train.torch.get_device()` + import horovod.torch as hvd + import torch + import torch.nn as nn + from ray.ml.train.integrations.horovod import HorovodTrainer + + input_size = 1 + layer_size = 15 + output_size = 1 + num_epochs = 3 + + class NeuralNetwork(nn.Module): + def __init__(self): + super(NeuralNetwork, self).__init__() + self.layer1 = nn.Linear(input_size, layer_size) + self.relu = nn.ReLU() + self.layer2 = nn.Linear(layer_size, output_size) + def forward(self, input): + return self.layer2(self.relu(self.layer1(input))) + + def train_loop_per_worker(): + hvd.init() + dataset_shard = train.get_dataset_shard("train") + model = NeuralNetwork() + device = train.torch.get_device() + model.to(device) + loss_fn = nn.MSELoss() + lr_scaler = 1 + optimizer = torch.optim.SGD(model.parameters(), lr=0.1 * lr_scaler) + # Horovod: wrap optimizer with DistributedOptimizer. + optimizer = hvd.DistributedOptimizer( + optimizer, + named_parameters=model.named_parameters(), + op=hvd.Average, + ) + for epoch in range(num_epochs): + model.train() + for inputs, labels in iter( + dataset_shard.to_torch( + label_column="y", + label_column_dtype=torch.float, + feature_column_dtypes=torch.float, + batch_size=32, + ) + ): + inputs.to(device) + labels.to(device) + outputs = model(inputs) + loss = loss_fn(outputs, labels) + optimizer.zero_grad() + loss.backward() + optimizer.step() + print(f"epoch: {epoch}, loss: {loss.item()}") + train.save_checkpoint(model=model.state_dict()) + train_dataset = ray.data.from_items([{"x": x, "y": x + 1} for x in range(32)]) + scaling_config = {"num_workers": 3} + # If using GPUs, use the below scaling config instead. + # scaling_config = {"num_workers": 3, "use_gpu": True} + trainer = HorovodTrainer( + train_loop_per_worker=train_loop_per_worker, + scaling_config={"num_workers": 3}, + datasets={"train": train_dataset}, + ) + result = trainer.fit() + + Args: + train_loop_per_worker: The training function to execute. + This can either take in no arguments or a ``config`` dict. + train_loop_config: Configurations to pass into + ``train_loop_per_worker`` if it accepts an argument. + horovod_config: Configuration for setting up the Horovod backend. + If set to None, use the default configuration. This replaces the + ``backend_config`` arg of ``DataParallelTrainer``. + scaling_config: Configuration for how to scale data parallel training. + run_config: Configuration for the execution of the training run. + datasets: Any Ray Datasets to use for training. Use + the key "train" to denote which dataset is the training + dataset. If a ``preprocessor`` is provided and has not already been fit, + it will be fit on the training dataset. All datasets will be transformed + by the ``preprocessor`` if one is provided. + preprocessor: A ray.ml.preprocessor.Preprocessor to preprocess the + provided datasets. + resume_from_checkpoint: A checkpoint to resume training from. + """ + + def __init__( + self, + train_loop_per_worker: Union[Callable[[], None], Callable[[Dict], None]], + train_loop_config: Optional[Dict] = None, + horovod_config: Optional[HorovodConfig] = None, + scaling_config: Optional[ScalingConfig] = None, + run_config: Optional[RunConfig] = None, + datasets: Optional[Dict[str, GenDataset]] = None, + preprocessor: Optional[Preprocessor] = None, + resume_from_checkpoint: Optional[Checkpoint] = None, + ): + super().__init__( + train_loop_per_worker, + train_loop_config=train_loop_config, + backend_config=horovod_config or HorovodConfig(), + scaling_config=scaling_config, + run_config=run_config, + datasets=datasets, + preprocessor=preprocessor, + resume_from_checkpoint=resume_from_checkpoint, + ) diff --git a/python/ray/train/examples/horovod/horovod_example.py b/python/ray/train/examples/horovod/horovod_example.py index 75851ccdb964..cb578b1fb18f 100644 --- a/python/ray/train/examples/horovod/horovod_example.py +++ b/python/ray/train/examples/horovod/horovod_example.py @@ -140,9 +140,6 @@ def train_func(config): log_interval = config.get("log_interval", 10) use_cuda = config.get("use_cuda", False) - if use_cuda: - torch.cuda.set_device(hvd.local_rank()) - model, optimizer, train_loader, train_sampler = setup(config) results = []