diff --git a/airflow/providers/amazon/aws/operators/batch.py b/airflow/providers/amazon/aws/operators/batch.py index 8c39be02a7cff..afca0fc615ff0 100644 --- a/airflow/providers/amazon/aws/operators/batch.py +++ b/airflow/providers/amazon/aws/operators/batch.py @@ -47,6 +47,7 @@ ) from airflow.providers.amazon.aws.utils import trim_none_values, validate_execute_complete_event from airflow.providers.amazon.aws.utils.task_log_fetcher import AwsTaskLogFetcher +from airflow.utils.types import NOTSET if TYPE_CHECKING: from airflow.utils.context import Context @@ -480,16 +481,16 @@ def __init__( aws_conn_id: str | None = None, region_name: str | None = None, deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), + status_retries=NOTSET, **kwargs, ): - if "status_retries" in kwargs: + if status_retries is not NOTSET: warnings.warn( "The `status_retries` parameter is unused and should be removed. " "It'll be deleted in a future version.", AirflowProviderDeprecationWarning, stacklevel=2, ) - kwargs.pop("status_retries") # remove before calling super() to prevent unexpected arg error super().__init__(**kwargs) diff --git a/airflow/providers/amazon/aws/operators/ecs.py b/airflow/providers/amazon/aws/operators/ecs.py index b5874d98d412d..8fe6df5812169 100644 --- a/airflow/providers/amazon/aws/operators/ecs.py +++ b/airflow/providers/amazon/aws/operators/ecs.py @@ -40,6 +40,7 @@ from airflow.providers.amazon.aws.utils.mixins import aws_template_fields from airflow.providers.amazon.aws.utils.task_log_fetcher import AwsTaskLogFetcher from airflow.utils.helpers import prune_dict +from airflow.utils.types import NOTSET if TYPE_CHECKING: import boto3 @@ -257,19 +258,18 @@ def __init__( self, *, task_definition: str, + wait_for_completion=NOTSET, + waiter_delay=NOTSET, + waiter_max_attempts=NOTSET, **kwargs, ): - if "wait_for_completion" in kwargs or "waiter_delay" in kwargs or "waiter_max_attempts" in kwargs: + if any(arg is not NOTSET for arg in [wait_for_completion, waiter_delay, waiter_max_attempts]): warnings.warn( "'wait_for_completion' and waiter related params have no effect and are deprecated, " "please remove them.", AirflowProviderDeprecationWarning, stacklevel=2, ) - # remove args to not trigger Invalid arguments exception - kwargs.pop("wait_for_completion", None) - kwargs.pop("waiter_delay", None) - kwargs.pop("waiter_max_attempts", None) super().__init__(**kwargs) self.task_definition = task_definition @@ -311,19 +311,18 @@ def __init__( family: str, container_definitions: list[dict], register_task_kwargs: dict | None = None, + wait_for_completion=NOTSET, + waiter_delay=NOTSET, + waiter_max_attempts=NOTSET, **kwargs, ): - if "wait_for_completion" in kwargs or "waiter_delay" in kwargs or "waiter_max_attempts" in kwargs: + if any(arg is not NOTSET for arg in [wait_for_completion, waiter_delay, waiter_max_attempts]): warnings.warn( "'wait_for_completion' and waiter related params have no effect and are deprecated, " "please remove them.", AirflowProviderDeprecationWarning, stacklevel=2, ) - # remove args to not trigger Invalid arguments exception - kwargs.pop("wait_for_completion", None) - kwargs.pop("waiter_delay", None) - kwargs.pop("waiter_max_attempts", None) super().__init__(**kwargs) self.family = family diff --git a/tests/providers/amazon/aws/operators/test_batch.py b/tests/providers/amazon/aws/operators/test_batch.py index 313d721b3ace9..2ac95578136ba 100644 --- a/tests/providers/amazon/aws/operators/test_batch.py +++ b/tests/providers/amazon/aws/operators/test_batch.py @@ -31,6 +31,7 @@ BatchCreateComputeEnvironmentTrigger, BatchJobTrigger, ) +from airflow.utils.task_instance_session import set_current_task_instance_session AWS_REGION = "eu-west-1" AWS_ACCESS_KEY_ID = "airflow_dummy_key" @@ -370,6 +371,8 @@ def test_monitor_job_with_logs( class TestBatchCreateComputeEnvironmentOperator: + warn_message = "The `status_retries` parameter is unused and should be removed" + @mock.patch.object(BatchClientHook, "client") def test_execute(self, mock_conn): environment_name = "environment_name" @@ -394,6 +397,36 @@ def test_execute(self, mock_conn): tags=tags, ) + def test_deprecation(self): + with pytest.warns(AirflowProviderDeprecationWarning, match=self.warn_message): + BatchCreateComputeEnvironmentOperator( + task_id="id", + compute_environment_name="environment_name", + environment_type="environment_type", + state="environment_state", + compute_resources={}, + status_retries="Huh?", + ) + + @pytest.mark.db_test + def test_partial_deprecation(self, dag_maker, session): + with dag_maker(dag_id="test_partial_deprecation_waiters_params_reg_ecs", session=session): + BatchCreateComputeEnvironmentOperator.partial( + task_id="id", + compute_environment_name="environment_name", + environment_type="environment_type", + state="environment_state", + status_retries="Huh?", + ).expand(compute_resources=[{}, {}]) + + dr = dag_maker.create_dagrun() + tis = dr.get_task_instances(session=session) + with set_current_task_instance_session(session=session): + for ti in tis: + with pytest.warns(AirflowProviderDeprecationWarning, match=self.warn_message): + ti.render_templates() + assert not hasattr(ti.task, "status_retries") + @mock.patch.object(BatchClientHook, "client") def test_defer(self, client_mock): client_mock.create_compute_environment.return_value = {"computeEnvironmentArn": "my_arn"} diff --git a/tests/providers/amazon/aws/operators/test_ecs.py b/tests/providers/amazon/aws/operators/test_ecs.py index 37583bcdcad7f..136e776b7de25 100644 --- a/tests/providers/amazon/aws/operators/test_ecs.py +++ b/tests/providers/amazon/aws/operators/test_ecs.py @@ -858,6 +858,8 @@ def test_execute_without_waiter(self, patch_hook_waiters): class TestEcsDeregisterTaskDefinitionOperator(EcsBaseTestCase): + warn_message = "'wait_for_completion' and waiter related params have no effect" + def test_execute_immediate_delete(self): """Test if task definition deleted during initial request.""" op = EcsDeregisterTaskDefinitionOperator( @@ -872,11 +874,50 @@ def test_execute_immediate_delete(self): assert result == "foo-bar" def test_deprecation(self): - with pytest.warns(AirflowProviderDeprecationWarning): + with pytest.warns(AirflowProviderDeprecationWarning, match=self.warn_message): EcsDeregisterTaskDefinitionOperator(task_id="id", task_definition="def", wait_for_completion=True) + @pytest.mark.db_test + @pytest.mark.parametrize( + "wait_for_completion, waiter_delay, waiter_max_attempts", + [ + pytest.param(True, 10, 42, id="all-params"), + pytest.param(False, None, None, id="wait-for-completion-only"), + pytest.param(None, 10, None, id="waiter-delay-only"), + pytest.param(None, None, 42, id="waiter-max-attempts-delay-only"), + ], + ) + def test_partial_deprecation_waiters_params( + self, wait_for_completion, waiter_delay, waiter_max_attempts, dag_maker, session + ): + op_kwargs = {} + if wait_for_completion is not None: + op_kwargs["wait_for_completion"] = wait_for_completion + if waiter_delay is not None: + op_kwargs["waiter_delay"] = waiter_delay + if waiter_max_attempts is not None: + op_kwargs["waiter_max_attempts"] = waiter_max_attempts + + with dag_maker(dag_id="test_partial_deprecation_waiters_params_dereg_ecs", session=session): + EcsDeregisterTaskDefinitionOperator.partial( + task_id="fake-task-id", + **op_kwargs, + ).expand(task_definition=["foo", "bar"]) + + dr = dag_maker.create_dagrun() + tis = dr.get_task_instances(session=session) + with set_current_task_instance_session(session=session): + for ti in tis: + with pytest.warns(AirflowProviderDeprecationWarning, match=self.warn_message): + ti.render_templates() + assert not hasattr(ti.task, "wait_for_completion") + assert not hasattr(ti.task, "waiter_delay") + assert not hasattr(ti.task, "waiter_max_attempts") + class TestEcsRegisterTaskDefinitionOperator(EcsBaseTestCase): + warn_message = "'wait_for_completion' and waiter related params have no effect" + def test_execute_immediate_create(self): """Test if task definition created during initial request.""" mock_ti = mock.MagicMock(name="MockedTaskInstance") @@ -908,7 +949,45 @@ def test_execute_immediate_create(self): assert result == "foo-bar" def test_deprecation(self): - with pytest.warns(AirflowProviderDeprecationWarning): + with pytest.warns(AirflowProviderDeprecationWarning, match=self.warn_message): EcsRegisterTaskDefinitionOperator( task_id="id", wait_for_completion=True, **TASK_DEFINITION_CONFIG ) + + @pytest.mark.db_test + @pytest.mark.parametrize( + "wait_for_completion, waiter_delay, waiter_max_attempts", + [ + pytest.param(True, 10, 42, id="all-params"), + pytest.param(False, None, None, id="wait-for-completion-only"), + pytest.param(None, 10, None, id="waiter-delay-only"), + pytest.param(None, None, 42, id="waiter-max-attempts-delay-only"), + ], + ) + def test_partial_deprecation_waiters_params( + self, wait_for_completion, waiter_delay, waiter_max_attempts, dag_maker, session + ): + op_kwargs = {} + if wait_for_completion is not None: + op_kwargs["wait_for_completion"] = wait_for_completion + if waiter_delay is not None: + op_kwargs["waiter_delay"] = waiter_delay + if waiter_max_attempts is not None: + op_kwargs["waiter_max_attempts"] = waiter_max_attempts + + with dag_maker(dag_id="test_partial_deprecation_waiters_params_reg_ecs", session=session): + EcsRegisterTaskDefinitionOperator.partial( + task_id="fake-task-id", + family="family_name", + **op_kwargs, + ).expand(container_definitions=[{}, {}]) + + dr = dag_maker.create_dagrun() + tis = dr.get_task_instances(session=session) + with set_current_task_instance_session(session=session): + for ti in tis: + with pytest.warns(AirflowProviderDeprecationWarning, match=self.warn_message): + ti.render_templates() + assert not hasattr(ti.task, "wait_for_completion") + assert not hasattr(ti.task, "waiter_delay") + assert not hasattr(ti.task, "waiter_max_attempts")