Skip to content

Commit

Permalink
Fix set deprecated amazon operators arguments in MappedOperator (#3…
Browse files Browse the repository at this point in the history
  • Loading branch information
Taragolis authored Mar 22, 2024
1 parent 0aee681 commit c893cb3
Show file tree
Hide file tree
Showing 4 changed files with 126 additions and 14 deletions.
5 changes: 3 additions & 2 deletions airflow/providers/amazon/aws/operators/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
)
from airflow.providers.amazon.aws.utils import trim_none_values, validate_execute_complete_event
from airflow.providers.amazon.aws.utils.task_log_fetcher import AwsTaskLogFetcher
from airflow.utils.types import NOTSET

if TYPE_CHECKING:
from airflow.utils.context import Context
Expand Down Expand Up @@ -480,16 +481,16 @@ def __init__(
aws_conn_id: str | None = None,
region_name: str | None = None,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
status_retries=NOTSET,
**kwargs,
):
if "status_retries" in kwargs:
if status_retries is not NOTSET:
warnings.warn(
"The `status_retries` parameter is unused and should be removed. "
"It'll be deleted in a future version.",
AirflowProviderDeprecationWarning,
stacklevel=2,
)
kwargs.pop("status_retries") # remove before calling super() to prevent unexpected arg error

super().__init__(**kwargs)

Expand Down
19 changes: 9 additions & 10 deletions airflow/providers/amazon/aws/operators/ecs.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
from airflow.providers.amazon.aws.utils.task_log_fetcher import AwsTaskLogFetcher
from airflow.utils.helpers import prune_dict
from airflow.utils.types import NOTSET

if TYPE_CHECKING:
import boto3
Expand Down Expand Up @@ -257,19 +258,18 @@ def __init__(
self,
*,
task_definition: str,
wait_for_completion=NOTSET,
waiter_delay=NOTSET,
waiter_max_attempts=NOTSET,
**kwargs,
):
if "wait_for_completion" in kwargs or "waiter_delay" in kwargs or "waiter_max_attempts" in kwargs:
if any(arg is not NOTSET for arg in [wait_for_completion, waiter_delay, waiter_max_attempts]):
warnings.warn(
"'wait_for_completion' and waiter related params have no effect and are deprecated, "
"please remove them.",
AirflowProviderDeprecationWarning,
stacklevel=2,
)
# remove args to not trigger Invalid arguments exception
kwargs.pop("wait_for_completion", None)
kwargs.pop("waiter_delay", None)
kwargs.pop("waiter_max_attempts", None)

super().__init__(**kwargs)
self.task_definition = task_definition
Expand Down Expand Up @@ -311,19 +311,18 @@ def __init__(
family: str,
container_definitions: list[dict],
register_task_kwargs: dict | None = None,
wait_for_completion=NOTSET,
waiter_delay=NOTSET,
waiter_max_attempts=NOTSET,
**kwargs,
):
if "wait_for_completion" in kwargs or "waiter_delay" in kwargs or "waiter_max_attempts" in kwargs:
if any(arg is not NOTSET for arg in [wait_for_completion, waiter_delay, waiter_max_attempts]):
warnings.warn(
"'wait_for_completion' and waiter related params have no effect and are deprecated, "
"please remove them.",
AirflowProviderDeprecationWarning,
stacklevel=2,
)
# remove args to not trigger Invalid arguments exception
kwargs.pop("wait_for_completion", None)
kwargs.pop("waiter_delay", None)
kwargs.pop("waiter_max_attempts", None)

super().__init__(**kwargs)
self.family = family
Expand Down
33 changes: 33 additions & 0 deletions tests/providers/amazon/aws/operators/test_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
BatchCreateComputeEnvironmentTrigger,
BatchJobTrigger,
)
from airflow.utils.task_instance_session import set_current_task_instance_session

AWS_REGION = "eu-west-1"
AWS_ACCESS_KEY_ID = "airflow_dummy_key"
Expand Down Expand Up @@ -370,6 +371,8 @@ def test_monitor_job_with_logs(


class TestBatchCreateComputeEnvironmentOperator:
warn_message = "The `status_retries` parameter is unused and should be removed"

@mock.patch.object(BatchClientHook, "client")
def test_execute(self, mock_conn):
environment_name = "environment_name"
Expand All @@ -394,6 +397,36 @@ def test_execute(self, mock_conn):
tags=tags,
)

def test_deprecation(self):
with pytest.warns(AirflowProviderDeprecationWarning, match=self.warn_message):
BatchCreateComputeEnvironmentOperator(
task_id="id",
compute_environment_name="environment_name",
environment_type="environment_type",
state="environment_state",
compute_resources={},
status_retries="Huh?",
)

@pytest.mark.db_test
def test_partial_deprecation(self, dag_maker, session):
with dag_maker(dag_id="test_partial_deprecation_waiters_params_reg_ecs", session=session):
BatchCreateComputeEnvironmentOperator.partial(
task_id="id",
compute_environment_name="environment_name",
environment_type="environment_type",
state="environment_state",
status_retries="Huh?",
).expand(compute_resources=[{}, {}])

dr = dag_maker.create_dagrun()
tis = dr.get_task_instances(session=session)
with set_current_task_instance_session(session=session):
for ti in tis:
with pytest.warns(AirflowProviderDeprecationWarning, match=self.warn_message):
ti.render_templates()
assert not hasattr(ti.task, "status_retries")

@mock.patch.object(BatchClientHook, "client")
def test_defer(self, client_mock):
client_mock.create_compute_environment.return_value = {"computeEnvironmentArn": "my_arn"}
Expand Down
83 changes: 81 additions & 2 deletions tests/providers/amazon/aws/operators/test_ecs.py
Original file line number Diff line number Diff line change
Expand Up @@ -858,6 +858,8 @@ def test_execute_without_waiter(self, patch_hook_waiters):


class TestEcsDeregisterTaskDefinitionOperator(EcsBaseTestCase):
warn_message = "'wait_for_completion' and waiter related params have no effect"

def test_execute_immediate_delete(self):
"""Test if task definition deleted during initial request."""
op = EcsDeregisterTaskDefinitionOperator(
Expand All @@ -872,11 +874,50 @@ def test_execute_immediate_delete(self):
assert result == "foo-bar"

def test_deprecation(self):
with pytest.warns(AirflowProviderDeprecationWarning):
with pytest.warns(AirflowProviderDeprecationWarning, match=self.warn_message):
EcsDeregisterTaskDefinitionOperator(task_id="id", task_definition="def", wait_for_completion=True)

@pytest.mark.db_test
@pytest.mark.parametrize(
"wait_for_completion, waiter_delay, waiter_max_attempts",
[
pytest.param(True, 10, 42, id="all-params"),
pytest.param(False, None, None, id="wait-for-completion-only"),
pytest.param(None, 10, None, id="waiter-delay-only"),
pytest.param(None, None, 42, id="waiter-max-attempts-delay-only"),
],
)
def test_partial_deprecation_waiters_params(
self, wait_for_completion, waiter_delay, waiter_max_attempts, dag_maker, session
):
op_kwargs = {}
if wait_for_completion is not None:
op_kwargs["wait_for_completion"] = wait_for_completion
if waiter_delay is not None:
op_kwargs["waiter_delay"] = waiter_delay
if waiter_max_attempts is not None:
op_kwargs["waiter_max_attempts"] = waiter_max_attempts

with dag_maker(dag_id="test_partial_deprecation_waiters_params_dereg_ecs", session=session):
EcsDeregisterTaskDefinitionOperator.partial(
task_id="fake-task-id",
**op_kwargs,
).expand(task_definition=["foo", "bar"])

dr = dag_maker.create_dagrun()
tis = dr.get_task_instances(session=session)
with set_current_task_instance_session(session=session):
for ti in tis:
with pytest.warns(AirflowProviderDeprecationWarning, match=self.warn_message):
ti.render_templates()
assert not hasattr(ti.task, "wait_for_completion")
assert not hasattr(ti.task, "waiter_delay")
assert not hasattr(ti.task, "waiter_max_attempts")


class TestEcsRegisterTaskDefinitionOperator(EcsBaseTestCase):
warn_message = "'wait_for_completion' and waiter related params have no effect"

def test_execute_immediate_create(self):
"""Test if task definition created during initial request."""
mock_ti = mock.MagicMock(name="MockedTaskInstance")
Expand Down Expand Up @@ -908,7 +949,45 @@ def test_execute_immediate_create(self):
assert result == "foo-bar"

def test_deprecation(self):
with pytest.warns(AirflowProviderDeprecationWarning):
with pytest.warns(AirflowProviderDeprecationWarning, match=self.warn_message):
EcsRegisterTaskDefinitionOperator(
task_id="id", wait_for_completion=True, **TASK_DEFINITION_CONFIG
)

@pytest.mark.db_test
@pytest.mark.parametrize(
"wait_for_completion, waiter_delay, waiter_max_attempts",
[
pytest.param(True, 10, 42, id="all-params"),
pytest.param(False, None, None, id="wait-for-completion-only"),
pytest.param(None, 10, None, id="waiter-delay-only"),
pytest.param(None, None, 42, id="waiter-max-attempts-delay-only"),
],
)
def test_partial_deprecation_waiters_params(
self, wait_for_completion, waiter_delay, waiter_max_attempts, dag_maker, session
):
op_kwargs = {}
if wait_for_completion is not None:
op_kwargs["wait_for_completion"] = wait_for_completion
if waiter_delay is not None:
op_kwargs["waiter_delay"] = waiter_delay
if waiter_max_attempts is not None:
op_kwargs["waiter_max_attempts"] = waiter_max_attempts

with dag_maker(dag_id="test_partial_deprecation_waiters_params_reg_ecs", session=session):
EcsRegisterTaskDefinitionOperator.partial(
task_id="fake-task-id",
family="family_name",
**op_kwargs,
).expand(container_definitions=[{}, {}])

dr = dag_maker.create_dagrun()
tis = dr.get_task_instances(session=session)
with set_current_task_instance_session(session=session):
for ti in tis:
with pytest.warns(AirflowProviderDeprecationWarning, match=self.warn_message):
ti.render_templates()
assert not hasattr(ti.task, "wait_for_completion")
assert not hasattr(ti.task, "waiter_delay")
assert not hasattr(ti.task, "waiter_max_attempts")

0 comments on commit c893cb3

Please sign in to comment.