Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GCP] Adopt new provisioner to stop/down clusters #2199

Merged
merged 18 commits into from
Jul 15, 2023
Merged
4 changes: 2 additions & 2 deletions sky/adaptors/gcp.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
"""GCP cloud adaptors"""

# pylint: disable=import-outside-toplevel
from functools import wraps
import functools

googleapiclient = None
google = None


def import_package(func):

@wraps(func)
@functools.wraps(func)
def wrapper(*args, **kwargs):
global googleapiclient, google
if googleapiclient is None or google is None:
Expand Down
51 changes: 10 additions & 41 deletions sky/backends/cloud_vm_ray_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@
_NODES_LAUNCHING_PROGRESS_TIMEOUT = {
clouds.AWS: 90,
clouds.Azure: 90,
clouds.GCP: 120,
clouds.GCP: 240,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why need to change this?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a known issue that the 120 seconds is not enough for launching multiple nodes for GCP. We need to increase this to make sure sky launch --cloud gcp --num-nodes 8 to work.

clouds.Lambda: 150,
clouds.IBM: 160,
clouds.Local: 90,
Expand Down Expand Up @@ -3350,16 +3350,14 @@ def teardown_no_lock(self,
cloud = handle.launched_resources.cloud
config = common_utils.read_yaml(handle.cluster_yaml)
cluster_name = handle.cluster_name
use_tpu_vm = config['provider'].get('_has_tpus', False)

# Avoid possibly unbound warnings. Code below must overwrite these vars:
returncode = 0
stdout = ''
stderr = ''

# Use the new provisioner for AWS.
if isinstance(cloud, clouds.AWS):
region = config['provider']['region']
if isinstance(cloud, (clouds.AWS, clouds.GCP)):
# Stop the ray autoscaler first to avoid the head node trying to
# re-launch the worker nodes, during the termination of the
# cluster.
Expand All @@ -3379,11 +3377,15 @@ def teardown_no_lock(self,
'already been terminated. It is fine to skip this.')
try:
if terminate:
provision_api.terminate_instances(repr(cloud), region,
cluster_name)
provision_api.terminate_instances(
repr(cloud),
cluster_name,
provider_config=config['provider'])
else:
provision_api.stop_instances(repr(cloud), region,
cluster_name)
provision_api.stop_instances(
repr(cloud),
cluster_name,
provider_config=config['provider'])
except Exception as e: # pylint: disable=broad-except
if purge:
logger.warning(
Expand Down Expand Up @@ -3485,39 +3487,6 @@ def teardown_no_lock(self,

# To avoid undefined local variables error.
stdout = stderr = ''
elif (terminate and
(prev_cluster_status == status_lib.ClusterStatus.STOPPED or
use_tpu_vm)):
# For TPU VMs, gcloud CLI is used for VM termination.
if isinstance(cloud, clouds.GCP):
zone = config['provider']['availability_zone']
# TODO(wei-lin): refactor by calling functions of node provider
# that uses Python API rather than CLI
if use_tpu_vm:
terminate_cmd = tpu_utils.terminate_tpu_vm_cluster_cmd(
cluster_name, zone, log_abs_path)
else:
query_cmd = (f'gcloud compute instances list --filter='
f'"(labels.ray-cluster-name={cluster_name})" '
f'--zones={zone} --format=value\\(name\\)')
# If there are no instances, exit with 0 rather than causing
# the delete command to fail.
terminate_cmd = (
f'VMS=$({query_cmd}) && [ -n "$VMS" ] && '
f'gcloud compute instances delete --zone={zone} --quiet'
' $VMS || echo "No instances to delete."')
else:
with ux_utils.print_exception_no_traceback():
raise ValueError(f'Unsupported cloud {cloud} for stopped '
f'cluster {cluster_name!r}.')
with log_utils.safe_rich_status(f'[bold cyan]Terminating '
f'[green]{cluster_name}'):
returncode, stdout, stderr = log_lib.run_with_log(
terminate_cmd,
log_abs_path,
shell=True,
stream_logs=False,
require_outputs=True)
else:
config['provider']['cache_stopped_nodes'] = not terminate
with tempfile.NamedTemporaryFile('w',
Expand Down
26 changes: 15 additions & 11 deletions sky/provision/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
This module provides a standard low-level interface that all
providers supported by SkyPilot need to follow.
"""
from typing import List, Optional
from typing import Any, Dict, List, Optional

import functools
import importlib
Expand Down Expand Up @@ -37,20 +37,24 @@ def _wrapper(*args, **kwargs):


@_route_to_cloud_impl
def stop_instances(provider_name: str,
region: str,
cluster_name: str,
included_instances: Optional[List[str]] = None,
excluded_instances: Optional[List[str]] = None) -> None:
def stop_instances(
provider_name: str,
cluster_name: str,
provider_config: Optional[Dict[str, Any]] = None,
included_instances: Optional[List[str]] = None,
excluded_instances: Optional[List[str]] = None,
) -> None:
"""Stop running instances."""
raise NotImplementedError


@_route_to_cloud_impl
def terminate_instances(provider_name: str,
region: str,
cluster_name: str,
included_instances: Optional[List[str]] = None,
excluded_instances: Optional[List[str]] = None) -> None:
def terminate_instances(
provider_name: str,
cluster_name: str,
provider_config: Optional[Dict[str, Any]] = None,
included_instances: Optional[List[str]] = None,
excluded_instances: Optional[List[str]] = None,
) -> None:
"""Terminate running or stopped instances."""
raise NotImplementedError
24 changes: 16 additions & 8 deletions sky/provision/aws/instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,15 @@ def _filter_instances(ec2, filters: List[Dict[str, Any]],
return instances


def stop_instances(region: str,
cluster_name: str,
included_instances: Optional[List[str]] = None,
excluded_instances: Optional[List[str]] = None) -> None:
def stop_instances(
cluster_name: str,
provider_config: Optional[Dict[str, Any]] = None,
included_instances: Optional[List[str]] = None,
excluded_instances: Optional[List[str]] = None,
) -> None:
"""See sky/provision/__init__.py"""
assert provider_config is not None, (cluster_name, provider_config)
region = provider_config['region']
ec2 = aws.resource(
'ec2',
region_name=region,
Expand All @@ -57,11 +61,15 @@ def stop_instances(region: str,
# of most cloud implementations (including AWS).


def terminate_instances(region: str,
cluster_name: str,
included_instances: Optional[List[str]] = None,
excluded_instances: Optional[List[str]] = None) -> None:
def terminate_instances(
cluster_name: str,
provider_config: Optional[Dict[str, Any]] = None,
included_instances: Optional[List[str]] = None,
excluded_instances: Optional[List[str]] = None,
) -> None:
"""See sky/provision/__init__.py"""
assert provider_config is not None, (cluster_name, provider_config)
region = provider_config['region']
ec2 = aws.resource(
'ec2',
region_name=region,
Expand Down
3 changes: 3 additions & 0 deletions sky/provision/gcp/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
"""GCP provisioner for SkyPilot."""

from sky.provision.gcp.instance import stop_instances, terminate_instances
153 changes: 153 additions & 0 deletions sky/provision/gcp/instance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
"""GCP instance provisioning."""
import collections
import time
from typing import Any, Callable, Dict, Iterable, List, Optional, Type

from sky import sky_logging
from sky.provision.gcp import instance_utils

logger = sky_logging.init_logger(__name__)

MAX_POLLS = 12
# Stopping instances can take several minutes, so we increase the timeout
MAX_POLLS_STOP = MAX_POLLS * 8
POLL_INTERVAL = 5

# Tag uniquely identifying all nodes of a cluster
TAG_RAY_CLUSTER_NAME = 'ray-cluster-name'


def _filter_instances(
handlers: Iterable[Type[instance_utils.GCPInstance]],
project_id: str,
zone: str,
label_filters: Dict[str, str],
status_filters_fn: Callable[[Type[instance_utils.GCPInstance]],
Optional[List[str]]],
included_instances: Optional[List[str]] = None,
excluded_instances: Optional[List[str]] = None,
) -> Dict[Type[instance_utils.GCPInstance], List[str]]:
"""Filter instances using all instance handlers."""
instances = set()
logger.debug(f'handlers: {handlers}')
for instance_handler in handlers:
instances |= set(
instance_handler.filter(project_id, zone, label_filters,
status_filters_fn(instance_handler),
included_instances, excluded_instances))
handler_to_instances = collections.defaultdict(list)
for instance in instances:
handler = instance_utils.instance_to_handler(instance)
handler_to_instances[handler].append(instance)
logger.debug(f'handler_to_instances: {handler_to_instances}')
return handler_to_instances


def _wait_for_operations(
handlers_to_operations: Dict[Type[instance_utils.GCPInstance], List[dict]],
project_id: str,
zone: str,
) -> None:
"""Poll for compute zone operation until finished."""
total_polls = 0
for handler, operations in handlers_to_operations.items():
for operation in operations:
logger.debug(
'wait_for_compute_zone_operation: '
f'Waiting for operation {operation["name"]} to finish...')
while total_polls < MAX_POLLS:
if handler.wait_for_operation(operation, project_id, zone):
break
time.sleep(POLL_INTERVAL)
total_polls += 1


def stop_instances(
cluster_name: str,
provider_config: Optional[Dict[str, Any]] = None,
included_instances: Optional[List[str]] = None,
excluded_instances: Optional[List[str]] = None,
) -> None:
assert provider_config is not None, cluster_name
zone = provider_config['availability_zone']
project_id = provider_config['project_id']
name_filter = {TAG_RAY_CLUSTER_NAME: cluster_name}

handlers: List[Type[instance_utils.GCPInstance]] = [
instance_utils.GCPComputeInstance
]
use_tpu_vms = provider_config.get('_has_tpus', False)
if use_tpu_vms:
handlers.append(instance_utils.GCPTPUVMInstance)

handler_to_instances = _filter_instances(
handlers,
project_id,
zone,
name_filter,
lambda handler: handler.NEED_TO_STOP_STATES,
included_instances,
excluded_instances,
)
all_instances = [
i for instances in handler_to_instances.values() for i in instances
]

operations = collections.defaultdict(list)
for handler, instances in handler_to_instances.items():
for instance in instances:
operations[handler].append(handler.stop(project_id, zone, instance))
_wait_for_operations(operations, project_id, zone)
# Check if the instance is actually stopped.
# GCP does not fully stop an instance even after
# the stop operation is finished.
for _ in range(MAX_POLLS_STOP):
handler_to_instances = _filter_instances(
handler_to_instances.keys(),
project_id,
zone,
name_filter,
lambda handler: handler.NON_STOPPED_STATES,
included_instances=all_instances,
)
if not handler_to_instances:
break
time.sleep(POLL_INTERVAL)
else:
raise RuntimeError(f'Maximum number of polls: '
f'{MAX_POLLS_STOP} reached. '
f'Instance {all_instances} is still not in '
'STOPPED status.')


def terminate_instances(
cluster_name: str,
provider_config: Optional[Dict[str, Any]] = None,
included_instances: Optional[List[str]] = None,
excluded_instances: Optional[List[str]] = None,
) -> None:
"""See sky/provision/__init__.py"""
assert provider_config is not None, cluster_name
zone = provider_config['availability_zone']
project_id = provider_config['project_id']
use_tpu_vms = provider_config.get('_has_tpus', False)

name_filter = {TAG_RAY_CLUSTER_NAME: cluster_name}
handlers: List[Type[instance_utils.GCPInstance]] = [
instance_utils.GCPComputeInstance
]
if use_tpu_vms:
handlers.append(instance_utils.GCPTPUVMInstance)

handler_to_instances = _filter_instances(handlers, project_id, zone,
name_filter, lambda _: None,
included_instances,
excluded_instances)
operations = collections.defaultdict(list)
for handler, instances in handler_to_instances.items():
for instance in instances:
operations[handler].append(
handler.terminate(project_id, zone, instance))
_wait_for_operations(operations, project_id, zone)
# We don't wait for the instances to be terminated, as it can take a long
# time (same as what we did in ray's node_provider).
Loading