Skip to content

Commit

Permalink
Update v6e-256 KubeRay Sample (#2466)
Browse files Browse the repository at this point in the history
* Update v6e-256 sample

Signed-off-by: Ryan O'Leary <[email protected]>

* Remove multi-slice env vars and block on Jax call

Signed-off-by: Ryan O'Leary <[email protected]>

* Add multihost_utils import

Signed-off-by: Ryan O'Leary <[email protected]>

* Add print statement

Signed-off-by: Ryan O'Leary <[email protected]>

* Print TPU_WORKER_ID

Signed-off-by: Ryan O'Leary <[email protected]>

* Fix print

Signed-off-by: Ryan O'Leary <[email protected]>

* move barrier after jax device_count call

Signed-off-by: Ryan O'Leary <[email protected]>

* Add missing import

Signed-off-by: Ryan O'Leary <[email protected]>

* test using sleep for RayJob

Signed-off-by: Ryan O'Leary <[email protected]>

* Clean up tpu_list_devices script

Signed-off-by: Ryan O'Leary <[email protected]>

* Remove securityContext from spec

Signed-off-by: Ryan O'Leary <[email protected]>

* Remove privileged security context

Signed-off-by: ryanaoleary <[email protected]>

* Remove unneeded securityContexts

Signed-off-by: ryanaoleary <[email protected]>

* Add back in sync_global_devices

Signed-off-by: ryanaoleary <[email protected]>

---------

Signed-off-by: Ryan O'Leary <[email protected]>
Signed-off-by: ryanaoleary <[email protected]>
  • Loading branch information
ryanaoleary authored Nov 7, 2024
1 parent 4fc1799 commit b9f0209
Show file tree
Hide file tree
Showing 5 changed files with 11 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,6 @@ spec:
rayStartParams: {}
template:
spec:
securityContext:
runAsUser: 0
containers:
- name: ray-worker
image: rayproject/ray:2.37.0-py310
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,6 @@ spec:
rayStartParams: {}
template:
spec:
securityContext:
runAsUser: 0
containers:
- name: ray-worker
image: rayproject/ray:2.37.0-py310
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,6 @@ spec:
rayStartParams: {}
template:
spec:
securityContext:
runAsUser: 0
containers:
- name: ray-worker
image: rayproject/ray:2.37.0-py310
Expand Down
16 changes: 2 additions & 14 deletions ray-operator/config/samples/ray-job.tpu-v6e-256-multihost.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,10 @@ spec:
maxReplicas: 1
numOfHosts: 64
groupName: tpu-group
rayStartParams: {}
rayStartParams:
resources: '"{\"TPU\": 4}"'
template:
spec:
securityContext:
runAsUser: 0
containers:
- name: ray-worker
image: rayproject/ray:2.37.0-py310
Expand All @@ -58,19 +57,8 @@ spec:
google.com/tpu: "4"
memory: 200G
env:
- name: NODE_IP
valueFrom:
fieldRef:
fieldPath: status.hostIP
- name: VBAR_CONTROL_SERVICE_URL
value: $(NODE_IP):8353
- name: JAX_PLATFORMS
value: tpu,cpu
- name: ENABLE_PJRT_COMPATIBILITY
value: "true"
ports:
- containerPort: 8081
name: mxla
nodeSelector:
cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice
cloud.google.com/gke-tpu-topology: 16x16
10 changes: 9 additions & 1 deletion ray-operator/config/samples/tpu/tpu_list_devices.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,20 @@
import os
import ray
import jax
import time

from jax.experimental import multihost_utils

ray.init()

@ray.remote(resources={"TPU": 4})
def tpu_cores():
return "TPU cores:" + str(jax.device_count())
multihost_utils.sync_global_devices("sync")
cores = "TPU cores:" + str(jax.device_count())
print("TPU Worker: " + os.environ.get("TPU_WORKER_ID"))
return cores

num_workers = int(ray.available_resources()["TPU"]) // 4
print(f"Number of TPU Workers: {num_workers}")
result = [tpu_cores.remote() for _ in range(num_workers)]
print(ray.get(result))

0 comments on commit b9f0209

Please sign in to comment.