From 7e0968f3bb34b5a058ae34480af03eb7b99d4064 Mon Sep 17 00:00:00 2001 From: yuqian90 Date: Fri, 21 Feb 2020 19:35:55 +0800 Subject: [PATCH] [AIRFLOW-5391] Do not re-run skipped tasks when they are cleared (#7276) If a task is skipped by BranchPythonOperator, BaseBranchOperator or ShortCircuitOperator and the user then clears the skipped task later, it'll execute. This is probably not the right behaviour. This commit changes that so it will be skipped again. This can be ignored by running the task again with "Ignore Task Deps" override. (cherry picked from commit 1cdab56a6192f69962506b7ff632c986c84eb10d) --- airflow/models/baseoperator.py | 2 + airflow/models/skipmixin.py | 112 +++++++++++---- airflow/ti_deps/dep_context.py | 27 +++- .../deps/not_previously_skipped_dep.py | 88 ++++++++++++ tests/jobs/test_scheduler_job.py | 39 +++++ tests/operators/test_python_operator.py | 119 +++++++++++++++- .../deps/test_not_previously_skipped_dep.py | 133 ++++++++++++++++++ tests/ti_deps/deps/test_trigger_rule_dep.py | 42 +++++- 8 files changed, 529 insertions(+), 33 deletions(-) create mode 100644 airflow/ti_deps/deps/not_previously_skipped_dep.py create mode 100644 tests/ti_deps/deps/test_not_previously_skipped_dep.py diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index 52037c56a756c..266ad64d08a0a 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -45,6 +45,7 @@ from airflow.models.taskinstance import TaskInstance, clear_task_instances from airflow.models.xcom import XCOM_RETURN_KEY from airflow.ti_deps.deps.not_in_retry_period_dep import NotInRetryPeriodDep +from airflow.ti_deps.deps.not_previously_skipped_dep import NotPreviouslySkippedDep from airflow.ti_deps.deps.prev_dagrun_dep import PrevDagrunDep from airflow.ti_deps.deps.trigger_rule_dep import TriggerRuleDep from airflow.utils import timezone @@ -575,6 +576,7 @@ def deps(self): NotInRetryPeriodDep(), PrevDagrunDep(), TriggerRuleDep(), + NotPreviouslySkippedDep(), } @property diff --git a/airflow/models/skipmixin.py b/airflow/models/skipmixin.py index 57341d8fb4eab..3b4531f68c7d4 100644 --- a/airflow/models/skipmixin.py +++ b/airflow/models/skipmixin.py @@ -19,28 +19,28 @@ from airflow.models.taskinstance import TaskInstance from airflow.utils import timezone -from airflow.utils.db import provide_session +from airflow.utils.db import create_session, provide_session from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.state import State import six -from typing import Union, Iterable, Set +from typing import Set + +# The key used by SkipMixin to store XCom data. +XCOM_SKIPMIXIN_KEY = "skipmixin_key" + +# The dictionary key used to denote task IDs that are skipped +XCOM_SKIPMIXIN_SKIPPED = "skipped" + +# The dictionary key used to denote task IDs that are followed +XCOM_SKIPMIXIN_FOLLOWED = "followed" class SkipMixin(LoggingMixin): - @provide_session - def skip(self, dag_run, execution_date, tasks, session=None): + def _set_state_to_skipped(self, dag_run, execution_date, tasks, session): """ - Sets tasks instances to skipped from the same dag run. - - :param dag_run: the DagRun for which to set the tasks to skipped - :param execution_date: execution_date - :param tasks: tasks to skip (not task_ids) - :param session: db session to use + Used internally to set state of task instances to skipped from the same dag run. """ - if not tasks: - return - task_ids = [d.task_id for d in tasks] now = timezone.utcnow() @@ -48,12 +48,15 @@ def skip(self, dag_run, execution_date, tasks, session=None): session.query(TaskInstance).filter( TaskInstance.dag_id == dag_run.dag_id, TaskInstance.execution_date == dag_run.execution_date, - TaskInstance.task_id.in_(task_ids) - ).update({TaskInstance.state: State.SKIPPED, - TaskInstance.start_date: now, - TaskInstance.end_date: now}, - synchronize_session=False) - session.commit() + TaskInstance.task_id.in_(task_ids), + ).update( + { + TaskInstance.state: State.SKIPPED, + TaskInstance.start_date: now, + TaskInstance.end_date: now, + }, + synchronize_session=False, + ) else: assert execution_date is not None, "Execution date is None and no dag run" @@ -66,14 +69,56 @@ def skip(self, dag_run, execution_date, tasks, session=None): ti.end_date = now session.merge(ti) - session.commit() + @provide_session + def skip( + self, dag_run, execution_date, tasks, session=None, + ): + """ + Sets tasks instances to skipped from the same dag run. + + If this instance has a `task_id` attribute, store the list of skipped task IDs to XCom + so that NotPreviouslySkippedDep knows these tasks should be skipped when they + are cleared. - def skip_all_except(self, ti, branch_task_ids): - # type: (TaskInstance, Union[str, Iterable[str]]) -> None + :param dag_run: the DagRun for which to set the tasks to skipped + :param execution_date: execution_date + :param tasks: tasks to skip (not task_ids) + :param session: db session to use + """ + if not tasks: + return + + self._set_state_to_skipped(dag_run, execution_date, tasks, session) + session.commit() + + # SkipMixin may not necessarily have a task_id attribute. Only store to XCom if one is available. + try: + task_id = self.task_id + except AttributeError: + task_id = None + + if task_id is not None: + from airflow.models.xcom import XCom + + XCom.set( + key=XCOM_SKIPMIXIN_KEY, + value={XCOM_SKIPMIXIN_SKIPPED: [d.task_id for d in tasks]}, + task_id=task_id, + dag_id=dag_run.dag_id, + execution_date=dag_run.execution_date, + session=session + ) + + def skip_all_except( + self, ti, branch_task_ids + ): """ This method implements the logic for a branching operator; given a single task ID or list of task IDs to follow, this skips all other tasks immediately downstream of this operator. + + branch_task_ids is stored to XCom so that NotPreviouslySkippedDep knows skipped tasks or + newly added tasks should be skipped when they are cleared. """ self.log.info("Following branch %s", branch_task_ids) if isinstance(branch_task_ids, six.string_types): @@ -90,13 +135,22 @@ def skip_all_except(self, ti, branch_task_ids): # is also a downstream task of the branch task, we exclude it from skipping. branch_downstream_task_ids = set() # type: Set[str] for b in branch_task_ids: - branch_downstream_task_ids.update(dag. - get_task(b). - get_flat_relative_ids(upstream=False)) + branch_downstream_task_ids.update( + dag.get_task(b).get_flat_relative_ids(upstream=False) + ) - skip_tasks = [t for t in downstream_tasks - if t.task_id not in branch_task_ids and - t.task_id not in branch_downstream_task_ids] + skip_tasks = [ + t + for t in downstream_tasks + if t.task_id not in branch_task_ids + and t.task_id not in branch_downstream_task_ids + ] self.log.info("Skipping tasks %s", [t.task_id for t in skip_tasks]) - self.skip(dag_run, ti.execution_date, skip_tasks) + with create_session() as session: + self._set_state_to_skipped( + dag_run, ti.execution_date, skip_tasks, session=session + ) + ti.xcom_push( + key=XCOM_SKIPMIXIN_KEY, value={XCOM_SKIPMIXIN_FOLLOWED: branch_task_ids} + ) diff --git a/airflow/ti_deps/dep_context.py b/airflow/ti_deps/dep_context.py index c5d999ae33c17..74307e4b45d2c 100644 --- a/airflow/ti_deps/dep_context.py +++ b/airflow/ti_deps/dep_context.py @@ -66,6 +66,8 @@ class DepContext(object): :type ignore_task_deps: bool :param ignore_ti_state: Ignore the task instance's previous failure/success :type ignore_ti_state: bool + :param finished_tasks: A list of all the finished tasks of this run + :type finished_tasks: list[airflow.models.TaskInstance] """ def __init__( self, @@ -76,7 +78,8 @@ def __init__( ignore_in_retry_period=False, ignore_in_reschedule_period=False, ignore_task_deps=False, - ignore_ti_state=False): + ignore_ti_state=False, + finished_tasks=None): self.deps = deps or set() self.flag_upstream_failed = flag_upstream_failed self.ignore_all_deps = ignore_all_deps @@ -85,6 +88,28 @@ def __init__( self.ignore_in_reschedule_period = ignore_in_reschedule_period self.ignore_task_deps = ignore_task_deps self.ignore_ti_state = ignore_ti_state + self.finished_tasks = finished_tasks + + def ensure_finished_tasks(self, dag, execution_date, session): + """ + This method makes sure finished_tasks is populated if it's currently None. + This is for the strange feature of running tasks without dag_run. + + :param dag: The DAG for which to find finished tasks + :type dag: airflow.models.DAG + :param execution_date: The execution_date to look for + :param session: Database session to use + :return: A list of all the finished tasks of this DAG and execution_date + :rtype: list[airflow.models.TaskInstance] + """ + if self.finished_tasks is None: + self.finished_tasks = dag.get_task_instances( + start_date=execution_date, + end_date=execution_date, + state=State.finished() + [State.UPSTREAM_FAILED], + session=session, + ) + return self.finished_tasks # In order to be able to get queued a task must have one of these states diff --git a/airflow/ti_deps/deps/not_previously_skipped_dep.py b/airflow/ti_deps/deps/not_previously_skipped_dep.py new file mode 100644 index 0000000000000..34ff6acff39c4 --- /dev/null +++ b/airflow/ti_deps/deps/not_previously_skipped_dep.py @@ -0,0 +1,88 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from airflow.ti_deps.deps.base_ti_dep import BaseTIDep + + +class NotPreviouslySkippedDep(BaseTIDep): + """ + Determines if any of the task's direct upstream relatives have decided this task should + be skipped. + """ + + NAME = "Not Previously Skipped" + IGNORABLE = True + IS_TASK_DEP = True + + def _get_dep_statuses( + self, ti, session, dep_context + ): # pylint: disable=signature-differs + from airflow.models.skipmixin import ( + SkipMixin, + XCOM_SKIPMIXIN_KEY, + XCOM_SKIPMIXIN_SKIPPED, + XCOM_SKIPMIXIN_FOLLOWED, + ) + from airflow.utils.state import State + + upstream = ti.task.get_direct_relatives(upstream=True) + + finished_tasks = dep_context.ensure_finished_tasks( + ti.task.dag, ti.execution_date, session + ) + + finished_task_ids = {t.task_id for t in finished_tasks} + + for parent in upstream: + if isinstance(parent, SkipMixin): + if parent.task_id not in finished_task_ids: + # This can happen if the parent task has not yet run. + continue + + prev_result = ti.xcom_pull( + task_ids=parent.task_id, key=XCOM_SKIPMIXIN_KEY + ) + + if prev_result is None: + # This can happen if the parent task has not yet run. + continue + + should_skip = False + if ( + XCOM_SKIPMIXIN_FOLLOWED in prev_result + and ti.task_id not in prev_result[XCOM_SKIPMIXIN_FOLLOWED] + ): + # Skip any tasks that are not in "followed" + should_skip = True + elif ( + XCOM_SKIPMIXIN_SKIPPED in prev_result + and ti.task_id in prev_result[XCOM_SKIPMIXIN_SKIPPED] + ): + # Skip any tasks that are in "skipped" + should_skip = True + + if should_skip: + # If the parent SkipMixin has run, and the XCom result stored indicates this + # ti should be skipped, set ti.state to SKIPPED and fail the rule so that the + # ti does not execute. + ti.set_state(State.SKIPPED, session) + yield self._failing_status( + reason="Skipping because of previous XCom result from parent task {}" + .format(parent.task_id) + ) + return diff --git a/tests/jobs/test_scheduler_job.py b/tests/jobs/test_scheduler_job.py index 48f70a9f4753b..161e4796397a1 100644 --- a/tests/jobs/test_scheduler_job.py +++ b/tests/jobs/test_scheduler_job.py @@ -3011,3 +3011,42 @@ def test_should_mark_dummy_task_as_success(self): self.assertIsNone(start_date) self.assertIsNone(end_date) self.assertIsNone(duration) + + +def test_task_with_upstream_skip_process_task_instances(): + """ + Test if _process_task_instances puts a task instance into SKIPPED state if any of its + upstream tasks are skipped according to TriggerRuleDep. + """ + with DAG( + dag_id='test_task_with_upstream_skip_dag', + start_date=DEFAULT_DATE, + schedule_interval=None + ) as dag: + dummy1 = DummyOperator(task_id='dummy1') + dummy2 = DummyOperator(task_id="dummy2") + dummy3 = DummyOperator(task_id="dummy3") + [dummy1, dummy2] >> dummy3 + + dag_file_processor = SchedulerJob(dag_ids=[], log=mock.MagicMock()) + dag.clear() + dr = dag.create_dagrun(run_id="manual__{}".format(DEFAULT_DATE.isoformat()), + state=State.RUNNING, + execution_date=DEFAULT_DATE) + assert dr is not None + + with create_session() as session: + tis = {ti.task_id: ti for ti in dr.get_task_instances(session=session)} + # Set dummy1 to skipped and dummy2 to success. dummy3 remains as none. + tis[dummy1.task_id].state = State.SKIPPED + tis[dummy2.task_id].state = State.SUCCESS + assert tis[dummy3.task_id].state == State.NONE + + dag_file_processor._process_task_instances(dag, task_instances_list=Mock()) + + with create_session() as session: + tis = {ti.task_id: ti for ti in dr.get_task_instances(session=session)} + assert tis[dummy1.task_id].state == State.SKIPPED + assert tis[dummy2.task_id].state == State.SUCCESS + # dummy3 should be skipped because dummy1 is skipped. + assert tis[dummy3.task_id].state == State.SKIPPED diff --git a/tests/operators/test_python_operator.py b/tests/operators/test_python_operator.py index 6f3dfe2703df2..a92213add8a34 100644 --- a/tests/operators/test_python_operator.py +++ b/tests/operators/test_python_operator.py @@ -32,6 +32,7 @@ from airflow.exceptions import AirflowException from airflow.models import TaskInstance as TI, DAG, DagRun +from airflow.models.taskinstance import clear_task_instances from airflow.operators.dummy_operator import DummyOperator from airflow.operators.python_operator import PythonOperator, BranchPythonOperator from airflow.operators.python_operator import ShortCircuitOperator @@ -491,7 +492,7 @@ def test_with_skip_in_branch_downstream_dependencies(self): elif ti.task_id == 'branch_2': self.assertEqual(ti.state, State.NONE) else: - raise Exception + raise def test_with_skip_in_branch_downstream_dependencies2(self): self.branch_op = BranchPythonOperator(task_id='make_choice', @@ -520,7 +521,63 @@ def test_with_skip_in_branch_downstream_dependencies2(self): elif ti.task_id == 'branch_2': self.assertEqual(ti.state, State.NONE) else: - raise Exception + raise + + def test_clear_skipped_downstream_task(self): + """ + After a downstream task is skipped by BranchPythonOperator, clearing the skipped task + should not cause it to be executed. + """ + branch_op = BranchPythonOperator(task_id='make_choice', + dag=self.dag, + python_callable=lambda: 'branch_1') + branches = [self.branch_1, self.branch_2] + branch_op >> branches + self.dag.clear() + + dr = self.dag.create_dagrun( + run_id="manual__", + start_date=timezone.utcnow(), + execution_date=DEFAULT_DATE, + state=State.RUNNING + ) + + branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + + for task in branches: + task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + + tis = dr.get_task_instances() + for ti in tis: + if ti.task_id == 'make_choice': + self.assertEqual(ti.state, State.SUCCESS) + elif ti.task_id == 'branch_1': + self.assertEqual(ti.state, State.SUCCESS) + elif ti.task_id == 'branch_2': + self.assertEqual(ti.state, State.SKIPPED) + else: + raise + + children_tis = [ti for ti in tis if ti.task_id in branch_op.get_direct_relative_ids()] + + # Clear the children tasks. + with create_session() as session: + clear_task_instances(children_tis, session=session, dag=self.dag) + + # Run the cleared tasks again. + for task in branches: + task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + + # Check if the states are correct after children tasks are cleared. + for ti in dr.get_task_instances(): + if ti.task_id == 'make_choice': + self.assertEqual(ti.state, State.SUCCESS) + elif ti.task_id == 'branch_1': + self.assertEqual(ti.state, State.SUCCESS) + elif ti.task_id == 'branch_2': + self.assertEqual(ti.state, State.SKIPPED) + else: + raise class ShortCircuitOperatorTest(unittest.TestCase): @@ -660,3 +717,61 @@ def test_with_dag_run(self): self.assertEqual(ti.state, State.NONE) else: raise + + def test_clear_skipped_downstream_task(self): + """ + After a downstream task is skipped by ShortCircuitOperator, clearing the skipped task + should not cause it to be executed. + """ + dag = DAG('shortcircuit_clear_skipped_downstream_task', + default_args={ + 'owner': 'airflow', + 'start_date': DEFAULT_DATE + }, + schedule_interval=INTERVAL) + short_op = ShortCircuitOperator(task_id='make_choice', + dag=dag, + python_callable=lambda: False) + downstream = DummyOperator(task_id='downstream', dag=dag) + + short_op >> downstream + + dag.clear() + + dr = dag.create_dagrun( + run_id="manual__", + start_date=timezone.utcnow(), + execution_date=DEFAULT_DATE, + state=State.RUNNING + ) + + short_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + downstream.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + + tis = dr.get_task_instances() + + for ti in tis: + if ti.task_id == 'make_choice': + self.assertEqual(ti.state, State.SUCCESS) + elif ti.task_id == 'downstream': + self.assertEqual(ti.state, State.SKIPPED) + else: + raise + + # Clear downstream + with create_session() as session: + clear_task_instances([t for t in tis if t.task_id == "downstream"], + session=session, + dag=dag) + + # Run downstream again + downstream.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + + # Check if the states are correct. + for ti in dr.get_task_instances(): + if ti.task_id == 'make_choice': + self.assertEqual(ti.state, State.SUCCESS) + elif ti.task_id == 'downstream': + self.assertEqual(ti.state, State.SKIPPED) + else: + raise diff --git a/tests/ti_deps/deps/test_not_previously_skipped_dep.py b/tests/ti_deps/deps/test_not_previously_skipped_dep.py new file mode 100644 index 0000000000000..30da9cf2bc5cd --- /dev/null +++ b/tests/ti_deps/deps/test_not_previously_skipped_dep.py @@ -0,0 +1,133 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pendulum + +from airflow.models import DAG, TaskInstance +from airflow.operators.dummy_operator import DummyOperator +from airflow.operators.python_operator import BranchPythonOperator +from airflow.ti_deps.dep_context import DepContext +from airflow.ti_deps.deps.not_previously_skipped_dep import NotPreviouslySkippedDep +from airflow.utils.db import create_session +from airflow.utils.state import State + + +def test_no_parent(): + """ + A simple DAG with a single task. NotPreviouslySkippedDep is met. + """ + start_date = pendulum.datetime(2020, 1, 1) + dag = DAG("test_test_no_parent_dag", schedule_interval=None, start_date=start_date) + op1 = DummyOperator(task_id="op1", dag=dag) + + ti1 = TaskInstance(op1, start_date) + + with create_session() as session: + dep = NotPreviouslySkippedDep() + assert len(list(dep.get_dep_statuses(ti1, session, DepContext()))) == 0 + assert dep.is_met(ti1, session) + assert ti1.state != State.SKIPPED + + +def test_no_skipmixin_parent(): + """ + A simple DAG with no branching. Both op1 and op2 are DummyOperator. NotPreviouslySkippedDep is met. + """ + start_date = pendulum.datetime(2020, 1, 1) + dag = DAG( + "test_no_skipmixin_parent_dag", schedule_interval=None, start_date=start_date + ) + op1 = DummyOperator(task_id="op1", dag=dag) + op2 = DummyOperator(task_id="op2", dag=dag) + op1 >> op2 + + ti2 = TaskInstance(op2, start_date) + + with create_session() as session: + dep = NotPreviouslySkippedDep() + assert len(list(dep.get_dep_statuses(ti2, session, DepContext()))) == 0 + assert dep.is_met(ti2, session) + assert ti2.state != State.SKIPPED + + +def test_parent_follow_branch(): + """ + A simple DAG with a BranchPythonOperator that follows op2. NotPreviouslySkippedDep is met. + """ + start_date = pendulum.datetime(2020, 1, 1) + dag = DAG( + "test_parent_follow_branch_dag", schedule_interval=None, start_date=start_date + ) + op1 = BranchPythonOperator(task_id="op1", python_callable=lambda: "op2", dag=dag) + op2 = DummyOperator(task_id="op2", dag=dag) + op1 >> op2 + + TaskInstance(op1, start_date).run() + ti2 = TaskInstance(op2, start_date) + + with create_session() as session: + dep = NotPreviouslySkippedDep() + assert len(list(dep.get_dep_statuses(ti2, session, DepContext()))) == 0 + assert dep.is_met(ti2, session) + assert ti2.state != State.SKIPPED + + +def test_parent_skip_branch(): + """ + A simple DAG with a BranchPythonOperator that does not follow op2. NotPreviouslySkippedDep is not met. + """ + start_date = pendulum.datetime(2020, 1, 1) + dag = DAG( + "test_parent_skip_branch_dag", schedule_interval=None, start_date=start_date + ) + op1 = BranchPythonOperator(task_id="op1", python_callable=lambda: "op3", dag=dag) + op2 = DummyOperator(task_id="op2", dag=dag) + op3 = DummyOperator(task_id="op3", dag=dag) + op1 >> [op2, op3] + + TaskInstance(op1, start_date).run() + ti2 = TaskInstance(op2, start_date) + + with create_session() as session: + dep = NotPreviouslySkippedDep() + assert len(list(dep.get_dep_statuses(ti2, session, DepContext()))) == 1 + assert not dep.is_met(ti2, session) + assert ti2.state == State.SKIPPED + + +def test_parent_not_executed(): + """ + A simple DAG with a BranchPythonOperator that does not follow op2. Parent task is not yet + executed (no xcom data). NotPreviouslySkippedDep is met (no decision). + """ + start_date = pendulum.datetime(2020, 1, 1) + dag = DAG( + "test_parent_not_executed_dag", schedule_interval=None, start_date=start_date + ) + op1 = BranchPythonOperator(task_id="op1", python_callable=lambda: "op3", dag=dag) + op2 = DummyOperator(task_id="op2", dag=dag) + op3 = DummyOperator(task_id="op3", dag=dag) + op1 >> [op2, op3] + + ti2 = TaskInstance(op2, start_date) + + with create_session() as session: + dep = NotPreviouslySkippedDep() + assert len(list(dep.get_dep_statuses(ti2, session, DepContext()))) == 0 + assert dep.is_met(ti2, session) + assert ti2.state == State.NONE diff --git a/tests/ti_deps/deps/test_trigger_rule_dep.py b/tests/ti_deps/deps/test_trigger_rule_dep.py index 45514f6ca20fb..bbd8e1374bfcc 100644 --- a/tests/ti_deps/deps/test_trigger_rule_dep.py +++ b/tests/ti_deps/deps/test_trigger_rule_dep.py @@ -19,13 +19,13 @@ import unittest from datetime import datetime +from unittest.mock import Mock from airflow.models import BaseOperator, TaskInstance from airflow.utils.trigger_rule import TriggerRule from airflow.ti_deps.deps.trigger_rule_dep import TriggerRuleDep from airflow.utils.db import create_session from airflow.utils.state import State -from tests.compat import Mock class TriggerRuleDepTest(unittest.TestCase): @@ -165,6 +165,46 @@ def test_all_success_tr_failure(self): self.assertEqual(len(dep_statuses), 1) self.assertFalse(dep_statuses[0].passed) + def test_all_success_tr_skip(self): + """ + All-success trigger rule fails when some upstream tasks are skipped. + """ + ti = self._get_task_instance(TriggerRule.ALL_SUCCESS, + upstream_task_ids=["FakeTaskID", + "OtherFakeTaskID"]) + dep_statuses = tuple(TriggerRuleDep()._evaluate_trigger_rule( + ti=ti, + successes=1, + skipped=1, + failed=0, + upstream_failed=0, + done=2, + flag_upstream_failed=False, + session="Fake Session")) + self.assertEqual(len(dep_statuses), 1) + self.assertFalse(dep_statuses[0].passed) + + def test_all_success_tr_skip_flag_upstream(self): + """ + All-success trigger rule fails when some upstream tasks are skipped. The state of the ti + should be set to SKIPPED when flag_upstream_failed is True. + """ + ti = self._get_task_instance(TriggerRule.ALL_SUCCESS, + upstream_task_ids=["FakeTaskID", + "OtherFakeTaskID"]) + dep_statuses = tuple(TriggerRuleDep()._evaluate_trigger_rule( + ti=ti, + successes=1, + skipped=1, + failed=0, + upstream_failed=0, + done=2, + flag_upstream_failed=True, + session=Mock())) + self.assertEqual(len(dep_statuses), 1) + self.assertFalse(dep_statuses[0].passed) + self.assertEqual(ti.state, State.SKIPPED) + def test_none_failed_tr_success(self): """ All success including skip trigger rule success