diff --git a/airflow/ti_deps/deps/trigger_rule_dep.py b/airflow/ti_deps/deps/trigger_rule_dep.py index 6203b2a79b6d11..ca2a6100a27846 100644 --- a/airflow/ti_deps/deps/trigger_rule_dep.py +++ b/airflow/ti_deps/deps/trigger_rule_dep.py @@ -27,7 +27,6 @@ from airflow.models.taskinstance import PAST_DEPENDS_MET from airflow.ti_deps.deps.base_ti_dep import BaseTIDep from airflow.utils.state import TaskInstanceState -from airflow.utils.task_group import MappedTaskGroup from airflow.utils.trigger_rule import TriggerRule as TR if TYPE_CHECKING: @@ -133,20 +132,6 @@ def _get_expanded_ti_count() -> int: """ return ti.task.get_mapped_ti_count(ti.run_id, session=session) - def _iter_expansion_dependencies() -> Iterator[str]: - from airflow.models.mappedoperator import MappedOperator - - if isinstance(ti.task, MappedOperator): - for op in ti.task.iter_mapped_dependencies(): - yield op.task_id - task_group = ti.task.task_group - if task_group and task_group.iter_mapped_task_groups(): - yield from ( - op.task_id - for tg in task_group.iter_mapped_task_groups() - for op in tg.iter_mapped_dependencies() - ) - @functools.lru_cache def _get_relevant_upstream_map_indexes(upstream_id: str) -> int | range | None: """Get the given task's map indexes relevant to the current ti. @@ -157,9 +142,6 @@ def _get_relevant_upstream_map_indexes(upstream_id: str) -> int | range | None: """ if TYPE_CHECKING: assert isinstance(ti.task.dag, DAG) - if isinstance(ti.task.task_group, MappedTaskGroup): - if upstream_id not in set(_iter_expansion_dependencies()): - return None try: expanded_ti_count = _get_expanded_ti_count() except (NotFullyPopulated, NotMapped): diff --git a/tests/models/test_mappedoperator.py b/tests/models/test_mappedoperator.py index 5c2e23c1f9e30c..7244c55774840c 100644 --- a/tests/models/test_mappedoperator.py +++ b/tests/models/test_mappedoperator.py @@ -1305,8 +1305,8 @@ def file_transforms(filename): states = self.get_states(dr) expected = { "file_transforms.my_setup": {0: "success", 1: "failed", 2: "skipped"}, - "file_transforms.my_work": {2: "upstream_failed", 1: "upstream_failed", 0: "upstream_failed"}, - "file_transforms.my_teardown": {2: "success", 1: "success", 0: "success"}, + "file_transforms.my_work": {0: "success", 1: "upstream_failed", 2: "skipped"}, + "file_transforms.my_teardown": {0: "success", 1: "upstream_failed", 2: "skipped"}, } assert states == expected diff --git a/tests/ti_deps/deps/test_trigger_rule_dep.py b/tests/ti_deps/deps/test_trigger_rule_dep.py index 1bc8808cb8b2ba..00cbcd449af3ef 100644 --- a/tests/ti_deps/deps/test_trigger_rule_dep.py +++ b/tests/ti_deps/deps/test_trigger_rule_dep.py @@ -1165,23 +1165,19 @@ def _one_scheduling_decision_iteration() -> dict[tuple[str, int], TaskInstance]: tis = _one_scheduling_decision_iteration() assert sorted(tis) == [("tg.t1", 0), ("tg.t1", 1), ("tg.t1", 2)] - # After running the first t1, the remaining t1 must be run before t2 is available. + # After running the first t1, the first t2 becomes immediately available. tis["tg.t1", 0].run() tis = _one_scheduling_decision_iteration() - assert sorted(tis) == [("tg.t1", 1), ("tg.t1", 2)] + assert sorted(tis) == [("tg.t1", 1), ("tg.t1", 2), ("tg.t2", 0)] - # After running all t1, t2 is available. - tis["tg.t1", 1].run() + # Similarly for the subsequent t2 instances. tis["tg.t1", 2].run() tis = _one_scheduling_decision_iteration() - assert sorted(tis) == [("tg.t2", 0), ("tg.t2", 1), ("tg.t2", 2)] - - # Similarly for t2 instances. They both have to complete before t3 is available - tis["tg.t2", 0].run() - tis = _one_scheduling_decision_iteration() - assert sorted(tis) == [("tg.t2", 1), ("tg.t2", 2)] + assert sorted(tis) == [("tg.t1", 1), ("tg.t2", 0), ("tg.t2", 2)] # But running t2 partially does not make t3 available. + tis["tg.t1", 1].run() + tis["tg.t2", 0].run() tis["tg.t2", 2].run() tis = _one_scheduling_decision_iteration() assert sorted(tis) == [("tg.t2", 1)] @@ -1411,34 +1407,3 @@ def w2(): (status,) = self.get_dep_statuses(dr, "w2", flag_upstream_failed=True, session=session) assert status.reason.startswith("All setup tasks must complete successfully") assert self.get_ti(dr, "w2").state == expected - - -def test_mapped_tasks_in_mapped_task_group_waits_for_upstreams_to_complete(dag_maker, session): - """Test that one failed trigger rule works well in mapped task group""" - with dag_maker() as dag: - - @dag.task - def t1(): - return [1, 2, 3] - - @task_group("tg1") - def tg1(a): - @dag.task() - def t2(a): - return a - - @dag.task(trigger_rule=TriggerRule.ONE_FAILED) - def t3(a): - return a - - t2(a) >> t3(a) - - t = t1() - tg1.expand(a=t) - - dr = dag_maker.create_dagrun() - ti = dr.get_task_instance(task_id="t1") - ti.run() - dr.task_instance_scheduling_decisions() - ti3 = dr.get_task_instance(task_id="tg1.t3") - assert not ti3.state