From b9f0209f8f540fb43262a48c2622d308e94436c3 Mon Sep 17 00:00:00 2001 From: ryanaoleary <113500783+ryanaoleary@users.noreply.github.com> Date: Thu, 7 Nov 2024 07:17:16 +0000 Subject: [PATCH] Update v6e-256 KubeRay Sample (#2466) * Update v6e-256 sample Signed-off-by: Ryan O'Leary * Remove multi-slice env vars and block on Jax call Signed-off-by: Ryan O'Leary * Add multihost_utils import Signed-off-by: Ryan O'Leary * Add print statement Signed-off-by: Ryan O'Leary * Print TPU_WORKER_ID Signed-off-by: Ryan O'Leary * Fix print Signed-off-by: Ryan O'Leary * move barrier after jax device_count call Signed-off-by: Ryan O'Leary * Add missing import Signed-off-by: Ryan O'Leary * test using sleep for RayJob Signed-off-by: Ryan O'Leary * Clean up tpu_list_devices script Signed-off-by: Ryan O'Leary * Remove securityContext from spec Signed-off-by: Ryan O'Leary * Remove privileged security context Signed-off-by: ryanaoleary * Remove unneeded securityContexts Signed-off-by: ryanaoleary * Add back in sync_global_devices Signed-off-by: ryanaoleary --------- Signed-off-by: Ryan O'Leary Signed-off-by: ryanaoleary --- .../ray-cluster.tpu-v6e-16-multihost.yaml | 2 -- .../ray-cluster.tpu-v6e-256-multihost.yaml | 2 -- .../samples/ray-cluster.tpu-v6e-singlehost.yaml | 2 -- .../samples/ray-job.tpu-v6e-256-multihost.yaml | 16 ++-------------- .../config/samples/tpu/tpu_list_devices.py | 10 +++++++++- 5 files changed, 11 insertions(+), 21 deletions(-) diff --git a/ray-operator/config/samples/ray-cluster.tpu-v6e-16-multihost.yaml b/ray-operator/config/samples/ray-cluster.tpu-v6e-16-multihost.yaml index 2b9db54754..485d85d1de 100644 --- a/ray-operator/config/samples/ray-cluster.tpu-v6e-16-multihost.yaml +++ b/ray-operator/config/samples/ray-cluster.tpu-v6e-16-multihost.yaml @@ -38,8 +38,6 @@ spec: rayStartParams: {} template: spec: - securityContext: - runAsUser: 0 containers: - name: ray-worker image: rayproject/ray:2.37.0-py310 diff --git a/ray-operator/config/samples/ray-cluster.tpu-v6e-256-multihost.yaml b/ray-operator/config/samples/ray-cluster.tpu-v6e-256-multihost.yaml index 8d4b029f20..c3b2876426 100644 --- a/ray-operator/config/samples/ray-cluster.tpu-v6e-256-multihost.yaml +++ b/ray-operator/config/samples/ray-cluster.tpu-v6e-256-multihost.yaml @@ -38,8 +38,6 @@ spec: rayStartParams: {} template: spec: - securityContext: - runAsUser: 0 containers: - name: ray-worker image: rayproject/ray:2.37.0-py310 diff --git a/ray-operator/config/samples/ray-cluster.tpu-v6e-singlehost.yaml b/ray-operator/config/samples/ray-cluster.tpu-v6e-singlehost.yaml index a8698429b6..1199a20d34 100644 --- a/ray-operator/config/samples/ray-cluster.tpu-v6e-singlehost.yaml +++ b/ray-operator/config/samples/ray-cluster.tpu-v6e-singlehost.yaml @@ -40,8 +40,6 @@ spec: rayStartParams: {} template: spec: - securityContext: - runAsUser: 0 containers: - name: ray-worker image: rayproject/ray:2.37.0-py310 diff --git a/ray-operator/config/samples/ray-job.tpu-v6e-256-multihost.yaml b/ray-operator/config/samples/ray-job.tpu-v6e-256-multihost.yaml index a1b349118f..09e45f735a 100644 --- a/ray-operator/config/samples/ray-job.tpu-v6e-256-multihost.yaml +++ b/ray-operator/config/samples/ray-job.tpu-v6e-256-multihost.yaml @@ -40,11 +40,10 @@ spec: maxReplicas: 1 numOfHosts: 64 groupName: tpu-group - rayStartParams: {} + rayStartParams: + resources: '"{\"TPU\": 4}"' template: spec: - securityContext: - runAsUser: 0 containers: - name: ray-worker image: rayproject/ray:2.37.0-py310 @@ -58,19 +57,8 @@ spec: google.com/tpu: "4" memory: 200G env: - - name: NODE_IP - valueFrom: - fieldRef: - fieldPath: status.hostIP - - name: VBAR_CONTROL_SERVICE_URL - value: $(NODE_IP):8353 - name: JAX_PLATFORMS value: tpu,cpu - - name: ENABLE_PJRT_COMPATIBILITY - value: "true" - ports: - - containerPort: 8081 - name: mxla nodeSelector: cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice cloud.google.com/gke-tpu-topology: 16x16 diff --git a/ray-operator/config/samples/tpu/tpu_list_devices.py b/ray-operator/config/samples/tpu/tpu_list_devices.py index 9f40f7674c..ea65f2df45 100644 --- a/ray-operator/config/samples/tpu/tpu_list_devices.py +++ b/ray-operator/config/samples/tpu/tpu_list_devices.py @@ -1,12 +1,20 @@ +import os import ray import jax +import time + +from jax.experimental import multihost_utils ray.init() @ray.remote(resources={"TPU": 4}) def tpu_cores(): - return "TPU cores:" + str(jax.device_count()) + multihost_utils.sync_global_devices("sync") + cores = "TPU cores:" + str(jax.device_count()) + print("TPU Worker: " + os.environ.get("TPU_WORKER_ID")) + return cores num_workers = int(ray.available_resources()["TPU"]) // 4 +print(f"Number of TPU Workers: {num_workers}") result = [tpu_cores.remote() for _ in range(num_workers)] print(ray.get(result))