diff --git a/ray-operator/config/samples/pytorch-mnist/ray-job.pytorch-mnist.yaml b/ray-operator/config/samples/pytorch-mnist/ray-job.pytorch-mnist.yaml new file mode 100644 index 0000000000..3027aa8e6f --- /dev/null +++ b/ray-operator/config/samples/pytorch-mnist/ray-job.pytorch-mnist.yaml @@ -0,0 +1,57 @@ +apiVersion: ray.io/v1 +kind: RayJob +metadata: + name: rayjob-pytorch-mnist +spec: + shutdownAfterJobFinishes: false + entrypoint: python ray-operator/config/samples/pytorch-mnist/ray_train_pytorch_mnist.py + runtimeEnvYAML: | + pip: + - torch + - torchvision + working_dir: "https://github.com/ray-project/kuberay/archive/master.zip" + + # rayClusterSpec specifies the RayCluster instance to be created by the RayJob controller. + rayClusterSpec: + rayVersion: '2.9.0' + headGroupSpec: + rayStartParams: {} + # Pod template + template: + spec: + containers: + - name: ray-head + image: rayproject/ray:2.9.0 + ports: + - containerPort: 6379 + name: gcs-server + - containerPort: 8265 # Ray dashboard + name: dashboard + - containerPort: 10001 + name: client + resources: + limits: + cpu: "2" + memory: "4Gi" + requests: + cpu: "2" + memory: "4Gi" + workerGroupSpecs: + - replicas: 4 + minReplicas: 1 + maxReplicas: 5 + groupName: small-group + rayStartParams: {} + # Pod template + template: + spec: + containers: + - name: ray-worker + image: rayproject/ray:2.9.0 + resources: + limits: + cpu: "2" + memory: "4Gi" + requests: + cpu: "2" + memory: "4Gi" diff --git a/ray-operator/config/samples/pytorch-mnist/ray_train_pytorch_mnist.py b/ray-operator/config/samples/pytorch-mnist/ray_train_pytorch_mnist.py new file mode 100644 index 0000000000..158a7489c9 --- /dev/null +++ b/ray-operator/config/samples/pytorch-mnist/ray_train_pytorch_mnist.py @@ -0,0 +1,163 @@ +""" +Reference: https://docs.ray.io/en/master/train/examples/pytorch/torch_fashion_mnist_example.html + +This script is a modified version of the original PyTorch Fashion MNIST +example. It uses only CPU resources to train the MNIST model. See +`ScalingConfig` for more details. +""" +import os +from typing import Dict + +import torch +from filelock import FileLock +from torch import nn +from torch.utils.data import DataLoader +from torchvision import datasets, transforms +from torchvision.transforms import Normalize, ToTensor +from tqdm import tqdm + +import ray.train +from ray.train import ScalingConfig +from ray.train.torch import TorchTrainer + + +def get_dataloaders(batch_size): + # Transform to normalize the input images + transform = transforms.Compose([ToTensor(), Normalize((0.5,), (0.5,))]) + + with FileLock(os.path.expanduser("~/data.lock")): + # Download training data from open datasets + training_data = datasets.FashionMNIST( + root="~/data", + train=True, + download=True, + transform=transform, + ) + + # Download test data from open datasets + test_data = datasets.FashionMNIST( + root="~/data", + train=False, + download=True, + transform=transform, + ) + + # Create data loaders + train_dataloader = DataLoader(training_data, batch_size=batch_size, shuffle=True) + test_dataloader = DataLoader(test_data, batch_size=batch_size) + + return train_dataloader, test_dataloader + + +# Model Definition +class NeuralNetwork(nn.Module): + def __init__(self): + super(NeuralNetwork, self).__init__() + self.flatten = nn.Flatten() + self.linear_relu_stack = nn.Sequential( + nn.Linear(28 * 28, 512), + nn.ReLU(), + nn.Dropout(0.25), + nn.Linear(512, 512), + nn.ReLU(), + nn.Dropout(0.25), + nn.Linear(512, 10), + nn.ReLU(), + ) + + def forward(self, x): + x = self.flatten(x) + logits = self.linear_relu_stack(x) + return logits + + +def train_func_per_worker(config: Dict): + lr = config["lr"] + epochs = config["epochs"] + batch_size = config["batch_size_per_worker"] + + # Get dataloaders inside the worker training function + train_dataloader, test_dataloader = get_dataloaders(batch_size=batch_size) + + # [1] Prepare Dataloader for distributed training + # Shard the datasets among workers and move batches to the correct device + # ======================================================================= + train_dataloader = ray.train.torch.prepare_data_loader(train_dataloader) + test_dataloader = ray.train.torch.prepare_data_loader(test_dataloader) + + model = NeuralNetwork() + + # [2] Prepare and wrap your model with DistributedDataParallel + # Move the model to the correct GPU/CPU device + # ============================================================ + model = ray.train.torch.prepare_model(model) + + loss_fn = nn.CrossEntropyLoss() + optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9) + + # Model training loop + for epoch in range(epochs): + if ray.train.get_context().get_world_size() > 1: + # Required for the distributed sampler to shuffle properly across epochs. + train_dataloader.sampler.set_epoch(epoch) + + model.train() + for X, y in tqdm(train_dataloader, desc=f"Train Epoch {epoch}"): + pred = model(X) + loss = loss_fn(pred, y) + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + model.eval() + test_loss, num_correct, num_total = 0, 0, 0 + with torch.no_grad(): + for X, y in tqdm(test_dataloader, desc=f"Test Epoch {epoch}"): + pred = model(X) + loss = loss_fn(pred, y) + + test_loss += loss.item() + num_total += y.shape[0] + num_correct += (pred.argmax(1) == y).sum().item() + + test_loss /= len(test_dataloader) + accuracy = num_correct / num_total + + # [3] Report metrics to Ray Train + # =============================== + ray.train.report(metrics={"loss": test_loss, "accuracy": accuracy}) + + +def train_fashion_mnist(num_workers=2, use_gpu=False): + global_batch_size = 32 + + train_config = { + "lr": 1e-3, + "epochs": 10, + "batch_size_per_worker": global_batch_size // num_workers, + } + + # Configure computation resources + scaling_config = ScalingConfig( + num_workers=num_workers, + use_gpu=use_gpu, + resources_per_worker={"CPU": 2} + ) + + # Initialize a Ray TorchTrainer + trainer = TorchTrainer( + train_loop_per_worker=train_func_per_worker, + train_loop_config=train_config, + scaling_config=scaling_config, + ) + + # [4] Start distributed training + # Run `train_func_per_worker` on all workers + # ============================================= + result = trainer.fit() + print(f"Training result: {result}") + + +if __name__ == "__main__": + train_fashion_mnist(num_workers=4)