Skip to content

Commit

Permalink
[autoscaler] Speedups (ray-project#3720)
Browse files Browse the repository at this point in the history
- NodeUpdater gets its' IP in parallel now (no longer in __init__)
- We use persistent connections in SSH (temp folder created only for ray; ControlMaster)
- hash_runtime_conf was performing a pointless hexlify step, wasting time on large files
- We use NodeUpdaterThreads and share the NodeProvider; NodeUpdaterProcess is removed
- AWSNodeProvider caches nodes more aggressively
- NodeProvider now has a shim batch terminate_nodes() call; AWSNodeProvider parallelises it; the autoscaler uses it
- AWSNodeProvider batches EC2 update_tags calls
- Logging changes throughout to provide standardised timing information for profiling
- Pulled out a few unnecessary is_running calls (NodeUpdater will loop waiting for SSH anyway)

## Related issue number
Issue ray-project#3599
  • Loading branch information
ls-daniel authored and richardliaw committed Feb 1, 2019
1 parent ff3c6af commit 315edab
Show file tree
Hide file tree
Showing 13 changed files with 543 additions and 342 deletions.
178 changes: 97 additions & 81 deletions python/ray/autoscaler/autoscaler.py

Large diffs are not rendered by default.

41 changes: 26 additions & 15 deletions python/ray/autoscaler/aws/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,18 @@

from distutils.version import StrictVersion
import json
import logging
import os
import time
import logging

import boto3
from botocore.config import Config
import botocore

from ray.ray_constants import BOTO_MAX_RETRIES

logger = logging.getLogger(__name__)

RAY = "ray-autoscaler"
DEFAULT_RAY_INSTANCE_PROFILE = RAY + "-v1"
DEFAULT_RAY_IAM_ROLE = RAY + "-v1"
Expand All @@ -34,7 +36,6 @@ def key_pair(i, region):

# Suppress excessive connection dropped logs from boto
logging.getLogger("botocore").setLevel(logging.WARNING)
logger = logging.getLogger(__name__)


def bootstrap_aws(config):
Expand Down Expand Up @@ -62,8 +63,9 @@ def _configure_iam_role(config):
profile = _get_instance_profile(DEFAULT_RAY_INSTANCE_PROFILE, config)

if profile is None:
logger.info("Creating new instance profile {}".format(
DEFAULT_RAY_INSTANCE_PROFILE))
logger.info("_configure_iam_role: "
"Creating new instance profile {}".format(
DEFAULT_RAY_INSTANCE_PROFILE))
client = _client("iam", config)
client.create_instance_profile(
InstanceProfileName=DEFAULT_RAY_INSTANCE_PROFILE)
Expand All @@ -75,7 +77,8 @@ def _configure_iam_role(config):
if not profile.roles:
role = _get_role(DEFAULT_RAY_IAM_ROLE, config)
if role is None:
logger.info("Creating new role {}".format(DEFAULT_RAY_IAM_ROLE))
logger.info("_configure_iam_role: "
"Creating new role {}".format(DEFAULT_RAY_IAM_ROLE))
iam = _resource("iam", config)
iam.create_role(
RoleName=DEFAULT_RAY_IAM_ROLE,
Expand All @@ -99,8 +102,9 @@ def _configure_iam_role(config):
profile.add_role(RoleName=role.name)
time.sleep(15) # wait for propagation

logger.info("Role not specified for head node, using {}".format(
profile.arn))
logger.info("_configure_iam_role: "
"Role not specified for head node, using {}".format(
profile.arn))
config["head_node"]["IamInstanceProfile"] = {"Arn": profile.arn}

return config
Expand All @@ -126,7 +130,8 @@ def _configure_key_pair(config):

# We can safely create a new key.
if not key and not os.path.exists(key_path):
logger.info("Creating new key pair {}".format(key_name))
logger.info("_configure_key_pair: "
"Creating new key pair {}".format(key_name))
key = ec2.create_key_pair(KeyName=key_name)
with open(key_path, "w") as f:
f.write(key.key_material)
Expand All @@ -142,7 +147,8 @@ def _configure_key_pair(config):
assert os.path.exists(key_path), \
"Private key file {} not found for {}".format(key_path, key_name)

logger.info("KeyName not specified for nodes, using {}".format(key_name))
logger.info("_configure_key_pair: "
"KeyName not specified for nodes, using {}".format(key_name))

config["auth"]["ssh_private_key"] = key_path
config["head_node"]["KeyName"] = key_name
Expand Down Expand Up @@ -174,19 +180,21 @@ def _configure_subnet(config):
"No usable subnets matching availability zone {} "
"found. Choose a different availability zone or try "
"manually creating an instance in your specified region "
"to populate the list of subnets and trying this again."
.format(config["provider"]["availability_zone"]))
"to populate the list of subnets and trying this again.".
format(config["provider"]["availability_zone"]))

subnet_ids = [s.subnet_id for s in subnets]
subnet_descr = [(s.subnet_id, s.availability_zone) for s in subnets]
if "SubnetIds" not in config["head_node"]:
config["head_node"]["SubnetIds"] = subnet_ids
logger.info("SubnetIds not specified for head node,"
" using {}".format(subnet_descr))
logger.info("_configure_subnet: "
"SubnetIds not specified for head node, using {}".format(
subnet_descr))

if "SubnetIds" not in config["worker_nodes"]:
config["worker_nodes"]["SubnetIds"] = subnet_ids
logger.info("SubnetId not specified for workers,"
logger.info("_configure_subnet: "
"SubnetId not specified for workers,"
" using {}".format(subnet_descr))

return config
Expand All @@ -202,7 +210,8 @@ def _configure_security_group(config):
security_group = _get_security_group(config, vpc_id, group_name)

if security_group is None:
logger.info("Creating new security group {}".format(group_name))
logger.info("_configure_security_group: "
"Creating new security group {}".format(group_name))
client = _client("ec2", config)
client.create_security_group(
Description="Auto-created security group for Ray workers",
Expand Down Expand Up @@ -230,12 +239,14 @@ def _configure_security_group(config):

if "SecurityGroupIds" not in config["head_node"]:
logger.info(
"_configure_security_group: "
"SecurityGroupIds not specified for head node, using {}".format(
security_group.group_name))
config["head_node"]["SecurityGroupIds"] = [security_group.id]

if "SecurityGroupIds" not in config["worker_nodes"]:
logger.info(
"_configure_security_group: "
"SecurityGroupIds not specified for workers, using {}".format(
security_group.group_name))
config["worker_nodes"]["SecurityGroupIds"] = [security_group.id]
Expand Down
135 changes: 96 additions & 39 deletions python/ray/autoscaler/aws/node_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,16 @@
from __future__ import print_function

import random
import threading
from collections import defaultdict

import boto3
from botocore.config import Config

from ray.autoscaler.node_provider import NodeProvider
from ray.autoscaler.tags import TAG_RAY_CLUSTER_NAME, TAG_RAY_NODE_NAME
from ray.ray_constants import BOTO_MAX_RETRIES
from ray.autoscaler.log_timer import LogTimer


def to_aws_format(tags):
Expand Down Expand Up @@ -40,15 +43,59 @@ def __init__(self, provider_config, cluster_name):
# Try availability zones round-robin, starting from random offset
self.subnet_idx = random.randint(0, 100)

self.tag_cache = {} # Tags that we believe to actually be on EC2.
self.tag_cache_pending = {} # Tags that we will soon upload.
self.tag_cache_lock = threading.Lock()
self.tag_cache_update_event = threading.Event()
self.tag_cache_kill_event = threading.Event()
self.tag_update_thread = threading.Thread(
target=self._node_tag_update_loop)
self.tag_update_thread.start()

# Cache of node objects from the last nodes() call. This avoids
# excessive DescribeInstances requests.
self.cached_nodes = {}

# Cache of ip lookups. We assume IPs never change once assigned.
self.internal_ip_cache = {}
self.external_ip_cache = {}
def _node_tag_update_loop(self):
""" Update the AWS tags for a cluster periodically.
The purpose of this loop is to avoid excessive EC2 calls when a large
number of nodes are being launched simultaneously.
"""
while True:
self.tag_cache_update_event.wait()
self.tag_cache_update_event.clear()

batch_updates = defaultdict(list)

with self.tag_cache_lock:
for node_id, tags in self.tag_cache_pending.items():
for x in tags.items():
batch_updates[x].append(node_id)
self.tag_cache[node_id].update(tags)

self.tag_cache_pending = {}

for (k, v), node_ids in batch_updates.items():
m = "Set tag {}={} on {}".format(k, v, node_ids)
with LogTimer("AWSNodeProvider: {}".format(m)):
if k == TAG_RAY_NODE_NAME:
k = "Name"
self.ec2.meta.client.create_tags(
Resources=node_ids,
Tags=[{
"Key": k,
"Value": v
}],
)

self.tag_cache_kill_event.wait(timeout=5)
if self.tag_cache_kill_event.is_set():
return

def nodes(self, tag_filters):
# Note that these filters are acceptable because they are set on
# node initialization, and so can never be sitting in the cache.
tag_filters = to_aws_format(tag_filters)
filters = [
{
Expand All @@ -65,9 +112,19 @@ def nodes(self, tag_filters):
"Name": "tag:{}".format(k),
"Values": [v],
})
instances = list(self.ec2.instances.filter(Filters=filters))
self.cached_nodes = {i.id: i for i in instances}
return [i.id for i in instances]

nodes = list(self.ec2.instances.filter(Filters=filters))
# Populate the tag cache with initial information if necessary
for node in nodes:
if node.id in self.tag_cache:
continue

self.tag_cache[node.id] = from_aws_format(
{x["Key"]: x["Value"]
for x in node.tags})

self.cached_nodes = {node.id: node for node in nodes}
return [node.id for node in nodes]

def is_running(self, node_id):
node = self._node(node_id)
Expand All @@ -79,40 +136,25 @@ def is_terminated(self, node_id):
return state not in ["running", "pending"]

def node_tags(self, node_id):
node = self._node(node_id)
tags = {}
for tag in node.tags:
tags[tag["Key"]] = tag["Value"]
return from_aws_format(tags)
with self.tag_cache_lock:
d1 = self.tag_cache[node_id]
d2 = self.tag_cache_pending.get(node_id, {})
return dict(d1, **d2)

def external_ip(self, node_id):
if node_id in self.external_ip_cache:
return self.external_ip_cache[node_id]
node = self._node(node_id)
ip = node.public_ip_address
if ip:
self.external_ip_cache[node_id] = ip
return ip
return self._node(node_id).public_ip_address

def internal_ip(self, node_id):
if node_id in self.internal_ip_cache:
return self.internal_ip_cache[node_id]
node = self._node(node_id)
ip = node.private_ip_address
if ip:
self.internal_ip_cache[node_id] = ip
return ip
return self._node(node_id).private_ip_address

def set_node_tags(self, node_id, tags):
tags = to_aws_format(tags)
node = self._node(node_id)
tag_pairs = []
for k, v in tags.items():
tag_pairs.append({
"Key": k,
"Value": v,
})
node.create_tags(Tags=tag_pairs)
with self.tag_cache_lock:
try:
self.tag_cache_pending[node_id].update(tags)
except KeyError:
self.tag_cache_pending[node_id] = tags

self.tag_cache_update_event.set()

def create_node(self, node_config, tags, count):
tags = to_aws_format(tags)
Expand Down Expand Up @@ -166,9 +208,24 @@ def terminate_node(self, node_id):
node = self._node(node_id)
node.terminate()

self.tag_cache.pop(node_id, None)
self.tag_cache_pending.pop(node_id, None)

def terminate_nodes(self, node_ids):
self.ec2.meta.client.terminate_instances(InstanceIds=node_ids)

for node_id in node_ids:
self.tag_cache.pop(node_id, None)
self.tag_cache_pending.pop(node_id, None)

def _node(self, node_id):
if node_id in self.cached_nodes:
return self.cached_nodes[node_id]
matches = list(self.ec2.instances.filter(InstanceIds=[node_id]))
assert len(matches) == 1, "Invalid instance id {}".format(node_id)
return matches[0]
if node_id not in self.cached_nodes:
self.nodes({}) # Side effect: should cache it.

assert node_id in self.cached_nodes, "Invalid instance id {}".format(
node_id)
return self.cached_nodes[node_id]

def cleanup(self):
self.tag_cache_update_event.set()
self.tag_cache_kill_event.set()
Loading

0 comments on commit 315edab

Please sign in to comment.