Skip to content

Commit

Permalink
[Perf] Add NUM_WORKERS and CPUS_PER_WORKER env to the mnist workload (r…
Browse files Browse the repository at this point in the history
  • Loading branch information
rueian authored May 8, 2024
1 parent f38951f commit f77ee03
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ spec:
- torch
- torchvision
working_dir: "https://github.com/ray-project/kuberay/archive/master.zip"
env_vars:
NUM_WORKERS: "4"
CPUS_PER_WORKER: "2"
# rayClusterSpec specifies the RayCluster instance to be created by the RayJob controller.
rayClusterSpec:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def train_func_per_worker(config: Dict):
ray.train.report(metrics={"loss": test_loss, "accuracy": accuracy})


def train_fashion_mnist(num_workers=2, use_gpu=False):
def train_fashion_mnist(num_workers=4, cpus_per_worker=2, use_gpu=False):
global_batch_size = 32

train_config = {
Expand All @@ -142,7 +142,7 @@ def train_fashion_mnist(num_workers=2, use_gpu=False):
scaling_config = ScalingConfig(
num_workers=num_workers,
use_gpu=use_gpu,
resources_per_worker={"CPU": 2}
resources_per_worker={"CPU": cpus_per_worker}
)

# Initialize a Ray TorchTrainer
Expand All @@ -160,4 +160,6 @@ def train_fashion_mnist(num_workers=2, use_gpu=False):


if __name__ == "__main__":
train_fashion_mnist(num_workers=4)
num_workers = int(os.getenv("NUM_WORKERS", "4"))
cpus_per_worker = int(os.getenv("CPUS_PER_WORKER", "2"))
train_fashion_mnist(num_workers=num_workers, cpus_per_worker=cpus_per_worker)

0 comments on commit f77ee03

Please sign in to comment.