Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add RedriveExecution support to StepFunctionStartExecutionOperator #40976

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions airflow/providers/amazon/aws/hooks/step_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import json

from airflow.exceptions import AirflowFailException
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook


Expand All @@ -43,6 +44,7 @@ def start_execution(
state_machine_arn: str,
name: str | None = None,
state_machine_input: dict | str | None = None,
is_redrive_execution: bool = False,
) -> str:
"""
Start Execution of the State Machine.
Expand All @@ -51,10 +53,26 @@ def start_execution(
- :external+boto3:py:meth:`SFN.Client.start_execution`
:param state_machine_arn: AWS Step Function State Machine ARN.
:param is_redrive_execution: Restarts unsuccessful executions of Standard workflows that did not
complete successfully in the last 14 days.
:param name: The name of the execution.
:param state_machine_input: JSON data input to pass to the State Machine.
:return: Execution ARN.
"""
if is_redrive_execution:
if not name:
raise AirflowFailException(
"Execution name is required to start RedriveExecution for %s.", state_machine_arn
)
elements = state_machine_arn.split(":stateMachine:")
execution_arn = f"{elements[0]}:execution:{elements[1]}:{name}"
self.conn.redrive_execution(executionArn=execution_arn)
self.log.info(
"Successfully started RedriveExecution for Step Function State Machine: %s.",
state_machine_arn,
)
return execution_arn

execution_args = {"stateMachineArn": state_machine_arn}
if name is not None:
execution_args["name"] = name
Expand Down
14 changes: 12 additions & 2 deletions airflow/providers/amazon/aws/operators/step_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ class StepFunctionStartExecutionOperator(AwsBaseOperator[StepFunctionHook]):
:param state_machine_arn: ARN of the Step Function State Machine
:param name: The name of the execution.
:param is_redrive_execution: Restarts unsuccessful executions of Standard workflows that did not
complete successfully in the last 14 days.
:param state_machine_input: JSON data input to pass to the State Machine
:param aws_conn_id: The Airflow connection used for AWS credentials.
If this is None or empty then the default boto3 behaviour is used. If
Expand All @@ -73,7 +75,9 @@ class StepFunctionStartExecutionOperator(AwsBaseOperator[StepFunctionHook]):
"""

aws_hook_class = StepFunctionHook
template_fields: Sequence[str] = aws_template_fields("state_machine_arn", "name", "input")
template_fields: Sequence[str] = aws_template_fields(
"state_machine_arn", "name", "input", "is_redrive_execution"
)
ui_color = "#f9c915"
operator_extra_links = (StateMachineDetailsLink(), StateMachineExecutionsDetailsLink())

Expand All @@ -82,6 +86,7 @@ def __init__(
*,
state_machine_arn: str,
name: str | None = None,
is_redrive_execution: bool = False,
state_machine_input: dict | str | None = None,
waiter_max_attempts: int = 30,
waiter_delay: int = 60,
Expand All @@ -91,6 +96,7 @@ def __init__(
super().__init__(**kwargs)
self.state_machine_arn = state_machine_arn
self.name = name
self.is_redrive_execution = is_redrive_execution
self.input = state_machine_input
self.waiter_delay = waiter_delay
self.waiter_max_attempts = waiter_max_attempts
Expand All @@ -105,7 +111,11 @@ def execute(self, context: Context):
state_machine_arn=self.state_machine_arn,
)

if not (execution_arn := self.hook.start_execution(self.state_machine_arn, self.name, self.input)):
if not (
execution_arn := self.hook.start_execution(
self.state_machine_arn, self.name, self.input, self.is_redrive_execution
)
):
raise AirflowException(f"Failed to start State Machine execution for: {self.state_machine_arn}")

StateMachineExecutionsDetailsLink.persist(
Expand Down
30 changes: 30 additions & 0 deletions tests/providers/amazon/aws/hooks/test_step_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,13 @@
# under the License.
from __future__ import annotations

from datetime import datetime
from unittest import mock

import pytest
from moto import mock_aws

from airflow.exceptions import AirflowFailException
from airflow.providers.amazon.aws.hooks.step_function import StepFunctionHook


Expand All @@ -42,6 +47,31 @@ def test_start_execution(self):

assert execution_arn is not None

@mock.patch.object(StepFunctionHook, "conn")
def test_redrive_execution(self, mock_conn):
mock_conn.redrive_execution.return_value = {"redriveDate": datetime(2024, 1, 1)}
StepFunctionHook().start_execution(
state_machine_arn="arn:aws:states:us-east-1:123456789012:stateMachine:test-state-machine",
name="random-123",
is_redrive_execution=True,
)

mock_conn.redrive_execution.assert_called_once_with(
executionArn="arn:aws:states:us-east-1:123456789012:execution:test-state-machine:random-123"
)

@mock.patch.object(StepFunctionHook, "conn")
def test_redrive_execution_without_name_should_fail(self, mock_conn):
mock_conn.redrive_execution.return_value = {"redriveDate": datetime(2024, 1, 1)}

with pytest.raises(
AirflowFailException, match="Execution name is required to start RedriveExecution"
):
StepFunctionHook().start_execution(
state_machine_arn="arn:aws:states:us-east-1:123456789012:stateMachine:test-state-machine",
is_redrive_execution=True,
)

def test_describe_execution(self):
hook = StepFunctionHook(aws_conn_id="aws_default", region_name="us-east-1")
state_machine = hook.get_conn().create_state_machine(
Expand Down
36 changes: 34 additions & 2 deletions tests/providers/amazon/aws/operators/test_step_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def test_execute(self, mocked_hook, mocked_context):
aws_conn_id=None,
)
assert op.execute(mocked_context) == hook_response
mocked_hook.start_execution.assert_called_once_with(STATE_MACHINE_ARN, NAME, INPUT)
mocked_hook.start_execution.assert_called_once_with(STATE_MACHINE_ARN, NAME, INPUT, False)
self.mocked_details_link.assert_called_once_with(
aws_partition=mock.ANY,
context=mock.ANY,
Expand Down Expand Up @@ -189,7 +189,7 @@ def test_step_function_start_execution_deferrable(self, mocked_hook):
)
with pytest.raises(TaskDeferred):
operator.execute(None)
mocked_hook.start_execution.assert_called_once_with(STATE_MACHINE_ARN, NAME, INPUT)
mocked_hook.start_execution.assert_called_once_with(STATE_MACHINE_ARN, NAME, INPUT, False)

@mock.patch.object(StepFunctionStartExecutionOperator, "hook")
@pytest.mark.parametrize("execution_arn", [pytest.param(None, id="none"), pytest.param("", id="empty")])
Expand All @@ -200,3 +200,35 @@ def test_step_function_no_execution_arn_returns(self, mocked_hook, execution_arn
)
with pytest.raises(AirflowException, match="Failed to start State Machine execution"):
op.execute({})

@mock.patch.object(StepFunctionStartExecutionOperator, "hook")
def test_start_redrive_execution(self, mocked_hook, mocked_context):
hook_response = (
"arn:aws:states:us-east-1:123456789012:execution:"
"pseudo-state-machine:020f5b16-b1a1-4149-946f-92dd32d97934"
)
mocked_hook.start_execution.return_value = hook_response
op = StepFunctionStartExecutionOperator(
task_id=self.TASK_ID,
state_machine_arn=STATE_MACHINE_ARN,
name=NAME,
is_redrive_execution=True,
state_machine_input=None,
aws_conn_id=None,
)
assert op.execute(mocked_context) == hook_response
mocked_hook.start_execution.assert_called_once_with(STATE_MACHINE_ARN, NAME, None, True)
self.mocked_details_link.assert_called_once_with(
aws_partition=mock.ANY,
context=mock.ANY,
operator=mock.ANY,
region_name=mock.ANY,
state_machine_arn=STATE_MACHINE_ARN,
)
self.mocked_executions_details_link.assert_called_once_with(
aws_partition=mock.ANY,
context=mock.ANY,
operator=mock.ANY,
region_name=mock.ANY,
execution_arn=EXECUTION_ARN,
)