From 1e352bd711610dbcf283e793ba5fbfa1638f6127 Mon Sep 17 00:00:00 2001 From: rzlim08 <37033997+rzlim08@users.noreply.github.com> Date: Mon, 25 Sep 2023 10:05:37 -0700 Subject: [PATCH] [CZID-8390] Add sqs step notifications (#116) * add current sqs queue to sqs notification step * run black * terraform fmt * typing * change to using sns * add formatting * add options to turn off step notifications --- Dockerfile | 2 + main.tf | 1 + miniwdl-plugins/sns_notification/README.md | 1 + miniwdl-plugins/sns_notification/setup.py | 29 ++++ .../sns_notification/sns_notification.py | 80 ++++++++++ terraform/modules/swipe-sfn-batch-job/main.tf | 1 + .../modules/swipe-sfn-batch-job/variables.tf | 6 + terraform/modules/swipe-sfn/main.tf | 27 ++-- terraform/modules/swipe-sfn/notifications.tf | 7 +- terraform/modules/swipe-sfn/variables.tf | 7 + test/terraform/moto/main.tf | 4 +- test/test_wdl.py | 145 ++++++++++++------ variables.tf | 6 + version | 2 +- 14 files changed, 249 insertions(+), 69 deletions(-) create mode 100644 miniwdl-plugins/sns_notification/README.md create mode 100644 miniwdl-plugins/sns_notification/setup.py create mode 100644 miniwdl-plugins/sns_notification/sns_notification.py diff --git a/Dockerfile b/Dockerfile index 1e745324..88c8d99e 100644 --- a/Dockerfile +++ b/Dockerfile @@ -46,6 +46,7 @@ RUN apt-get -q update && apt-get -q install -y \ # upgrade because of this issue https://github.com/chanzuckerberg/miniwdl/issues/607 in miniwdl RUN pip3 install importlib-metadata==4.13.0 RUN pip3 install miniwdl==${MINIWDL_VERSION} +RUN pip3 install urllib3==1.26.16 RUN curl -Ls https://github.com/chanzuckerberg/s3parcp/releases/download/v1.0.1/s3parcp_1.0.1_linux_amd64.tar.gz | tar -C /usr/bin -xz s3parcp @@ -62,6 +63,7 @@ ADD miniwdl-plugins miniwdl-plugins RUN pip install miniwdl-plugins/s3upload RUN pip install miniwdl-plugins/sfn_wdl RUN pip install miniwdl-plugins/s3parcp_download +RUN pip install miniwdl-plugins/sns_notification RUN cd /usr/bin; curl -O https://amazon-ecr-credential-helper-releases.s3.amazonaws.com/0.4.0/linux-amd64/docker-credential-ecr-login RUN chmod +x /usr/bin/docker-credential-ecr-login diff --git a/main.tf b/main.tf index f9e8687f..a114f275 100644 --- a/main.tf +++ b/main.tf @@ -64,6 +64,7 @@ module "sfn" { stage_vcpu_defaults = var.stage_vcpu_defaults extra_env_vars = var.extra_env_vars sqs_queues = var.sqs_queues + step_notifications = var.step_notifications call_cache = var.call_cache output_status_json_files = var.output_status_json_files tags = var.tags diff --git a/miniwdl-plugins/sns_notification/README.md b/miniwdl-plugins/sns_notification/README.md new file mode 100644 index 00000000..243062ab --- /dev/null +++ b/miniwdl-plugins/sns_notification/README.md @@ -0,0 +1 @@ +# sns_notifications \ No newline at end of file diff --git a/miniwdl-plugins/sns_notification/setup.py b/miniwdl-plugins/sns_notification/setup.py new file mode 100644 index 00000000..ce3e66ca --- /dev/null +++ b/miniwdl-plugins/sns_notification/setup.py @@ -0,0 +1,29 @@ +#!/usr/bin/env python3 +from setuptools import setup +from os import path + +this_directory = path.abspath(path.dirname(__file__)) +with open(path.join(path.dirname(__file__), "README.md")) as f: + long_description = f.read() + +setup( + name="sns_notification", + version="0.0.1", + description="miniwdl plugin for notification of task completion to Amazon SQS", + url="https://github.com/chanzuckerberg/swipe", + project_urls={}, + long_description=long_description, + long_description_content_type="text/markdown", + author="", + py_modules=["sns_notification"], + python_requires=">=3.6", + setup_requires=["reentry"], + install_requires=["boto3"], + reentry_register=True, + entry_points={ + "miniwdl.plugin.task": ["sns_notification_task = sns_notification:task"], + "miniwdl.plugin.workflow": [ + "sns_notification_workflow = sns_notification:workflow" + ], + }, +) diff --git a/miniwdl-plugins/sns_notification/sns_notification.py b/miniwdl-plugins/sns_notification/sns_notification.py new file mode 100644 index 00000000..da0e9e91 --- /dev/null +++ b/miniwdl-plugins/sns_notification/sns_notification.py @@ -0,0 +1,80 @@ +""" +Send SNS notifications after each miniwdl step +""" + +import os +import json +from typing import Dict +from datetime import datetime +from WDL import values_to_json +from WDL._util import StructuredLogMessage as _ + +import boto3 + +sns_client = boto3.client("sns", endpoint_url=os.getenv("AWS_ENDPOINT_URL")) +topic_arn = os.getenv('STEP_NOTIFICATION_TOPIC_ARN') + + +def process_outputs(outputs: Dict): + """process outputs dict into string to be passed into SQS""" + # only stringify for now + return json.dumps(outputs) + + +def send_message(attr, body): + """send message to SNS""" + sns_resp = sns_client.publish( + TopicArn=topic_arn, + Message=body, + MessageAttributes=attr, + ) + return sns_resp + + +def task(cfg, logger, run_id, run_dir, task, **recv): + """ + on completion of any task sends a message to sns with the output files + """ + log = logger.getChild("sns_step_notification") + + # ignore inputs + recv = yield recv + # ignore command/runtime/container + recv = yield recv + + log.info(_("sending message to sns")) + + if topic_arn: + message_attributes = { + "WorkflowName": {"DataType": "String", "StringValue": run_id[0]}, + "TaskName": {"DataType": "String", "StringValue": run_id[-1]}, + "ExecutionId": { + "DataType": "String", + "StringValue": "execution_id_to_be_passed_in", + }, + } + + outputs = process_outputs(values_to_json(recv["outputs"])) + message_body = { + "version": "0", + "id": "0", + "detail-type": "Step Functions Execution Step Notification", + "source": "aws.batch", + "account": "", + "time": datetime.now().strftime("%Y-%m-%dT%H:%M:%SZ"), + "resources": [], + "detail": outputs, + } + send_message(message_attributes, json.dumps(message_body)) + + yield recv + + +def workflow(cfg, logger, run_id, run_dir, workflow, **recv): + log = logger.getChild("sns_step_notification") + + # ignore inputs + recv = yield recv + + log.info(_("ignores workflow calls")) + yield recv diff --git a/terraform/modules/swipe-sfn-batch-job/main.tf b/terraform/modules/swipe-sfn-batch-job/main.tf index 182e8960..82cfc17f 100644 --- a/terraform/modules/swipe-sfn-batch-job/main.tf +++ b/terraform/modules/swipe-sfn-batch-job/main.tf @@ -34,6 +34,7 @@ locals { "MINIWDL__DOWNLOAD_CACHE__DISABLE_PATTERNS" = "[\"s3://swipe-samples-*/*\"]", "DOWNLOAD_CACHE_MAX_GB" = "500", "WDL_PASSTHRU_ENVVARS" = join(" ", [for k, v in var.extra_env_vars : k]), + "STEP_NOTIFICATION_TOPIC_ARN" = var.sfn_notification_topic_arn, "OUTPUT_STATUS_JSON_FILES" = tostring(var.output_status_json_files) }) container_env_vars = { "environment" : [for k in sort(keys(local.batch_env_vars)) : { "name" : k, "value" : local.batch_env_vars[k] }] } diff --git a/terraform/modules/swipe-sfn-batch-job/variables.tf b/terraform/modules/swipe-sfn-batch-job/variables.tf index 858084cf..4c52b5d3 100644 --- a/terraform/modules/swipe-sfn-batch-job/variables.tf +++ b/terraform/modules/swipe-sfn-batch-job/variables.tf @@ -70,3 +70,9 @@ variable "docker_network" { type = string default = "" } + +variable "sfn_notification_topic_arn" { + description = "ARN of notification sns topic" + type = string +} + diff --git a/terraform/modules/swipe-sfn/main.tf b/terraform/modules/swipe-sfn/main.tf index 7135ba87..7c2f291c 100644 --- a/terraform/modules/swipe-sfn/main.tf +++ b/terraform/modules/swipe-sfn/main.tf @@ -32,19 +32,20 @@ resource "aws_iam_role_policy_attachment" "swipe_sfn_service" { } module "batch_job" { - source = "../swipe-sfn-batch-job" - app_name = var.app_name - batch_job_docker_image = var.batch_job_docker_image - batch_job_timeout_seconds = var.batch_job_timeout_seconds - miniwdl_dir = var.miniwdl_dir - workspace_s3_prefixes = var.workspace_s3_prefixes - wdl_workflow_s3_prefix = var.wdl_workflow_s3_prefix - job_policy_arns = var.job_policy_arns - extra_env_vars = var.extra_env_vars - docker_network = var.docker_network - call_cache = var.call_cache - output_status_json_files = var.output_status_json_files - tags = var.tags + source = "../swipe-sfn-batch-job" + app_name = var.app_name + batch_job_docker_image = var.batch_job_docker_image + batch_job_timeout_seconds = var.batch_job_timeout_seconds + miniwdl_dir = var.miniwdl_dir + workspace_s3_prefixes = var.workspace_s3_prefixes + wdl_workflow_s3_prefix = var.wdl_workflow_s3_prefix + job_policy_arns = var.job_policy_arns + extra_env_vars = var.extra_env_vars + docker_network = var.docker_network + call_cache = var.call_cache + output_status_json_files = var.output_status_json_files + sfn_notification_topic_arn = length(var.sqs_queues) > 0 && var.step_notifications ? aws_sns_topic.sfn_notifications_topic[0].arn : "" + tags = var.tags } module "sfn_io_helper" { diff --git a/terraform/modules/swipe-sfn/notifications.tf b/terraform/modules/swipe-sfn/notifications.tf index c0a04eaf..c7025b47 100644 --- a/terraform/modules/swipe-sfn/notifications.tf +++ b/terraform/modules/swipe-sfn/notifications.tf @@ -64,9 +64,10 @@ data "aws_iam_policy_document" "sfn_notifications_topic_policy_document" { resource "aws_sns_topic_subscription" "sfn_notifications_sqs_target" { for_each = var.sqs_queues - topic_arn = aws_sns_topic.sfn_notifications_topic[0].arn - protocol = "sqs" - endpoint = aws_sqs_queue.sfn_notifications_queue[each.key].arn + topic_arn = aws_sns_topic.sfn_notifications_topic[0].arn + protocol = "sqs" + endpoint = aws_sqs_queue.sfn_notifications_queue[each.key].arn + raw_message_delivery = true } resource "aws_sqs_queue" "sfn_notifications_queue" { diff --git a/terraform/modules/swipe-sfn/variables.tf b/terraform/modules/swipe-sfn/variables.tf index c137aea1..8f58e206 100644 --- a/terraform/modules/swipe-sfn/variables.tf +++ b/terraform/modules/swipe-sfn/variables.tf @@ -116,3 +116,10 @@ variable "metrics_schedule" { type = string default = "rate(1 minute)" } + +variable "step_notifications" { + description = "Boolean to determine whether or not to use send step notifications with SNS" + type = bool + default = false +} + diff --git a/test/terraform/moto/main.tf b/test/terraform/moto/main.tf index 6b974dd0..d6a6a4be 100644 --- a/test/terraform/moto/main.tf +++ b/test/terraform/moto/main.tf @@ -32,7 +32,7 @@ module "swipetest" { "Two" : { "spot" : 12800, "on_demand" : 256000 }, } - workspace_s3_prefixes = ["swipe-test"] - + workspace_s3_prefixes = ["swipe-test"] output_status_json_files = true + step_notifications = true } diff --git a/test/test_wdl.py b/test/test_wdl.py index 419a7331..7648a8b1 100644 --- a/test/test_wdl.py +++ b/test/test_wdl.py @@ -241,7 +241,7 @@ class TestSFNWDL(unittest.TestCase): def setUp(self) -> None: - self.logger = logging.getLogger('test-wdl') + self.logger = logging.getLogger("test-wdl") self.s3 = boto3.resource("s3", endpoint_url="http://localhost:9000") self.s3_client = boto3.client("s3", endpoint_url="http://localhost:9000") @@ -258,10 +258,16 @@ def setUp(self) -> None: self.wdl_two_obj = self.test_bucket.Object("test-two-v1.0.0.wdl") self.wdl_two_obj.put(Body=test_two_wdl.encode()) self.wdl_obj_temp = self.test_bucket.Object("test-temp-v1.0.0.wdl") - self.wdl_obj_temp.put(Body=test_wdl_temp.replace("swipe_test", "temp_test").encode()) + self.wdl_obj_temp.put( + Body=test_wdl_temp.replace("swipe_test", "temp_test").encode() + ) with NamedTemporaryFile(suffix=".wdl.zip") as f: - Zip.build(load(join(dirname(realpath(__file__)), 'multi_wdl/run.wdl')), f.name, self.logger) + Zip.build( + load(join(dirname(realpath(__file__)), "multi_wdl/run.wdl")), + f.name, + self.logger, + ) self.wdl_zip_object = self.test_bucket.Object("test-v1.0.0.wdl.zip") self.wdl_zip_object.upload_file(f.name) @@ -291,13 +297,31 @@ def tearDown(self) -> None: ) self.test_bucket.delete() + def retrieve_message(self, url: str) -> Dict: + """Retrieve a single SQS message and delete it from queue""" + resp = self.sqs.receive_message( + QueueUrl=url, + MaxNumberOfMessages=1, + ) + # If no messages, just return + if not resp.get("Messages", None): + return {} + + message = resp["Messages"][0] + receipt_handle = message["ReceiptHandle"] + self.sqs.delete_message( + QueueUrl=url, + ReceiptHandle=receipt_handle, + ) + return json.loads(message["Body"]) + def _wait_sfn( self, sfn_input: Dict, sfn_arn: str, n_stages: int = 1, - expect_success: bool = True - ) -> Tuple[str, Dict, List[Dict]]: + expect_success: bool = True, + ) -> Tuple[str, Dict, List[Dict], List[Dict]]: execution_name = "swipe-test-{}".format(int(time.time())) res = self.sfn.start_execution( stateMachineArn=sfn_arn, name=execution_name, input=json.dumps(sfn_input) @@ -311,7 +335,10 @@ def _wait_sfn( print("printing execution history", file=sys.stderr) seen_events = set() - for event in sorted(self.sfn.get_execution_history(executionArn=arn)["events"], key=lambda x: x["id"]): + for event in sorted( + self.sfn.get_execution_history(executionArn=arn)["events"], + key=lambda x: x["id"], + ): if event["id"] not in seen_events: details = {} for key in event.keys(): @@ -325,36 +352,46 @@ def _wait_sfn( details.get("name", ""), json.loads(details.get("parameters", "{}")).get("FunctionName", ""), file=sys.stderr, - ) + ) if "taskSubmittedEventDetails" in event: - if event.get("taskSubmittedEventDetails", {}).get("resourceType") == "batch": - job_id = json.loads(event["taskSubmittedEventDetails"]["output"])["JobId"] + if ( + event.get("taskSubmittedEventDetails", {}).get("resourceType") + == "batch" + ): + job_id = json.loads( + event["taskSubmittedEventDetails"]["output"] + )["JobId"] print(f"Batch job ID {job_id}", file=sys.stderr) job_desc = self.batch.describe_jobs(jobs=[job_id])["jobs"][0] try: - log_group_name = job_desc["container"]["logConfiguration"]["options"]["awslogs-group"] + log_group_name = job_desc["container"]["logConfiguration"][ + "options" + ]["awslogs-group"] except KeyError: log_group_name = "/aws/batch/job" response = self.logs.get_log_events( logGroupName=log_group_name, - logStreamName=job_desc["container"]["logStreamName"] + logStreamName=job_desc["container"]["logStreamName"], ) for log_event in response["events"]: print(log_event["message"], file=sys.stderr) seen_events.add(event["id"]) - resp = self.sqs.receive_message( - QueueUrl=self.state_change_queue_url, - MaxNumberOfMessages=n_stages, - ) - print(resp) - messages = resp["Messages"] + status_notification = [] + step_notification = [] + while message := self.retrieve_message(self.state_change_queue_url): + if message["source"] == "aws.batch": + step_notification.append(message) + elif message["source"] == "aws.states": + status_notification.append(message) if expect_success: self.assertEqual(description["status"], "SUCCEEDED", description) else: self.assertEqual(description["status"], "FAILED", description) - return arn, description, messages + + self.assertEqual(len(status_notification), n_stages) + return arn, description, status_notification, step_notification def test_simple_sfn_wdl_workflow(self): output_prefix = "out-1" @@ -369,23 +406,27 @@ def test_simple_sfn_wdl_workflow(self): }, } - arn, description, messages = self._wait_sfn(sfn_input, self.single_sfn_arn) + arn, description, messages, step_messages = self._wait_sfn( + sfn_input, self.single_sfn_arn + ) output = json.loads(description["output"]) - self.assertEqual(output["Result"], { - "swipe_test.out_world": f"s3://{self.input_obj.bucket_name}/{output_prefix}/test-1/out_world.txt", - "swipe_test.out_goodbye": f"s3://{self.input_obj.bucket_name}/{output_prefix}/test-1/out_goodbye.txt", - "swipe_test.out_farewell": f"s3://{self.input_obj.bucket_name}/{output_prefix}/test-1/out_farewell.txt", - }) + self.assertEqual( + output["Result"], + { + "swipe_test.out_world": f"s3://{self.input_obj.bucket_name}/{output_prefix}/test-1/out_world.txt", + "swipe_test.out_goodbye": f"s3://{self.input_obj.bucket_name}/{output_prefix}/test-1/out_goodbye.txt", + "swipe_test.out_farewell": f"s3://{self.input_obj.bucket_name}/{output_prefix}/test-1/out_farewell.txt", + }, + ) outputs_obj = self.test_bucket.Object(f"{output_prefix}/test-1/out_world.txt") output_text = outputs_obj.get()["Body"].read().decode() self.assertEqual(output_text, "hello\nworld\n") - self.assertEqual(json.loads(messages[0]["Body"])["detail"]["executionArn"], arn) - self.assertEqual( - json.loads(messages[0]["Body"])["detail"]["lastCompletedStage"], "run" - ) + self.assertEqual(messages[0]["detail"]["executionArn"], arn) + self.assertEqual(messages[0]["detail"]["lastCompletedStage"], "run") + self.assertEqual(len(step_messages), 3) # 3 steps to the inputs def test_https_inputs(self): output_prefix = "out-https-1" @@ -415,9 +456,12 @@ def test_failing_wdl_workflow(self): }, } - arn, description, messages = self._wait_sfn(sfn_input, self.single_sfn_arn, expect_success=False) - errorType = (self.sfn.get_execution_history(executionArn=arn)["events"] - [-1]["executionFailedEventDetails"]["error"]) + arn, description, messages, _ = self._wait_sfn( + sfn_input, self.single_sfn_arn, expect_success=False + ) + errorType = self.sfn.get_execution_history(executionArn=arn)["events"][-1][ + "executionFailedEventDetails" + ]["error"] self.assertTrue(errorType in ["UncaughtError", "RunFailed"]) def test_temp_tag(self): @@ -437,8 +481,7 @@ def test_temp_tag(self): # test temporary tag is there for intermediate file temporary_tagset = self.s3_client.get_object_tagging( - Bucket="swipe-test", - Key=f"{output_prefix}/test-temp-1/temporary.txt" + Bucket="swipe-test", Key=f"{output_prefix}/test-temp-1/temporary.txt" ).get("TagSet", []) self.assertEqual(len(temporary_tagset), 1) self.assertEqual(temporary_tagset[0].get("Key"), "intermediate_output") @@ -446,8 +489,7 @@ def test_temp_tag(self): # test temporary tag got removed for output file output_tagset = self.s3_client.get_object_tagging( - Bucket="swipe-test", - Key=f"{output_prefix}/test-temp-1/out_world.txt" + Bucket="swipe-test", Key=f"{output_prefix}/test-temp-1/out_world.txt" ).get("TagSet", []) self.assertEqual(len(output_tagset), 0) @@ -469,7 +511,7 @@ def test_staged_sfn_wdl_workflow(self): }, } - _, _, messages = self._wait_sfn(sfn_input, self.stage_sfn_arn, 2) + _, _, messages, _ = self._wait_sfn(sfn_input, self.stage_sfn_arn, 2) outputs_obj = self.test_bucket.Object( f"{output_prefix}/test-1/happy_message.txt" @@ -477,12 +519,8 @@ def test_staged_sfn_wdl_workflow(self): output_text = outputs_obj.get()["Body"].read().decode() self.assertEqual(output_text, "hello\nworld\n:)\n") - self.assertEqual( - json.loads(messages[0]["Body"])["detail"]["lastCompletedStage"], "one" - ) - self.assertEqual( - json.loads(messages[1]["Body"])["detail"]["lastCompletedStage"], "two" - ) + self.assertEqual(messages[0]["detail"]["lastCompletedStage"], "one") + self.assertEqual(messages[1]["detail"]["lastCompletedStage"], "two") def test_call_cache(self): output_prefix = "out-3" @@ -514,24 +552,28 @@ def test_call_cache(self): # clear cache to simulate getting cut off the step before this one objects = self.s3_client.list_objects_v2( - Bucket=self.test_bucket.name, - Prefix=f"{output_prefix}/test-1/cache/add_farewell/", + Bucket=self.test_bucket.name, + Prefix=f"{output_prefix}/test-1/cache/add_farewell/", )["Contents"] self.test_bucket.Object(objects[0]["Key"]).delete() objects = self.s3_client.list_objects_v2( - Bucket=self.test_bucket.name, - Prefix=f"{output_prefix}/test-1/cache/swipe_test/", + Bucket=self.test_bucket.name, + Prefix=f"{output_prefix}/test-1/cache/swipe_test/", )["Contents"] self.test_bucket.Object(objects[0]["Key"]).delete() self.test_bucket.Object(out_json_path).delete() self._wait_sfn(sfn_input, self.single_sfn_arn) - outputs = json.loads(self.test_bucket.Object(out_json_path).get()["Body"].read().decode()) + outputs = json.loads( + self.test_bucket.Object(out_json_path).get()["Body"].read().decode() + ) for v in outputs.values(): self.assert_(v.startswith("s3://"), f"{v} does not start with 's3://'") - outputs_obj = self.test_bucket.Object(f"{output_prefix}/test-1/out_farewell.txt") + outputs_obj = self.test_bucket.Object( + f"{output_prefix}/test-1/out_farewell.txt" + ) output_text = outputs_obj.get()["Body"].read().decode() self.assertEqual(output_text, "cache_break\nfarewell\n") @@ -550,7 +592,7 @@ def test_zip_wdls(self): self._wait_sfn(sfn_input, self.single_sfn_arn) self.sqs.receive_message( - QueueUrl=self.state_change_queue_url, MaxNumberOfMessages=1 + QueueUrl=self.state_change_queue_url, MaxNumberOfMessages=1 ) outputs_obj = self.test_bucket.Object(f"{output_prefix}/test-1/out_bar.txt") @@ -575,7 +617,10 @@ def test_status_reporting(self): status_json = json.loads( self.test_bucket.Object( f"{output_prefix}/test-1/test_status2.json", - ).get()["Body"].read().decode(), + ) + .get()["Body"] + .read() + .decode(), ) self.assertEqual(status_json["add_world"]["status"], "uploaded") self.assertEqual(status_json["add_goodbye"]["status"], "uploaded") diff --git a/variables.tf b/variables.tf index 2e08455e..8b4d1856 100644 --- a/variables.tf +++ b/variables.tf @@ -194,3 +194,9 @@ variable "output_status_json_files" { type = bool default = false } + +variable "step_notifications" { + description = "Boolean to determine whether or not to use send step notifications with SNS" + type = bool + default = false +} diff --git a/version b/version index 2aca8c01..d6379b87 100644 --- a/version +++ b/version @@ -1 +1 @@ -v1.4.6 +v1.4.8