diff --git a/tests/models/test_mappedoperator.py b/tests/models/test_mappedoperator.py index 78f0a0d2715e5..76dba4538bf7b 100644 --- a/tests/models/test_mappedoperator.py +++ b/tests/models/test_mappedoperator.py @@ -1568,3 +1568,62 @@ def my_teardown(val): "tg_2.my_work": "skipped", } assert states == expected + + def test_skip_one_mapped_task_from_task_group_with_generator(self, dag_maker): + with dag_maker() as dag: + + @task + def make_list(): + return [1, 2, 3] + + @task + def double(n): + if n == 2: + raise AirflowSkipException() + return n * 2 + + @task + def last(n): + ... + + @task_group + def group(n: int) -> None: + last(double(n)) + + group.expand(n=make_list()) + + dr = dag.test() + states = self.get_states(dr) + expected = { + "group.double": {0: "success", 1: "skipped", 2: "success"}, + "group.last": {0: "success", 1: "skipped", 2: "success"}, + "make_list": "success", + } + assert states == expected + + def test_skip_one_mapped_task_from_task_group(self, dag_maker): + with dag_maker() as dag: + + @task + def double(n): + if n == 2: + raise AirflowSkipException() + return n * 2 + + @task + def last(n): + ... + + @task_group + def group(n: int) -> None: + last(double(n)) + + group.expand(n=[1, 2, 3]) + + dr = dag.test() + states = self.get_states(dr) + expected = { + "group.double": {0: "success", 1: "skipped", 2: "success"}, + "group.last": {0: "success", 1: "skipped", 2: "success"}, + } + assert states == expected