From 579041d3fed45a98fe25bc1fa8acc9b02fa16fdd Mon Sep 17 00:00:00 2001 From: Wei-Lin Chiang Date: Fri, 4 Aug 2023 00:18:00 -0700 Subject: [PATCH 1/3] fix tpu bug --- sky/backends/cloud_vm_ray_backend.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/sky/backends/cloud_vm_ray_backend.py b/sky/backends/cloud_vm_ray_backend.py index 690d9ab19c1..f976e517a73 100644 --- a/sky/backends/cloud_vm_ray_backend.py +++ b/sky/backends/cloud_vm_ray_backend.py @@ -2248,7 +2248,7 @@ def _update_stable_ssh_ports(self, max_attempts: int = 1) -> None: ports = [head_port] + worker_ports else: # Use port 22 for other clouds - ports = [22] * self.launched_nodes + ports = [22] * self.num_node_ips self.stable_ssh_ports = ports def _update_stable_cluster_ips(self, max_attempts: int = 1) -> None: @@ -2350,6 +2350,16 @@ def head_ssh_port(self): return external_ssh_ports[0] return None + @property + def num_node_ips(self): + is_tpu_vm_pod = tpu_utils.is_tpu_vm_pod(self.launched_resources) + if is_tpu_vm_pod: + num_ips = tpu_utils.get_num_tpu_devices( + self.launched_resources) + else: + num_ips = self.launched_nodes + return num_ips + def __setstate__(self, state): self._version = self._VERSION From 59bcc4ef3ce609edfb667029ba32dccfad7d3f56 Mon Sep 17 00:00:00 2001 From: Wei-Lin Chiang Date: Fri, 4 Aug 2023 00:25:08 -0700 Subject: [PATCH 2/3] format --- sky/backends/cloud_vm_ray_backend.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/sky/backends/cloud_vm_ray_backend.py b/sky/backends/cloud_vm_ray_backend.py index f976e517a73..730af407562 100644 --- a/sky/backends/cloud_vm_ray_backend.py +++ b/sky/backends/cloud_vm_ray_backend.py @@ -2354,8 +2354,7 @@ def head_ssh_port(self): def num_node_ips(self): is_tpu_vm_pod = tpu_utils.is_tpu_vm_pod(self.launched_resources) if is_tpu_vm_pod: - num_ips = tpu_utils.get_num_tpu_devices( - self.launched_resources) + num_ips = tpu_utils.get_num_tpu_devices(self.launched_resources) else: num_ips = self.launched_nodes return num_ips From 4ef9fd9be7e68c02abd5d24040ee39447126760e Mon Sep 17 00:00:00 2001 From: Wei-Lin Chiang Date: Fri, 4 Aug 2023 15:38:38 -0700 Subject: [PATCH 3/3] update --- sky/backends/cloud_vm_ray_backend.py | 3 ++- sky/utils/tpu_utils.py | 5 ++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/sky/backends/cloud_vm_ray_backend.py b/sky/backends/cloud_vm_ray_backend.py index 730af407562..941a60ec299 100644 --- a/sky/backends/cloud_vm_ray_backend.py +++ b/sky/backends/cloud_vm_ray_backend.py @@ -2351,7 +2351,8 @@ def head_ssh_port(self): return None @property - def num_node_ips(self): + def num_node_ips(self) -> int: + """Returns number of IPs of the cluster, correctly handling TPU Pod.""" is_tpu_vm_pod = tpu_utils.is_tpu_vm_pod(self.launched_resources) if is_tpu_vm_pod: num_ips = tpu_utils.get_num_tpu_devices(self.launched_resources) diff --git a/sky/utils/tpu_utils.py b/sky/utils/tpu_utils.py index c811b938a0f..369ef94e6a6 100644 --- a/sky/utils/tpu_utils.py +++ b/sky/utils/tpu_utils.py @@ -30,10 +30,9 @@ def is_tpu_vm_pod(resources: Optional[resources_lib.Resources]) -> bool: return acc not in ['tpu-v2-8', 'tpu-v3-8', 'tpu-v4-8'] -def get_num_tpu_devices( - resources: Optional[resources_lib.Resources]) -> Optional[int]: +def get_num_tpu_devices(resources: Optional[resources_lib.Resources]) -> int: if resources is None or not is_tpu(resources): - return None + raise ValueError('resources must be a valid TPU resource.') acc, _ = list(resources.accelerators.items())[0] num_tpu_devices = int(int(acc.split('-')[2]) / 8) return num_tpu_devices