diff --git a/Makefile b/Makefile index 021584cc..4a3c6461 100644 --- a/Makefile +++ b/Makefile @@ -3,7 +3,7 @@ SHELL=/bin/bash -o pipefail deploy-mock: rm terraform.tfstate || true aws ssm put-parameter --name /mock-aws/service/ecs/optimized-ami/amazon-linux-2/recommended/image_id --value ami-12345678 --type String --endpoint-url http://localhost:9000 - cp test/mock.tf .; unset TF_CLI_ARGS_init; terraform init; TF_VAR_mock=true TF_VAR_app_name=swipe-test TF_VAR_batch_ec2_instance_types='["optimal"]' terraform apply --auto-approve + cp test/mock.tf .; unset TF_CLI_ARGS_init; terraform init; TF_VAR_mock=true TF_VAR_app_name=swipe-test TF_VAR_batch_ec2_instance_types='["optimal"]' TF_VAR_sqs_queues='{"notifications":{"dead_letter": false}}' terraform apply --auto-approve lint: flake8 . diff --git a/terraform/modules/sfn-io-helper-lambdas/app/app.py b/terraform/modules/sfn-io-helper-lambdas/app/app.py index ad7a1148..05095d56 100644 --- a/terraform/modules/sfn-io-helper-lambdas/app/app.py +++ b/terraform/modules/sfn-io-helper-lambdas/app/app.py @@ -50,6 +50,10 @@ def preprocess_input(sfn_data, _): def process_stage_output(sfn_data, _): assert sfn_data["CurrentState"].endswith("ReadOutput") + stage_io.broadcast_stage_complete( + sfn_data["ExecutionId"], + sfn_data["CurrentState"][:-len("ReadOutput")], + ) sfn_state = stage_io.read_state_from_s3(sfn_state=sfn_data["Input"], current_state=sfn_data["CurrentState"]) stage_io.link_outputs(sfn_state) sfn_state = stage_io.trim_batch_job_details(sfn_state=sfn_state) diff --git a/terraform/modules/sfn-io-helper-lambdas/app/sfn_io_helper/__init__.py b/terraform/modules/sfn-io-helper-lambdas/app/sfn_io_helper/__init__.py index bb1d4baf..07fe6331 100644 --- a/terraform/modules/sfn-io-helper-lambdas/app/sfn_io_helper/__init__.py +++ b/terraform/modules/sfn-io-helper-lambdas/app/sfn_io_helper/__init__.py @@ -6,6 +6,7 @@ batch = boto3.client("batch", endpoint_url=os.getenv("AWS_ENDPOINT_URL")) stepfunctions = boto3.client("stepfunctions", endpoint_url=os.getenv("AWS_ENDPOINT_URL")) cloudwatch = boto3.client("cloudwatch", endpoint_url=os.getenv("AWS_ENDPOINT_URL")) +sqs = boto3.client("sqs", endpoint_url=os.getenv("AWS_ENDPOINT_URL")) def s3_object(uri): diff --git a/terraform/modules/sfn-io-helper-lambdas/app/sfn_io_helper/stage_io.py b/terraform/modules/sfn-io-helper-lambdas/app/sfn_io_helper/stage_io.py index 87c1e11c..c9334486 100644 --- a/terraform/modules/sfn-io-helper-lambdas/app/sfn_io_helper/stage_io.py +++ b/terraform/modules/sfn-io-helper-lambdas/app/sfn_io_helper/stage_io.py @@ -3,10 +3,12 @@ import json import logging from typing import List +from datetime import datetime +from uuid import uuid4 from botocore import xform_name -from . import s3_object +from . import s3_object, sqs logger = logging.getLogger() @@ -121,3 +123,51 @@ def preprocess_sfn_input(sfn_state, aws_region, aws_account_id, state_machine_na link_outputs(sfn_state) return sfn_state + + +def broadcast_stage_complete(execution_id: str, stage: str): + if "SQS_QUEUE_URLS" not in os.environ: + return + + sqs_queue_urls = os.environ["SQS_QUEUE_URLS"].split(",") + + assert len(execution_id.split(":")) == 8 + _, _, _, aws_region, aws_account_id, _, state_machine_name, execution_name = execution_id.split(":") + + state_machine_arn = f"arn:aws:states:{aws_region}:{aws_account_id}:stateMachine:{state_machine_name}" + execution_arn = f"arn:aws:states:{aws_region}:{aws_account_id}:execution:{execution_id}" + + body = json.dumps({ + "version": "0", + "id": str(uuid4()), + "detail-type": "Step Functions Execution Status Change", + "source": "aws.states", + "account": aws_account_id, + "time": datetime.now().strftime('%Y-%m-%dT%H:%M:%SZ'), + "region": aws_region, + "resources": [execution_arn], + "detail": { + "executionArn": execution_arn, + "stateMachineArn": state_machine_arn, + "name": execution_name, + "status": "RUNNING", + "lastCompletedStage": re.sub(r'(? "${defaults.spot}" }, { @@ -139,7 +147,6 @@ resource "aws_lambda_function" "lambda" { }, { for stage, defaults in var.stage_vcpu_defaults : "${stage}EC2VcpuDefault" => "${defaults.on_demand}" }, - ) } } diff --git a/terraform/modules/sfn-io-helper-lambdas/variables.tf b/terraform/modules/sfn-io-helper-lambdas/variables.tf index 0329ee9d..64078c22 100644 --- a/terraform/modules/sfn-io-helper-lambdas/variables.tf +++ b/terraform/modules/sfn-io-helper-lambdas/variables.tf @@ -58,3 +58,12 @@ variable "aws_account_id" { type = string } +variable "sfn_notification_queue_arns" { + description = "ARNs of notification SQS queues" + type = list(string) +} + +variable "sfn_notification_queue_urls" { + description = "URLs of notification SQS queues" + type = list(string) +} diff --git a/terraform/modules/swipe-sfn/main.tf b/terraform/modules/swipe-sfn/main.tf index 87f8154e..1bb4c811 100644 --- a/terraform/modules/swipe-sfn/main.tf +++ b/terraform/modules/swipe-sfn/main.tf @@ -45,17 +45,19 @@ module "batch_job" { } module "sfn_io_helper" { - source = "../sfn-io-helper-lambdas" - app_name = var.app_name - mock = var.mock - aws_region = data.aws_region.current.name - aws_account_id = data.aws_caller_identity.current.account_id - batch_queue_arns = [var.batch_spot_job_queue_arn, var.batch_on_demand_job_queue_arn] - workspace_s3_prefix = var.workspace_s3_prefix - wdl_workflow_s3_prefix = var.wdl_workflow_s3_prefix - stage_memory_defaults = var.stage_memory_defaults - stage_vcpu_defaults = var.stage_vcpu_defaults - tags = var.tags + source = "../sfn-io-helper-lambdas" + app_name = var.app_name + mock = var.mock + aws_region = data.aws_region.current.name + aws_account_id = data.aws_caller_identity.current.account_id + batch_queue_arns = [var.batch_spot_job_queue_arn, var.batch_on_demand_job_queue_arn] + workspace_s3_prefix = var.workspace_s3_prefix + wdl_workflow_s3_prefix = var.wdl_workflow_s3_prefix + stage_memory_defaults = var.stage_memory_defaults + stage_vcpu_defaults = var.stage_vcpu_defaults + sfn_notification_queue_arns = [for name, queue in aws_sqs_queue.sfn_notifications_queue : queue.arn] + sfn_notification_queue_urls = [for name, queue in aws_sqs_queue.sfn_notifications_queue : queue.url] + tags = var.tags } resource "aws_sfn_state_machine" "swipe_single_wdl" { diff --git a/test/test_wdl.py b/test/test_wdl.py index 41a37a9a..c472f3bb 100644 --- a/test/test_wdl.py +++ b/test/test_wdl.py @@ -58,6 +58,7 @@ def setUp(self) -> None: self.sfn = boto3.client("stepfunctions", endpoint_url="http://localhost:8083") self.test_bucket = self.s3.create_bucket(Bucket="swipe-test") self.lamb = boto3.client("lambda", endpoint_url="http://localhost:9000") + self.sqs = boto3.client("sqs", endpoint_url="http://localhost:9000") def test_simple_sfn_wdl_workflow(self): wdl_obj = self.test_bucket.Object("test-v1.0.0.wdl") @@ -83,8 +84,6 @@ def test_simple_sfn_wdl_workflow(self): input=json.dumps(sfn_input)) arn = res["executionArn"] - assert res - start = time.time() description = self.sfn.describe_execution(executionArn=arn) while description["status"] == "RUNNING" and time.time() < start + 2 * 60: @@ -94,10 +93,15 @@ def test_simple_sfn_wdl_workflow(self): for event in self.sfn.get_execution_history(executionArn=arn)["events"]: print(event, file=sys.stderr) - assert description["status"] == "SUCCEEDED", description + self.assertEqual(description["status"], "SUCCEEDED") outputs_obj = self.test_bucket.Object(f"{output_prefix}/test-1/out.txt") output_text = outputs_obj.get()['Body'].read().decode() - assert output_text == "hello\nworld\n", output_text + self.assertEqual(output_text, "hello\nworld\n") + + res = self.sqs.list_queues() + queue_url = res["QueueUrls"][0] + res = self.sqs.receive_message(QueueUrl=queue_url) + self.assertEqual(json.loads(res["Messages"][0]["Body"])["detail"]["lastCompletedStage"], "run") if __name__ == "__main__": diff --git a/version b/version index b094f206..a60b32a0 100644 --- a/version +++ b/version @@ -1 +1 @@ -v0.13.1-beta +v0.14.0-beta