Skip to content

Commit

Permalink
[autoscaler] Support cache_stopped_nodes on Azure (#21747)
Browse files Browse the repository at this point in the history
* basic reuse functionality without valid node filtering

* Filtering, logging, and formatting for cache_stopped_nodes on Azure

* Updated formatter version
  • Loading branch information
mraheja authored Jan 28, 2022
1 parent 570f677 commit fe1bf02
Showing 1 changed file with 90 additions and 42 deletions.
132 changes: 90 additions & 42 deletions python/ray/autoscaler/_private/_azure/node_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
from azure.mgmt.resource.resources.models import DeploymentMode

from ray.autoscaler.node_provider import NodeProvider
from ray.autoscaler.tags import TAG_RAY_CLUSTER_NAME, TAG_RAY_NODE_NAME
from ray.autoscaler.tags import (TAG_RAY_CLUSTER_NAME, TAG_RAY_NODE_NAME,
TAG_RAY_NODE_KIND, TAG_RAY_LAUNCH_CONFIG,
TAG_RAY_USER_NODE_TYPE)
from ray.autoscaler._private._azure.config import (bootstrap_azure,
get_azure_sdk_function)

Expand Down Expand Up @@ -50,6 +52,8 @@ class AzureNodeProvider(NodeProvider):
def __init__(self, provider_config, cluster_name):
NodeProvider.__init__(self, provider_config, cluster_name)
subscription_id = provider_config["subscription_id"]
self.cache_stopped_nodes = provider_config.get("cache_stopped_nodes",
True)
credential = DefaultAzureCredential(
exclude_shared_token_cache_credential=True)
self.compute_client = ComputeManagementClient(credential,
Expand Down Expand Up @@ -114,6 +118,13 @@ def _extract_metadata(self, vm):

return metadata

def stopped_nodes(self, tag_filters):
"""Return a list of stopped node ids filtered by the specified tags dict."""
nodes = self._get_filtered_nodes(tag_filters=tag_filters)
return [
k for k, v in nodes.items() if v["status"].startswith("deallocat")
]

def non_terminated_nodes(self, tag_filters):
"""Return a list of node ids filtered by the specified tags dict.
Expand Down Expand Up @@ -161,8 +172,34 @@ def internal_ip(self, node_id):
return ip

def create_node(self, node_config, tags, count):
resource_group = self.provider_config["resource_group"]

if self.cache_stopped_nodes:
VALIDITY_TAGS = [
TAG_RAY_CLUSTER_NAME, TAG_RAY_NODE_KIND, TAG_RAY_LAUNCH_CONFIG,
TAG_RAY_USER_NODE_TYPE
]
filters = {tag: tags[tag] for tag in VALIDITY_TAGS if tag in tags}
reuse_nodes = self.stopped_nodes(filters)[:count]
logger.info(
f"Reusing nodes {list(reuse_nodes)}. "
"To disable reuse, set `cache_stopped_nodes: False` "
"under `provider` in the cluster configuration.", )
start = get_azure_sdk_function(
client=self.compute_client.virtual_machines,
function_name="start")
for node_id in reuse_nodes:
start(
resource_group_name=resource_group,
vm_name=node_id).wait()
self.set_node_tags(node_id, tags)
count -= len(reuse_nodes)

if count:
self._create_node(node_config, tags, count)

def _create_node(self, node_config, tags, count):
"""Creates a number of nodes within the namespace."""
# TODO: restart deallocated nodes if possible
resource_group = self.provider_config["resource_group"]

# load the template file
Expand Down Expand Up @@ -235,57 +272,68 @@ def terminate_node(self, node_id):
# node no longer exists
return

# TODO: deallocate instead of delete to allow possible reuse
# self.compute_client.virtual_machines.deallocate(
# resource_group_name=resource_group,
# vm_name=node_id)

# gather disks to delete later
vm = self.compute_client.virtual_machines.get(
resource_group_name=resource_group, vm_name=node_id)
disks = {d.name for d in vm.storage_profile.data_disks}
disks.add(vm.storage_profile.os_disk.name)

try:
# delete machine, must wait for this to complete
delete = get_azure_sdk_function(
client=self.compute_client.virtual_machines,
function_name="delete")
delete(resource_group_name=resource_group, vm_name=node_id).wait()
except Exception as e:
logger.warning("Failed to delete VM: {}".format(e))

try:
# delete nic
delete = get_azure_sdk_function(
client=self.network_client.network_interfaces,
function_name="delete")
delete(
resource_group_name=resource_group,
network_interface_name=metadata["nic_name"])
except Exception as e:
logger.warning("Failed to delete nic: {}".format(e))
if self.cache_stopped_nodes:
try:
# stop machine and leave all resources
logger.info(f"Stopping instance {node_id}"
"(to fully terminate instead, "
"set `cache_stopped_nodes: False` "
"under `provider` in the cluster configuration)")
stop = get_azure_sdk_function(
client=self.compute_client.virtual_machines,
function_name="deallocate")
stop(resource_group_name=resource_group, vm_name=node_id)
except Exception as e:
logger.warning("Failed to stop VM: {}".format(e))
else:
vm = self.compute_client.virtual_machines.get(
resource_group_name=resource_group, vm_name=node_id)
disks = {d.name for d in vm.storage_profile.data_disks}
disks.add(vm.storage_profile.os_disk.name)

# delete ip address
if "public_ip_name" in metadata:
try:
# delete machine, must wait for this to complete
delete = get_azure_sdk_function(
client=self.network_client.public_ip_addresses,
client=self.compute_client.virtual_machines,
function_name="delete")
delete(
resource_group_name=resource_group,
public_ip_address_name=metadata["public_ip_name"])
vm_name=node_id).wait()
except Exception as e:
logger.warning("Failed to delete public ip: {}".format(e))
logger.warning("Failed to delete VM: {}".format(e))

# delete disks
for disk in disks:
try:
# delete nic
delete = get_azure_sdk_function(
client=self.compute_client.disks, function_name="delete")
delete(resource_group_name=resource_group, disk_name=disk)
client=self.network_client.network_interfaces,
function_name="delete")
delete(
resource_group_name=resource_group,
network_interface_name=metadata["nic_name"])
except Exception as e:
logger.warning("Failed to delete disk: {}".format(e))
logger.warning("Failed to delete nic: {}".format(e))

# delete ip address
if "public_ip_name" in metadata:
try:
delete = get_azure_sdk_function(
client=self.network_client.public_ip_addresses,
function_name="delete")
delete(
resource_group_name=resource_group,
public_ip_address_name=metadata["public_ip_name"])
except Exception as e:
logger.warning("Failed to delete public ip: {}".format(e))

# delete disks
for disk in disks:
try:
delete = get_azure_sdk_function(
client=self.compute_client.disks,
function_name="delete")
delete(resource_group_name=resource_group, disk_name=disk)
except Exception as e:
logger.warning("Failed to delete disk: {}".format(e))

def _get_node(self, node_id):
self._get_filtered_nodes({}) # Side effect: updates cache
Expand Down

0 comments on commit fe1bf02

Please sign in to comment.