diff --git a/python/ray/autoscaler/_private/_azure/node_provider.py b/python/ray/autoscaler/_private/_azure/node_provider.py index 42f2e06f264b..484fb014ed61 100644 --- a/python/ray/autoscaler/_private/_azure/node_provider.py +++ b/python/ray/autoscaler/_private/_azure/node_provider.py @@ -11,7 +11,9 @@ from azure.mgmt.resource.resources.models import DeploymentMode from ray.autoscaler.node_provider import NodeProvider -from ray.autoscaler.tags import TAG_RAY_CLUSTER_NAME, TAG_RAY_NODE_NAME +from ray.autoscaler.tags import (TAG_RAY_CLUSTER_NAME, TAG_RAY_NODE_NAME, + TAG_RAY_NODE_KIND, TAG_RAY_LAUNCH_CONFIG, + TAG_RAY_USER_NODE_TYPE) from ray.autoscaler._private._azure.config import (bootstrap_azure, get_azure_sdk_function) @@ -50,6 +52,8 @@ class AzureNodeProvider(NodeProvider): def __init__(self, provider_config, cluster_name): NodeProvider.__init__(self, provider_config, cluster_name) subscription_id = provider_config["subscription_id"] + self.cache_stopped_nodes = provider_config.get("cache_stopped_nodes", + True) credential = DefaultAzureCredential( exclude_shared_token_cache_credential=True) self.compute_client = ComputeManagementClient(credential, @@ -114,6 +118,13 @@ def _extract_metadata(self, vm): return metadata + def stopped_nodes(self, tag_filters): + """Return a list of stopped node ids filtered by the specified tags dict.""" + nodes = self._get_filtered_nodes(tag_filters=tag_filters) + return [ + k for k, v in nodes.items() if v["status"].startswith("deallocat") + ] + def non_terminated_nodes(self, tag_filters): """Return a list of node ids filtered by the specified tags dict. @@ -161,8 +172,34 @@ def internal_ip(self, node_id): return ip def create_node(self, node_config, tags, count): + resource_group = self.provider_config["resource_group"] + + if self.cache_stopped_nodes: + VALIDITY_TAGS = [ + TAG_RAY_CLUSTER_NAME, TAG_RAY_NODE_KIND, TAG_RAY_LAUNCH_CONFIG, + TAG_RAY_USER_NODE_TYPE + ] + filters = {tag: tags[tag] for tag in VALIDITY_TAGS if tag in tags} + reuse_nodes = self.stopped_nodes(filters)[:count] + logger.info( + f"Reusing nodes {list(reuse_nodes)}. " + "To disable reuse, set `cache_stopped_nodes: False` " + "under `provider` in the cluster configuration.", ) + start = get_azure_sdk_function( + client=self.compute_client.virtual_machines, + function_name="start") + for node_id in reuse_nodes: + start( + resource_group_name=resource_group, + vm_name=node_id).wait() + self.set_node_tags(node_id, tags) + count -= len(reuse_nodes) + + if count: + self._create_node(node_config, tags, count) + + def _create_node(self, node_config, tags, count): """Creates a number of nodes within the namespace.""" - # TODO: restart deallocated nodes if possible resource_group = self.provider_config["resource_group"] # load the template file @@ -235,57 +272,68 @@ def terminate_node(self, node_id): # node no longer exists return - # TODO: deallocate instead of delete to allow possible reuse - # self.compute_client.virtual_machines.deallocate( - # resource_group_name=resource_group, - # vm_name=node_id) - - # gather disks to delete later - vm = self.compute_client.virtual_machines.get( - resource_group_name=resource_group, vm_name=node_id) - disks = {d.name for d in vm.storage_profile.data_disks} - disks.add(vm.storage_profile.os_disk.name) - - try: - # delete machine, must wait for this to complete - delete = get_azure_sdk_function( - client=self.compute_client.virtual_machines, - function_name="delete") - delete(resource_group_name=resource_group, vm_name=node_id).wait() - except Exception as e: - logger.warning("Failed to delete VM: {}".format(e)) - - try: - # delete nic - delete = get_azure_sdk_function( - client=self.network_client.network_interfaces, - function_name="delete") - delete( - resource_group_name=resource_group, - network_interface_name=metadata["nic_name"]) - except Exception as e: - logger.warning("Failed to delete nic: {}".format(e)) + if self.cache_stopped_nodes: + try: + # stop machine and leave all resources + logger.info(f"Stopping instance {node_id}" + "(to fully terminate instead, " + "set `cache_stopped_nodes: False` " + "under `provider` in the cluster configuration)") + stop = get_azure_sdk_function( + client=self.compute_client.virtual_machines, + function_name="deallocate") + stop(resource_group_name=resource_group, vm_name=node_id) + except Exception as e: + logger.warning("Failed to stop VM: {}".format(e)) + else: + vm = self.compute_client.virtual_machines.get( + resource_group_name=resource_group, vm_name=node_id) + disks = {d.name for d in vm.storage_profile.data_disks} + disks.add(vm.storage_profile.os_disk.name) - # delete ip address - if "public_ip_name" in metadata: try: + # delete machine, must wait for this to complete delete = get_azure_sdk_function( - client=self.network_client.public_ip_addresses, + client=self.compute_client.virtual_machines, function_name="delete") delete( resource_group_name=resource_group, - public_ip_address_name=metadata["public_ip_name"]) + vm_name=node_id).wait() except Exception as e: - logger.warning("Failed to delete public ip: {}".format(e)) + logger.warning("Failed to delete VM: {}".format(e)) - # delete disks - for disk in disks: try: + # delete nic delete = get_azure_sdk_function( - client=self.compute_client.disks, function_name="delete") - delete(resource_group_name=resource_group, disk_name=disk) + client=self.network_client.network_interfaces, + function_name="delete") + delete( + resource_group_name=resource_group, + network_interface_name=metadata["nic_name"]) except Exception as e: - logger.warning("Failed to delete disk: {}".format(e)) + logger.warning("Failed to delete nic: {}".format(e)) + + # delete ip address + if "public_ip_name" in metadata: + try: + delete = get_azure_sdk_function( + client=self.network_client.public_ip_addresses, + function_name="delete") + delete( + resource_group_name=resource_group, + public_ip_address_name=metadata["public_ip_name"]) + except Exception as e: + logger.warning("Failed to delete public ip: {}".format(e)) + + # delete disks + for disk in disks: + try: + delete = get_azure_sdk_function( + client=self.compute_client.disks, + function_name="delete") + delete(resource_group_name=resource_group, disk_name=disk) + except Exception as e: + logger.warning("Failed to delete disk: {}".format(e)) def _get_node(self, node_id): self._get_filtered_nodes({}) # Side effect: updates cache