Skip to content

Commit

Permalink
[autoscaler] Consolidate CloudWatch agent/dashboard/alarm support; Ad…
Browse files Browse the repository at this point in the history
…d unit tests for AWS autoscaler CloudWatch integration (ray-project#22070)

This PR mainly adds two improvements:

We have introduced three CloudWatch Config support in previous PRs: Agent, Dashboard and Alarm. In this PR, we generalize the logic of all three config types by using enum CloudwatchConfigType.
Adds unit tests to ensure the correctness of Ray autoscaler CloudWatch integration behavior.

Signed-off-by: Huaiwei Sun <[email protected]>
  • Loading branch information
Zyiqin-Miranda authored and scottsun94 committed Aug 9, 2022
1 parent dd16822 commit 4bc9c75
Show file tree
Hide file tree
Showing 5 changed files with 876 additions and 68 deletions.
161 changes: 94 additions & 67 deletions python/ray/autoscaler/_private/aws/cloudwatch/cloudwatch_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
import logging
import os
import time
from typing import Any, Dict, List, Tuple, Union
from enum import Enum
from typing import Any, Callable, Dict, List, Union

import botocore

Expand All @@ -21,6 +22,12 @@
CLOUDWATCH_CONFIG_HASH_TAG_BASE = "cloudwatch-config-hash"


class CloudwatchConfigType(str, Enum):
AGENT = "agent"
DASHBOARD = "dashboard"
ALARM = "alarm"


class CloudwatchHelper:
def __init__(
self, provider_config: Dict[str, Any], node_id: str, cluster_name: str
Expand All @@ -34,19 +41,34 @@ def __init__(
self.ssm_client = client_cache("ssm", region)
cloudwatch_resource = resource_cache("cloudwatch", region)
self.cloudwatch_client = cloudwatch_resource.meta.client
self.CLOUDWATCH_CONFIG_TYPE_TO_CONFIG_VARIABLE_REPLACE_FUNC: Dict[
str, Callable
] = {
CloudwatchConfigType.AGENT.value: self._replace_cwa_config_vars,
CloudwatchConfigType.DASHBOARD.value: self._replace_dashboard_config_vars,
CloudwatchConfigType.ALARM.value: self._load_config_file,
}
self.CLOUDWATCH_CONFIG_TYPE_TO_UPDATE_FUNC_HEAD_NODE: Dict[str, Callable] = {
CloudwatchConfigType.AGENT.value: self._restart_cloudwatch_agent,
CloudwatchConfigType.DASHBOARD.value: self._put_cloudwatch_dashboard,
CloudwatchConfigType.ALARM.value: self._put_cloudwatch_alarm,
}
self.CLOUDWATCH_CONFIG_TYPE_TO_UPDATE_FUNC_WORKER_NODE: Dict[str, Callable] = {
CloudwatchConfigType.AGENT.value: self._restart_cloudwatch_agent,
CloudwatchConfigType.ALARM.value: self._put_cloudwatch_alarm,
}

def update_from_config(self, is_head_node: bool) -> None:
"""Discovers and applies CloudWatch config updates as required.
Args:
is_head_node: whether this node is the head node.
"""
if CloudwatchHelper.cloudwatch_config_exists(self.provider_config, "agent"):
self._update_cloudwatch_config(is_head_node, "agent")
if CloudwatchHelper.cloudwatch_config_exists(self.provider_config, "dashboard"):
self._update_cloudwatch_config(is_head_node, "dashboard")
if CloudwatchHelper.cloudwatch_config_exists(self.provider_config, "alarm"):
self._update_cloudwatch_config(is_head_node, "alarm")
for config_type in CloudwatchConfigType:
if CloudwatchHelper.cloudwatch_config_exists(
self.provider_config, config_type.value
):
self._update_cloudwatch_config(config_type.value, is_head_node)

def _ec2_health_check_waiter(self, node_id: str) -> None:
# wait for all EC2 instance checks to complete
Expand All @@ -66,14 +88,10 @@ def _ec2_health_check_waiter(self, node_id: str) -> None:
)
raise e

def _update_cloudwatch_config(self, is_head_node: bool, config_type: str) -> None:
"""Update remote CloudWatch configs at Parameter Store,
update hash tag value on node and perform associated operations
at CloudWatch console if local CloudWatch configs change.
Args:
is_head_node: whether this node is the head node.
config_type: CloudWatch config file type.
def _update_cloudwatch_config(self, config_type: str, is_head_node: bool) -> None:
"""
check whether update operations are needed in
cloudwatch related configs
"""
cwa_installed = self._setup_cwa()
param_name = self._get_ssm_param_name(config_type)
Expand All @@ -84,30 +102,30 @@ def _update_cloudwatch_config(self, is_head_node: bool, config_type: str) -> Non
)
cur_cw_config_hash = self._sha1_hash_file(config_type)
ssm_cw_config_hash = self._sha1_hash_json(cw_config_ssm)
# check if user updated Unified Cloudwatch Agent config file.
# check if user updated cloudwatch related config files.
# if so, perform corresponding actions.
if cur_cw_config_hash != ssm_cw_config_hash:
logger.info(
"Cloudwatch {} config file has changed.".format(config_type)
)
self._upload_config_to_ssm_and_set_hash_tag(config_type)
if config_type == "agent":
self._restart_cloudwatch_agent()
elif config_type == "dashboard":
self._put_cloudwatch_dashboard()
elif config_type == "alarm":
self._put_cloudwatch_alarm()
self.CLOUDWATCH_CONFIG_TYPE_TO_UPDATE_FUNC_HEAD_NODE.get(
config_type
)()
else:
head_node_hash = self._get_head_node_config_hash(config_type)
cur_node_hash = self._get_cur_node_config_hash(config_type)
if head_node_hash != cur_node_hash:
logger.info(
"Cloudwatch {} config file has changed.".format(config_type)
)
if config_type == "agent":
self._restart_cloudwatch_agent()
if config_type == "alarm":
self._put_cloudwatch_alarm()
update_func = (
self.CLOUDWATCH_CONFIG_TYPE_TO_UPDATE_FUNC_WORKER_NODE.get(
config_type
)
)
if update_func:
update_func()
self._update_cloudwatch_hash_tag_value(
self.node_id, head_node_hash, config_type
)
Expand All @@ -120,7 +138,9 @@ def _put_cloudwatch_dashboard(self) -> Dict[str, Any]:
dashboard_name_cluster = dashboard_config.get("name", self.cluster_name)
dashboard_name = self.cluster_name + "-" + dashboard_name_cluster

widgets = self._replace_dashboard_config_variables()
widgets = self._replace_dashboard_config_vars(
CloudwatchConfigType.DASHBOARD.value
)

response = self.cloudwatch_client.put_dashboard(
DashboardName=dashboard_name, DashboardBody=json.dumps({"widgets": widgets})
Expand All @@ -144,7 +164,7 @@ def _put_cloudwatch_dashboard(self) -> Dict[str, Any]:

def _put_cloudwatch_alarm(self) -> None:
"""put CloudWatch metric alarms read from config"""
param_name = self._get_ssm_param_name("alarm")
param_name = self._get_ssm_param_name(CloudwatchConfigType.ALARM.value)
data = json.loads(self._get_ssm_param(param_name))
for item in data:
item_out = copy.deepcopy(item)
Expand All @@ -158,7 +178,7 @@ def _put_cloudwatch_alarm(self) -> None:
logger.info("Successfully put alarms to CloudWatch console")

def _send_command_to_node(
self, document_name: str, parameters: List[str], node_id: str
self, document_name: str, parameters: Dict[str, List[str]], node_id: str
) -> Dict[str, Any]:
"""send SSM command to the given nodes"""
logger.debug(
Expand All @@ -177,10 +197,10 @@ def _send_command_to_node(
def _ssm_command_waiter(
self,
document_name: str,
parameters: List[str],
parameters: Dict[str, List[str]],
node_id: str,
retry_failed: bool = True,
) -> bool:
) -> Dict[str, Any]:
"""wait for SSM command to complete on all cluster nodes"""

# This waiter differs from the built-in SSM.Waiter by
Expand All @@ -192,7 +212,9 @@ def _ssm_command_waiter(
command_id = response["Command"]["CommandId"]

cloudwatch_config = self.provider_config["cloudwatch"]
agent_retryer_config = cloudwatch_config.get("agent").get("retryer", {})
agent_retryer_config = cloudwatch_config.get(
CloudwatchConfigType.AGENT.value
).get("retryer", {})
max_attempts = agent_retryer_config.get("max_attempts", 120)
delay_seconds = agent_retryer_config.get("delay_seconds", 30)
num_attempts = 0
Expand Down Expand Up @@ -283,26 +305,32 @@ def _replace_config_variables(

def _replace_all_config_variables(
self,
collection: Union[dict, list],
collection: Union[Dict[str, Any], str],
node_id: str,
cluster_name: str,
region: str,
) -> Tuple[(Union[dict, list], int)]:
) -> Union[str, Dict[str, Any]]:
"""
Replace known config variable occurrences in the input collection.
The input collection must be either a dict or list.
Returns a tuple consisting of the output collection and the number of
modified strings in the collection (which is not necessarily equal to
the number of variables replaced).
"""

for key in collection:
if type(collection) is dict:
value = collection.get(key)
index_key = key
elif type(collection) is list:
value = key
index_key = collection.index(key)
else:
raise ValueError(
f"Can't replace CloudWatch config variables "
f"in unsupported collection type: {type(collection)}."
f"Please check your CloudWatch JSON config files."
)
if type(value) is str:
collection[index_key] = self._replace_config_variables(
value, node_id, cluster_name, region
Expand Down Expand Up @@ -344,8 +372,8 @@ def _set_cloudwatch_ssm_config_param(
return self._get_default_empty_config_file_hash()
else:
logger.info(
"Failed to fetch CloudWatch {} config from SSM "
"parameter store.".format(config_type)
"Failed to fetch Unified CloudWatch Agent config from SSM "
"parameter store."
)
logger.error(e)
raise e
Expand All @@ -368,31 +396,25 @@ def _get_ssm_param(self, parameter_name: str) -> str:

def _sha1_hash_json(self, value: str) -> str:
"""calculate the json string sha1 hash"""
hash = hashlib.new("sha1")
sha1_hash = hashlib.new("sha1")
binary_value = value.encode("ascii")
hash.update(binary_value)
sha1_res = hash.hexdigest()
sha1_hash.update(binary_value)
sha1_res = sha1_hash.hexdigest()
return sha1_res

def _sha1_hash_file(self, config_type: str) -> str:
"""calculate the config file sha1 hash"""
if config_type == "agent":
config = self._replace_cwa_config_variables()
if config_type == "dashboard":
config = self._replace_dashboard_config_variables()
if config_type == "alarm":
config = self._load_config_file("alarm")
config = self.CLOUDWATCH_CONFIG_TYPE_TO_CONFIG_VARIABLE_REPLACE_FUNC.get(
config_type
)(config_type)
value = json.dumps(config)
sha1_res = self._sha1_hash_json(value)
return sha1_res

def _upload_config_to_ssm_and_set_hash_tag(self, config_type: str):
if config_type == "agent":
data = self._replace_cwa_config_variables()
if config_type == "dashboard":
data = self._replace_dashboard_config_variables()
if config_type == "alarm":
data = self._load_config_file("alarm")
data = self.CLOUDWATCH_CONFIG_TYPE_TO_CONFIG_VARIABLE_REPLACE_FUNC.get(
config_type
)(config_type)
sha1_hash_value = self._sha1_hash_file(config_type)
self._upload_config_to_ssm(data, config_type)
self._update_cloudwatch_hash_tag_value(
Expand All @@ -405,7 +427,7 @@ def _add_cwa_installed_tag(self, node_id: str) -> None:
Tags=[{"Key": CLOUDWATCH_AGENT_INSTALLED_TAG, "Value": "True"}],
)
logger.info(
"Successfully add Unified Cloudwatch Agent installed "
"Successfully add Unified CloudWatch Agent installed "
"tag on {}".format(node_id)
)

Expand Down Expand Up @@ -444,12 +466,12 @@ def _upload_config_to_ssm(self, param: Dict[str, Any], config_type: str):
param_name = self._get_ssm_param_name(config_type)
self._put_ssm_param(param, param_name)

def _replace_cwa_config_variables(self) -> Dict[str, Any]:
def _replace_cwa_config_vars(self, config_type: str) -> Dict[str, Any]:
"""
replace {instance_id}, {region}, {cluster_name}
variable occurrences in Unified Cloudwatch Agent config file
"""
cwa_config = self._load_config_file("agent")
cwa_config = self._load_config_file(config_type)
self._replace_all_config_variables(
cwa_config,
self.node_id,
Expand All @@ -458,11 +480,11 @@ def _replace_cwa_config_variables(self) -> Dict[str, Any]:
)
return cwa_config

def _replace_dashboard_config_variables(self) -> List[Dict[str, Any]]:
def _replace_dashboard_config_vars(self, config_type: str) -> List[str]:
"""
replace known variable occurrences in CloudWatch Dashboard config file
"""
data = self._load_config_file("dashboard")
data = self._load_config_file(config_type)
widgets = []
for item in data:
item_out = self._replace_all_config_variables(
Expand All @@ -471,16 +493,15 @@ def _replace_dashboard_config_variables(self) -> List[Dict[str, Any]]:
self.cluster_name,
self.provider_config["region"],
)
item_out = copy.deepcopy(item)
widgets.append(item_out)
return widgets

def _replace_alarm_config_variables(self) -> List[Dict[str, Any]]:
def _replace_alarm_config_vars(self, config_type: str) -> List[str]:
"""
replace {instance_id}, {region}, {cluster_name}
variable occurrences in cloudwatch alarm config file
"""
data = self._load_config_file("alarm")
data = self._load_config_file(config_type)
param_data = []
for item in data:
item_out = copy.deepcopy(item)
Expand All @@ -494,11 +515,11 @@ def _replace_alarm_config_variables(self) -> List[Dict[str, Any]]:
return param_data

def _restart_cloudwatch_agent(self) -> None:
"""restart Unified Cloudwatch Agent"""
cwa_param_name = self._get_ssm_param_name("agent")
"""restart Unified CloudWatch Agent"""
cwa_param_name = self._get_ssm_param_name(CloudwatchConfigType.AGENT.value)
logger.info(
"Restarting Unified Cloudwatch Agent package on {} node(s).".format(
(self.node_id)
"Restarting Unified CloudWatch Agent package on node {}.".format(
self.node_id
)
)
self._stop_cloudwatch_agent()
Expand Down Expand Up @@ -691,7 +712,9 @@ def resolve_instance_profile_name(
default ray instance profile name if cloudwatch config file
doesn't exist.
"""
cwa_cfg_exists = CloudwatchHelper.cloudwatch_config_exists(config, "agent")
cwa_cfg_exists = CloudwatchHelper.cloudwatch_config_exists(
config, CloudwatchConfigType.AGENT.value
)
return (
CLOUDWATCH_RAY_INSTANCE_PROFILE
if cwa_cfg_exists
Expand All @@ -712,7 +735,9 @@ def resolve_iam_role_name(
default cloudwatch iam role name if cloudwatch config file exists.
default ray iam role name if cloudwatch config file doesn't exist.
"""
cwa_cfg_exists = CloudwatchHelper.cloudwatch_config_exists(config, "agent")
cwa_cfg_exists = CloudwatchHelper.cloudwatch_config_exists(
config, CloudwatchConfigType.AGENT.value
)
return CLOUDWATCH_RAY_IAM_ROLE if cwa_cfg_exists else default_iam_role_name

@staticmethod
Expand All @@ -731,7 +756,9 @@ def resolve_policy_arns(
related operations if cloudwatch agent config is specifed in
cluster config file.
"""
cwa_cfg_exists = CloudwatchHelper.cloudwatch_config_exists(config, "agent")
cwa_cfg_exists = CloudwatchHelper.cloudwatch_config_exists(
config, CloudwatchConfigType.AGENT.value
)
if cwa_cfg_exists:
cloudwatch_managed_policy = {
"Version": "2012-10-17",
Expand Down
Loading

0 comments on commit 4bc9c75

Please sign in to comment.