From d4503674b46277b96b8df3e584c0001cc4796324 Mon Sep 17 00:00:00 2001 From: Pavan Kumar Date: Tue, 23 Jul 2024 23:24:56 +0100 Subject: [PATCH 1/2] fix statc checks --- .../amazon/aws/hooks/step_function.py | 18 ++++++++++ .../amazon/aws/operators/step_function.py | 14 ++++++-- .../amazon/aws/hooks/test_step_function.py | 28 +++++++++++++++ .../aws/operators/test_step_function.py | 36 +++++++++++++++++-- 4 files changed, 92 insertions(+), 4 deletions(-) diff --git a/airflow/providers/amazon/aws/hooks/step_function.py b/airflow/providers/amazon/aws/hooks/step_function.py index 2d760ee891031..826b3047d3ea6 100644 --- a/airflow/providers/amazon/aws/hooks/step_function.py +++ b/airflow/providers/amazon/aws/hooks/step_function.py @@ -18,6 +18,7 @@ import json +from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook @@ -43,6 +44,7 @@ def start_execution( state_machine_arn: str, name: str | None = None, state_machine_input: dict | str | None = None, + is_redrive_execution: bool = False, ) -> str: """ Start Execution of the State Machine. @@ -51,10 +53,26 @@ def start_execution( - :external+boto3:py:meth:`SFN.Client.start_execution` :param state_machine_arn: AWS Step Function State Machine ARN. + :param is_redrive_execution: Restarts unsuccessful executions of Standard workflows that did not + complete successfully in the last 14 days. :param name: The name of the execution. :param state_machine_input: JSON data input to pass to the State Machine. :return: Execution ARN. """ + if is_redrive_execution: + if not name: + raise AirflowException( + "Execution name is required to start RedriveExecution for %s.", state_machine_arn + ) + elements = state_machine_arn.split(":stateMachine:") + execution_arn = f"{elements[0]}:execution:{elements[1]}:{name}" + self.conn.redrive_execution(executionArn=execution_arn) + self.log.info( + "Successfully started RedriveExecution for Step Function State Machine: %s.", + state_machine_arn, + ) + return execution_arn + execution_args = {"stateMachineArn": state_machine_arn} if name is not None: execution_args["name"] = name diff --git a/airflow/providers/amazon/aws/operators/step_function.py b/airflow/providers/amazon/aws/operators/step_function.py index c6c2f71e1bcf6..e17aeeeae26fb 100644 --- a/airflow/providers/amazon/aws/operators/step_function.py +++ b/airflow/providers/amazon/aws/operators/step_function.py @@ -48,6 +48,8 @@ class StepFunctionStartExecutionOperator(AwsBaseOperator[StepFunctionHook]): :param state_machine_arn: ARN of the Step Function State Machine :param name: The name of the execution. + :param is_redrive_execution: Restarts unsuccessful executions of Standard workflows that did not + complete successfully in the last 14 days. :param state_machine_input: JSON data input to pass to the State Machine :param aws_conn_id: The Airflow connection used for AWS credentials. If this is None or empty then the default boto3 behaviour is used. If @@ -73,7 +75,9 @@ class StepFunctionStartExecutionOperator(AwsBaseOperator[StepFunctionHook]): """ aws_hook_class = StepFunctionHook - template_fields: Sequence[str] = aws_template_fields("state_machine_arn", "name", "input") + template_fields: Sequence[str] = aws_template_fields( + "state_machine_arn", "name", "input", "is_redrive_execution" + ) ui_color = "#f9c915" operator_extra_links = (StateMachineDetailsLink(), StateMachineExecutionsDetailsLink()) @@ -82,6 +86,7 @@ def __init__( *, state_machine_arn: str, name: str | None = None, + is_redrive_execution: bool = False, state_machine_input: dict | str | None = None, waiter_max_attempts: int = 30, waiter_delay: int = 60, @@ -91,6 +96,7 @@ def __init__( super().__init__(**kwargs) self.state_machine_arn = state_machine_arn self.name = name + self.is_redrive_execution = is_redrive_execution self.input = state_machine_input self.waiter_delay = waiter_delay self.waiter_max_attempts = waiter_max_attempts @@ -105,7 +111,11 @@ def execute(self, context: Context): state_machine_arn=self.state_machine_arn, ) - if not (execution_arn := self.hook.start_execution(self.state_machine_arn, self.name, self.input)): + if not ( + execution_arn := self.hook.start_execution( + self.state_machine_arn, self.name, self.input, self.is_redrive_execution + ) + ): raise AirflowException(f"Failed to start State Machine execution for: {self.state_machine_arn}") StateMachineExecutionsDetailsLink.persist( diff --git a/tests/providers/amazon/aws/hooks/test_step_function.py b/tests/providers/amazon/aws/hooks/test_step_function.py index 393d27715bc82..de56c3e061f31 100644 --- a/tests/providers/amazon/aws/hooks/test_step_function.py +++ b/tests/providers/amazon/aws/hooks/test_step_function.py @@ -17,8 +17,13 @@ # under the License. from __future__ import annotations +from datetime import datetime +from unittest import mock + +import pytest from moto import mock_aws +from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.hooks.step_function import StepFunctionHook @@ -42,6 +47,29 @@ def test_start_execution(self): assert execution_arn is not None + @mock.patch.object(StepFunctionHook, "conn") + def test_redrive_execution(self, mock_conn): + mock_conn.redrive_execution.return_value = {"redriveDate": datetime(2024, 1, 1)} + StepFunctionHook().start_execution( + state_machine_arn="arn:aws:states:us-east-1:123456789012:stateMachine:test-state-machine", + name="random-123", + is_redrive_execution=True, + ) + + mock_conn.redrive_execution.assert_called_once_with( + executionArn="arn:aws:states:us-east-1:123456789012:execution:test-state-machine:random-123" + ) + + @mock.patch.object(StepFunctionHook, "conn") + def test_redrive_execution_without_name_should_fail(self, mock_conn): + mock_conn.redrive_execution.return_value = {"redriveDate": datetime(2024, 1, 1)} + + with pytest.raises(AirflowException, match="Execution name is required to start RedriveExecution"): + StepFunctionHook().start_execution( + state_machine_arn="arn:aws:states:us-east-1:123456789012:stateMachine:test-state-machine", + is_redrive_execution=True, + ) + def test_describe_execution(self): hook = StepFunctionHook(aws_conn_id="aws_default", region_name="us-east-1") state_machine = hook.get_conn().create_state_machine( diff --git a/tests/providers/amazon/aws/operators/test_step_function.py b/tests/providers/amazon/aws/operators/test_step_function.py index 904205418e0c0..e8ab5c85a6b8f 100644 --- a/tests/providers/amazon/aws/operators/test_step_function.py +++ b/tests/providers/amazon/aws/operators/test_step_function.py @@ -159,7 +159,7 @@ def test_execute(self, mocked_hook, mocked_context): aws_conn_id=None, ) assert op.execute(mocked_context) == hook_response - mocked_hook.start_execution.assert_called_once_with(STATE_MACHINE_ARN, NAME, INPUT) + mocked_hook.start_execution.assert_called_once_with(STATE_MACHINE_ARN, NAME, INPUT, False) self.mocked_details_link.assert_called_once_with( aws_partition=mock.ANY, context=mock.ANY, @@ -189,7 +189,7 @@ def test_step_function_start_execution_deferrable(self, mocked_hook): ) with pytest.raises(TaskDeferred): operator.execute(None) - mocked_hook.start_execution.assert_called_once_with(STATE_MACHINE_ARN, NAME, INPUT) + mocked_hook.start_execution.assert_called_once_with(STATE_MACHINE_ARN, NAME, INPUT, False) @mock.patch.object(StepFunctionStartExecutionOperator, "hook") @pytest.mark.parametrize("execution_arn", [pytest.param(None, id="none"), pytest.param("", id="empty")]) @@ -200,3 +200,35 @@ def test_step_function_no_execution_arn_returns(self, mocked_hook, execution_arn ) with pytest.raises(AirflowException, match="Failed to start State Machine execution"): op.execute({}) + + @mock.patch.object(StepFunctionStartExecutionOperator, "hook") + def test_start_redrive_execution(self, mocked_hook, mocked_context): + hook_response = ( + "arn:aws:states:us-east-1:123456789012:execution:" + "pseudo-state-machine:020f5b16-b1a1-4149-946f-92dd32d97934" + ) + mocked_hook.start_execution.return_value = hook_response + op = StepFunctionStartExecutionOperator( + task_id=self.TASK_ID, + state_machine_arn=STATE_MACHINE_ARN, + name=NAME, + is_redrive_execution=True, + state_machine_input=None, + aws_conn_id=None, + ) + assert op.execute(mocked_context) == hook_response + mocked_hook.start_execution.assert_called_once_with(STATE_MACHINE_ARN, NAME, None, True) + self.mocked_details_link.assert_called_once_with( + aws_partition=mock.ANY, + context=mock.ANY, + operator=mock.ANY, + region_name=mock.ANY, + state_machine_arn=STATE_MACHINE_ARN, + ) + self.mocked_executions_details_link.assert_called_once_with( + aws_partition=mock.ANY, + context=mock.ANY, + operator=mock.ANY, + region_name=mock.ANY, + execution_arn=EXECUTION_ARN, + ) From 17406926ae250de1e526795fcbdb668ffdf3101d Mon Sep 17 00:00:00 2001 From: Pavan Kumar Date: Wed, 24 Jul 2024 16:49:41 +0100 Subject: [PATCH 2/2] fix static checks --- airflow/providers/amazon/aws/hooks/step_function.py | 4 ++-- tests/providers/amazon/aws/hooks/test_step_function.py | 6 ++++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/airflow/providers/amazon/aws/hooks/step_function.py b/airflow/providers/amazon/aws/hooks/step_function.py index 826b3047d3ea6..48da7cb1150a8 100644 --- a/airflow/providers/amazon/aws/hooks/step_function.py +++ b/airflow/providers/amazon/aws/hooks/step_function.py @@ -18,7 +18,7 @@ import json -from airflow.exceptions import AirflowException +from airflow.exceptions import AirflowFailException from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook @@ -61,7 +61,7 @@ def start_execution( """ if is_redrive_execution: if not name: - raise AirflowException( + raise AirflowFailException( "Execution name is required to start RedriveExecution for %s.", state_machine_arn ) elements = state_machine_arn.split(":stateMachine:") diff --git a/tests/providers/amazon/aws/hooks/test_step_function.py b/tests/providers/amazon/aws/hooks/test_step_function.py index de56c3e061f31..ce66447da68ec 100644 --- a/tests/providers/amazon/aws/hooks/test_step_function.py +++ b/tests/providers/amazon/aws/hooks/test_step_function.py @@ -23,7 +23,7 @@ import pytest from moto import mock_aws -from airflow.exceptions import AirflowException +from airflow.exceptions import AirflowFailException from airflow.providers.amazon.aws.hooks.step_function import StepFunctionHook @@ -64,7 +64,9 @@ def test_redrive_execution(self, mock_conn): def test_redrive_execution_without_name_should_fail(self, mock_conn): mock_conn.redrive_execution.return_value = {"redriveDate": datetime(2024, 1, 1)} - with pytest.raises(AirflowException, match="Execution name is required to start RedriveExecution"): + with pytest.raises( + AirflowFailException, match="Execution name is required to start RedriveExecution" + ): StepFunctionHook().start_execution( state_machine_arn="arn:aws:states:us-east-1:123456789012:stateMachine:test-state-machine", is_redrive_execution=True,