Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TPUs fixes #1360

Merged
merged 3 commits into from
Jun 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 72 additions & 19 deletions src/dstack/_internal/core/backends/gcp/compute.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import concurrent.futures
import json
from collections import defaultdict
from typing import Callable, Dict, List, Optional
Expand All @@ -21,6 +22,7 @@
ComputeError,
ComputeResourceNotFoundError,
NoCapacityError,
ProvisioningError,
)
from dstack._internal.core.models.backends.base import BackendType
from dstack._internal.core.models.gateways import GatewayComputeConfiguration
Expand Down Expand Up @@ -164,32 +166,32 @@ def create_instance(
spot=instance_offer.instance.resources.spot,
labels=labels,
subnetwork=subnetwork,
allocate_public_ip=allocate_public_ip,
)
create_node_request = tpu_v2.CreateNodeRequest(
parent=f"projects/{self.config.project_id}/locations/{zone}",
node_id=instance_id,
node=tpu_node,
)
try:
# TPUs may get created and then deleted immediately in case of no capacity.
# We call wait_for_operation() only to get the capacity error and try another option.
# If the request succeeds, we'll probably timeout and update_provisioning_data() will get hostname.
operation = self.tpu_client.create_node(request=create_node_request)
gcp_resources.wait_for_operation(
operation, verbose_name="tpu instance creation"
)
gcp_resources.wait_for_operation(operation, timeout=5)
except (
google.api_core.exceptions.ServiceUnavailable,
google.api_core.exceptions.NotFound,
google.api_core.exceptions.ResourceExhausted,
):
continue
node_request = tpu_v2.GetNodeRequest(
name=f"projects/dstack/locations/{zone}/nodes/{instance_id}",
)
instance = self.tpu_client.get_node(request=node_request)
except concurrent.futures.TimeoutError:
pass
return JobProvisioningData(
backend=instance_offer.backend,
instance_type=instance_offer.instance,
instance_id=instance_id,
hostname=instance.network_endpoints[0].access_config.external_ip,
hostname=None,
internal_ip=None,
region=zone,
price=instance_offer.price,
Expand Down Expand Up @@ -240,27 +242,19 @@ def create_instance(
allocate_public_ip=allocate_public_ip,
)
try:
operation = self.instances_client.insert(request=request)
gcp_resources.wait_for_extended_operation(operation, "instance creation")
self.instances_client.insert(request=request)
except (
google.api_core.exceptions.ServiceUnavailable,
google.api_core.exceptions.NotFound,
):
continue
instance = self.instances_client.get(
project=self.config.project_id, zone=zone, instance=instance_name
)
if allocate_public_ip:
hostname = instance.network_interfaces[0].access_configs[0].nat_i_p
else:
hostname = instance.network_interfaces[0].network_i_p
return JobProvisioningData(
backend=instance_offer.backend,
instance_type=instance_offer.instance,
instance_id=instance_name,
public_ip_enabled=allocate_public_ip,
hostname=hostname,
internal_ip=instance.network_interfaces[0].network_i_p,
hostname=None,
internal_ip=None,
region=instance_offer.region,
price=instance_offer.price,
username="ubuntu",
Expand All @@ -271,6 +265,65 @@ def create_instance(
)
raise NoCapacityError()

def update_provisioning_data(
self,
provisioning_data: JobProvisioningData,
project_ssh_public_key: str,
project_ssh_private_key: str,
):
allocate_public_ip = self.config.allocate_public_ips
zone = provisioning_data.region
is_tpu = False
if provisioning_data.backend_data is not None:
backend_data_dict = json.loads(provisioning_data.backend_data)
zone = backend_data_dict["zone"]
is_tpu = backend_data_dict.get("is_tpu", False)

if is_tpu:
node_request = tpu_v2.GetNodeRequest(
name=f"projects/dstack/locations/{zone}/nodes/{provisioning_data.instance_id}",
)
try:
instance = self.tpu_client.get_node(request=node_request)
except google.api_core.exceptions.NotFound:
raise ProvisioningError("Failed to get instance IP address. Instance not found.")

# See states https://cloud.google.com/python/docs/reference/tpu/latest/google.cloud.tpu_v2.types.Node.State
if instance.state in [0, 1]:
return
if instance.state == 2:
if allocate_public_ip:
hostname = instance.network_endpoints[0].access_config.external_ip
else:
hostname = instance.network_endpoints[0].ip_address
provisioning_data.hostname = hostname
provisioning_data.internal_ip = instance.network_endpoints[0].ip_address
return
raise ProvisioningError(
f"Failed to get instance IP address. Instance state: {instance.state}"
)

try:
instance = self.instances_client.get(
project=self.config.project_id, zone=zone, instance=provisioning_data.instance_id
)
except google.api_core.exceptions.NotFound:
raise ProvisioningError("Failed to get instance IP address. Instance not found.")

if instance.status in ["PROVISIONING", "STAGING"]:
return
if instance.status == "RUNNING":
if allocate_public_ip:
hostname = instance.network_interfaces[0].access_configs[0].nat_i_p
else:
hostname = instance.network_interfaces[0].network_i_p
provisioning_data.hostname = hostname
provisioning_data.internal_ip = instance.network_interfaces[0].network_i_p
return
raise ProvisioningError(
f"Failed to get instance IP address. Instance status: {instance.status}"
)

def run_job(
self,
run: Run,
Expand Down
10 changes: 7 additions & 3 deletions src/dstack/_internal/core/backends/gcp/resources.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import concurrent.futures
import random
import re
import string
Expand Down Expand Up @@ -311,6 +312,7 @@ def create_tpu_node_struct(
spot: bool,
labels: Dict[str, str],
subnetwork: Optional[str] = None,
allocate_public_ip: bool = True,
) -> tpu_v2.Node:
node = tpu_v2.Node()
if spot:
Expand All @@ -319,7 +321,7 @@ def create_tpu_node_struct(
node.runtime_version = "tpu-ubuntu2204-base"
# subnetwork determines the network, so network shouldn't be specified
node.network_config = tpu_v2.NetworkConfig(
enable_external_ips=True,
enable_external_ips=allocate_public_ip,
subnetwork=subnetwork,
)
ssh_keys = "\n".join(f"ubuntu:{key}" for key in authorized_keys)
Expand Down Expand Up @@ -351,10 +353,12 @@ def wait_for_extended_operation(
def wait_for_operation(operation: Operation, verbose_name: str = "operation", timeout: int = 300):
try:
result = operation.result(timeout=timeout)
except concurrent.futures.TimeoutError as e:
logger.debug("Error during %s: %s", verbose_name, e)
raise
except Exception as e:
# Write only debug logs here.
# The unexpected errors will be propagated and logged appropriatly by the caller.
logger.debug("Error during %s: %s", verbose_name, e)
logger.debug("Operation ID: %s", operation)
raise operation.exception() or RuntimeError(str(e))
raise operation.exception() or e
return result
9 changes: 4 additions & 5 deletions src/dstack/_internal/core/models/instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,10 @@ def pretty_format(self) -> str:
resources["disk_size"] = f"{self.disk.size_mib / 1024:.1f}GB"
if self.gpus:
gpu = self.gpus[0]
resources.update(
gpu_name=gpu.name,
gpu_count=len(self.gpus),
gpu_memory=f"{gpu.memory_mib / 1024:.0f}GB",
)
resources["gpu_name"] = gpu.name
resources["gpu_count"] = len(self.gpus)
if gpu.memory_mib > 0:
resources["gpu_memory"] = (f"{gpu.memory_mib / 1024:.0f}GB",)
return pretty_resources(**resources)


Expand Down
Loading