Skip to content

Commit

Permalink
notify of stage changes (#41)
Browse files Browse the repository at this point in the history
  • Loading branch information
morsecodist authored Jan 27, 2022
1 parent 493e07a commit 69ef772
Show file tree
Hide file tree
Showing 9 changed files with 96 additions and 19 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ SHELL=/bin/bash -o pipefail
deploy-mock:
rm terraform.tfstate || true
aws ssm put-parameter --name /mock-aws/service/ecs/optimized-ami/amazon-linux-2/recommended/image_id --value ami-12345678 --type String --endpoint-url http://localhost:9000
cp test/mock.tf .; unset TF_CLI_ARGS_init; terraform init; TF_VAR_mock=true TF_VAR_app_name=swipe-test TF_VAR_batch_ec2_instance_types='["optimal"]' terraform apply --auto-approve
cp test/mock.tf .; unset TF_CLI_ARGS_init; terraform init; TF_VAR_mock=true TF_VAR_app_name=swipe-test TF_VAR_batch_ec2_instance_types='["optimal"]' TF_VAR_sqs_queues='{"notifications":{"dead_letter": false}}' terraform apply --auto-approve

lint:
flake8 .
Expand Down
4 changes: 4 additions & 0 deletions terraform/modules/sfn-io-helper-lambdas/app/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@ def preprocess_input(sfn_data, _):

def process_stage_output(sfn_data, _):
assert sfn_data["CurrentState"].endswith("ReadOutput")
stage_io.broadcast_stage_complete(
sfn_data["ExecutionId"],
sfn_data["CurrentState"][:-len("ReadOutput")],
)
sfn_state = stage_io.read_state_from_s3(sfn_state=sfn_data["Input"], current_state=sfn_data["CurrentState"])
stage_io.link_outputs(sfn_state)
sfn_state = stage_io.trim_batch_job_details(sfn_state=sfn_state)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
batch = boto3.client("batch", endpoint_url=os.getenv("AWS_ENDPOINT_URL"))
stepfunctions = boto3.client("stepfunctions", endpoint_url=os.getenv("AWS_ENDPOINT_URL"))
cloudwatch = boto3.client("cloudwatch", endpoint_url=os.getenv("AWS_ENDPOINT_URL"))
sqs = boto3.client("sqs", endpoint_url=os.getenv("AWS_ENDPOINT_URL"))


def s3_object(uri):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
import json
import logging
from typing import List
from datetime import datetime
from uuid import uuid4

from botocore import xform_name

from . import s3_object
from . import s3_object, sqs

logger = logging.getLogger()

Expand Down Expand Up @@ -121,3 +123,51 @@ def preprocess_sfn_input(sfn_state, aws_region, aws_account_id, state_machine_na
link_outputs(sfn_state)

return sfn_state


def broadcast_stage_complete(execution_id: str, stage: str):
if "SQS_QUEUE_URLS" not in os.environ:
return

sqs_queue_urls = os.environ["SQS_QUEUE_URLS"].split(",")

assert len(execution_id.split(":")) == 8
_, _, _, aws_region, aws_account_id, _, state_machine_name, execution_name = execution_id.split(":")

state_machine_arn = f"arn:aws:states:{aws_region}:{aws_account_id}:stateMachine:{state_machine_name}"
execution_arn = f"arn:aws:states:{aws_region}:{aws_account_id}:execution:{execution_id}"

body = json.dumps({
"version": "0",
"id": str(uuid4()),
"detail-type": "Step Functions Execution Status Change",
"source": "aws.states",
"account": aws_account_id,
"time": datetime.now().strftime('%Y-%m-%dT%H:%M:%SZ'),
"region": aws_region,
"resources": [execution_arn],
"detail": {
"executionArn": execution_arn,
"stateMachineArn": state_machine_arn,
"name": execution_name,
"status": "RUNNING",
"lastCompletedStage": re.sub(r'(?<!^)(?=[A-Z])', '_', stage).lower(),
# We don't set this because it isn't used yet and we don't have this
# field in the lambda, but it is part of the schema for these
# messages so we may need to add it.
# "startDate": 1551225271984,
"stopDate": None,
"input": "{}",
"inputDetails": {
"included": None
},
"output": None,
"outputDetails": None
}
})

for squs_que_url in sqs_queue_urls:
sqs.send_message(
QueueUrl=squs_que_url,
MessageBody=body,
)
9 changes: 8 additions & 1 deletion terraform/modules/sfn-io-helper-lambdas/main.tf
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,13 @@ resource "aws_iam_role_policy" "iam_role_policy" {
"logs:PutLogEvents"
],
Resource : "arn:aws:logs:*:*:*"
},
{
Effect : "Allow",
Action : [
"sqs:SendMessage",
],
Resource : var.sfn_notification_queue_arns
}
])
})
Expand All @@ -130,6 +137,7 @@ resource "aws_lambda_function" "lambda" {
variables = merge({
APP_NAME = var.app_name
AWS_ENDPOINT_URL = var.mock ? "http://awsnet:5000" : null
SQS_QUEUE_URLS = join(",", var.sfn_notification_queue_urls)
}, {
for stage, defaults in var.stage_memory_defaults : "${stage}SPOTMemoryDefault" => "${defaults.spot}"
}, {
Expand All @@ -139,7 +147,6 @@ resource "aws_lambda_function" "lambda" {
}, {
for stage, defaults in var.stage_vcpu_defaults : "${stage}EC2VcpuDefault" => "${defaults.on_demand}"
},

)
}
}
Expand Down
9 changes: 9 additions & 0 deletions terraform/modules/sfn-io-helper-lambdas/variables.tf
Original file line number Diff line number Diff line change
Expand Up @@ -58,3 +58,12 @@ variable "aws_account_id" {
type = string
}

variable "sfn_notification_queue_arns" {
description = "ARNs of notification SQS queues"
type = list(string)
}

variable "sfn_notification_queue_urls" {
description = "URLs of notification SQS queues"
type = list(string)
}
24 changes: 13 additions & 11 deletions terraform/modules/swipe-sfn/main.tf
Original file line number Diff line number Diff line change
Expand Up @@ -45,17 +45,19 @@ module "batch_job" {
}

module "sfn_io_helper" {
source = "../sfn-io-helper-lambdas"
app_name = var.app_name
mock = var.mock
aws_region = data.aws_region.current.name
aws_account_id = data.aws_caller_identity.current.account_id
batch_queue_arns = [var.batch_spot_job_queue_arn, var.batch_on_demand_job_queue_arn]
workspace_s3_prefix = var.workspace_s3_prefix
wdl_workflow_s3_prefix = var.wdl_workflow_s3_prefix
stage_memory_defaults = var.stage_memory_defaults
stage_vcpu_defaults = var.stage_vcpu_defaults
tags = var.tags
source = "../sfn-io-helper-lambdas"
app_name = var.app_name
mock = var.mock
aws_region = data.aws_region.current.name
aws_account_id = data.aws_caller_identity.current.account_id
batch_queue_arns = [var.batch_spot_job_queue_arn, var.batch_on_demand_job_queue_arn]
workspace_s3_prefix = var.workspace_s3_prefix
wdl_workflow_s3_prefix = var.wdl_workflow_s3_prefix
stage_memory_defaults = var.stage_memory_defaults
stage_vcpu_defaults = var.stage_vcpu_defaults
sfn_notification_queue_arns = [for name, queue in aws_sqs_queue.sfn_notifications_queue : queue.arn]
sfn_notification_queue_urls = [for name, queue in aws_sqs_queue.sfn_notifications_queue : queue.url]
tags = var.tags
}

resource "aws_sfn_state_machine" "swipe_single_wdl" {
Expand Down
12 changes: 8 additions & 4 deletions test/test_wdl.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def setUp(self) -> None:
self.sfn = boto3.client("stepfunctions", endpoint_url="http://localhost:8083")
self.test_bucket = self.s3.create_bucket(Bucket="swipe-test")
self.lamb = boto3.client("lambda", endpoint_url="http://localhost:9000")
self.sqs = boto3.client("sqs", endpoint_url="http://localhost:9000")

def test_simple_sfn_wdl_workflow(self):
wdl_obj = self.test_bucket.Object("test-v1.0.0.wdl")
Expand All @@ -83,8 +84,6 @@ def test_simple_sfn_wdl_workflow(self):
input=json.dumps(sfn_input))

arn = res["executionArn"]
assert res

start = time.time()
description = self.sfn.describe_execution(executionArn=arn)
while description["status"] == "RUNNING" and time.time() < start + 2 * 60:
Expand All @@ -94,10 +93,15 @@ def test_simple_sfn_wdl_workflow(self):
for event in self.sfn.get_execution_history(executionArn=arn)["events"]:
print(event, file=sys.stderr)

assert description["status"] == "SUCCEEDED", description
self.assertEqual(description["status"], "SUCCEEDED")
outputs_obj = self.test_bucket.Object(f"{output_prefix}/test-1/out.txt")
output_text = outputs_obj.get()['Body'].read().decode()
assert output_text == "hello\nworld\n", output_text
self.assertEqual(output_text, "hello\nworld\n")

res = self.sqs.list_queues()
queue_url = res["QueueUrls"][0]
res = self.sqs.receive_message(QueueUrl=queue_url)
self.assertEqual(json.loads(res["Messages"][0]["Body"])["detail"]["lastCompletedStage"], "run")


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion version
Original file line number Diff line number Diff line change
@@ -1 +1 @@
v0.13.1-beta
v0.14.0-beta

0 comments on commit 69ef772

Please sign in to comment.