diff --git a/ray-operator/config/samples/pytorch-mnist/ray-job.pytorch-mnist.yaml b/ray-operator/config/samples/pytorch-mnist/ray-job.pytorch-mnist.yaml index 3027aa8e6f..567dea311e 100644 --- a/ray-operator/config/samples/pytorch-mnist/ray-job.pytorch-mnist.yaml +++ b/ray-operator/config/samples/pytorch-mnist/ray-job.pytorch-mnist.yaml @@ -10,6 +10,9 @@ spec: - torch - torchvision working_dir: "https://github.com/ray-project/kuberay/archive/master.zip" + env_vars: + NUM_WORKERS: "4" + CPUS_PER_WORKER: "2" # rayClusterSpec specifies the RayCluster instance to be created by the RayJob controller. rayClusterSpec: diff --git a/ray-operator/config/samples/pytorch-mnist/ray_train_pytorch_mnist.py b/ray-operator/config/samples/pytorch-mnist/ray_train_pytorch_mnist.py index 158a7489c9..2d0844ccb8 100644 --- a/ray-operator/config/samples/pytorch-mnist/ray_train_pytorch_mnist.py +++ b/ray-operator/config/samples/pytorch-mnist/ray_train_pytorch_mnist.py @@ -129,7 +129,7 @@ def train_func_per_worker(config: Dict): ray.train.report(metrics={"loss": test_loss, "accuracy": accuracy}) -def train_fashion_mnist(num_workers=2, use_gpu=False): +def train_fashion_mnist(num_workers=4, cpus_per_worker=2, use_gpu=False): global_batch_size = 32 train_config = { @@ -142,7 +142,7 @@ def train_fashion_mnist(num_workers=2, use_gpu=False): scaling_config = ScalingConfig( num_workers=num_workers, use_gpu=use_gpu, - resources_per_worker={"CPU": 2} + resources_per_worker={"CPU": cpus_per_worker} ) # Initialize a Ray TorchTrainer @@ -160,4 +160,6 @@ def train_fashion_mnist(num_workers=2, use_gpu=False): if __name__ == "__main__": - train_fashion_mnist(num_workers=4) + num_workers = int(os.getenv("NUM_WORKERS", "4")) + cpus_per_worker = int(os.getenv("CPUS_PER_WORKER", "2")) + train_fashion_mnist(num_workers=num_workers, cpus_per_worker=cpus_per_worker)