diff --git a/UPDATING.md b/UPDATING.md index 61734bb2d2e945..f82ba10d4a4fb2 100644 --- a/UPDATING.md +++ b/UPDATING.md @@ -25,6 +25,7 @@ assists users migrating to a new version. **Table of contents** +- [Airflow 1.10.12](#airflow-11012) - [Airflow 1.10.11](#airflow-11011) - [Airflow 1.10.10](#airflow-11010) - [Airflow 1.10.9](#airflow-1109) @@ -59,6 +60,12 @@ More tips can be found in the guide: https://developers.google.com/style/inclusive-documentation --> +## Airflow 1.10.12 + +### Clearing tasks skipped by SkipMixin will skip them + +Previously, when tasks skipped by SkipMixin (such as BranchPythonOperator, BaseBranchOperator and ShortCircuitOperator) are cleared, they execute. Since 1.10.12, when such skipped tasks are cleared, +they will be skipped again by the newly introduced NotPreviouslySkippedDep. ## Airflow 1.10.11 diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index 52037c56a756c3..266ad64d08a0a9 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -45,6 +45,7 @@ from airflow.models.taskinstance import TaskInstance, clear_task_instances from airflow.models.xcom import XCOM_RETURN_KEY from airflow.ti_deps.deps.not_in_retry_period_dep import NotInRetryPeriodDep +from airflow.ti_deps.deps.not_previously_skipped_dep import NotPreviouslySkippedDep from airflow.ti_deps.deps.prev_dagrun_dep import PrevDagrunDep from airflow.ti_deps.deps.trigger_rule_dep import TriggerRuleDep from airflow.utils import timezone @@ -575,6 +576,7 @@ def deps(self): NotInRetryPeriodDep(), PrevDagrunDep(), TriggerRuleDep(), + NotPreviouslySkippedDep(), } @property diff --git a/airflow/models/skipmixin.py b/airflow/models/skipmixin.py index 57341d8fb4eabd..3b4531f68c7d4b 100644 --- a/airflow/models/skipmixin.py +++ b/airflow/models/skipmixin.py @@ -19,28 +19,28 @@ from airflow.models.taskinstance import TaskInstance from airflow.utils import timezone -from airflow.utils.db import provide_session +from airflow.utils.db import create_session, provide_session from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.state import State import six -from typing import Union, Iterable, Set +from typing import Set + +# The key used by SkipMixin to store XCom data. +XCOM_SKIPMIXIN_KEY = "skipmixin_key" + +# The dictionary key used to denote task IDs that are skipped +XCOM_SKIPMIXIN_SKIPPED = "skipped" + +# The dictionary key used to denote task IDs that are followed +XCOM_SKIPMIXIN_FOLLOWED = "followed" class SkipMixin(LoggingMixin): - @provide_session - def skip(self, dag_run, execution_date, tasks, session=None): + def _set_state_to_skipped(self, dag_run, execution_date, tasks, session): """ - Sets tasks instances to skipped from the same dag run. - - :param dag_run: the DagRun for which to set the tasks to skipped - :param execution_date: execution_date - :param tasks: tasks to skip (not task_ids) - :param session: db session to use + Used internally to set state of task instances to skipped from the same dag run. """ - if not tasks: - return - task_ids = [d.task_id for d in tasks] now = timezone.utcnow() @@ -48,12 +48,15 @@ def skip(self, dag_run, execution_date, tasks, session=None): session.query(TaskInstance).filter( TaskInstance.dag_id == dag_run.dag_id, TaskInstance.execution_date == dag_run.execution_date, - TaskInstance.task_id.in_(task_ids) - ).update({TaskInstance.state: State.SKIPPED, - TaskInstance.start_date: now, - TaskInstance.end_date: now}, - synchronize_session=False) - session.commit() + TaskInstance.task_id.in_(task_ids), + ).update( + { + TaskInstance.state: State.SKIPPED, + TaskInstance.start_date: now, + TaskInstance.end_date: now, + }, + synchronize_session=False, + ) else: assert execution_date is not None, "Execution date is None and no dag run" @@ -66,14 +69,56 @@ def skip(self, dag_run, execution_date, tasks, session=None): ti.end_date = now session.merge(ti) - session.commit() + @provide_session + def skip( + self, dag_run, execution_date, tasks, session=None, + ): + """ + Sets tasks instances to skipped from the same dag run. + + If this instance has a `task_id` attribute, store the list of skipped task IDs to XCom + so that NotPreviouslySkippedDep knows these tasks should be skipped when they + are cleared. - def skip_all_except(self, ti, branch_task_ids): - # type: (TaskInstance, Union[str, Iterable[str]]) -> None + :param dag_run: the DagRun for which to set the tasks to skipped + :param execution_date: execution_date + :param tasks: tasks to skip (not task_ids) + :param session: db session to use + """ + if not tasks: + return + + self._set_state_to_skipped(dag_run, execution_date, tasks, session) + session.commit() + + # SkipMixin may not necessarily have a task_id attribute. Only store to XCom if one is available. + try: + task_id = self.task_id + except AttributeError: + task_id = None + + if task_id is not None: + from airflow.models.xcom import XCom + + XCom.set( + key=XCOM_SKIPMIXIN_KEY, + value={XCOM_SKIPMIXIN_SKIPPED: [d.task_id for d in tasks]}, + task_id=task_id, + dag_id=dag_run.dag_id, + execution_date=dag_run.execution_date, + session=session + ) + + def skip_all_except( + self, ti, branch_task_ids + ): """ This method implements the logic for a branching operator; given a single task ID or list of task IDs to follow, this skips all other tasks immediately downstream of this operator. + + branch_task_ids is stored to XCom so that NotPreviouslySkippedDep knows skipped tasks or + newly added tasks should be skipped when they are cleared. """ self.log.info("Following branch %s", branch_task_ids) if isinstance(branch_task_ids, six.string_types): @@ -90,13 +135,22 @@ def skip_all_except(self, ti, branch_task_ids): # is also a downstream task of the branch task, we exclude it from skipping. branch_downstream_task_ids = set() # type: Set[str] for b in branch_task_ids: - branch_downstream_task_ids.update(dag. - get_task(b). - get_flat_relative_ids(upstream=False)) + branch_downstream_task_ids.update( + dag.get_task(b).get_flat_relative_ids(upstream=False) + ) - skip_tasks = [t for t in downstream_tasks - if t.task_id not in branch_task_ids and - t.task_id not in branch_downstream_task_ids] + skip_tasks = [ + t + for t in downstream_tasks + if t.task_id not in branch_task_ids + and t.task_id not in branch_downstream_task_ids + ] self.log.info("Skipping tasks %s", [t.task_id for t in skip_tasks]) - self.skip(dag_run, ti.execution_date, skip_tasks) + with create_session() as session: + self._set_state_to_skipped( + dag_run, ti.execution_date, skip_tasks, session=session + ) + ti.xcom_push( + key=XCOM_SKIPMIXIN_KEY, value={XCOM_SKIPMIXIN_FOLLOWED: branch_task_ids} + ) diff --git a/airflow/ti_deps/dep_context.py b/airflow/ti_deps/dep_context.py index c5d999ae33c179..74307e4b45d2cf 100644 --- a/airflow/ti_deps/dep_context.py +++ b/airflow/ti_deps/dep_context.py @@ -66,6 +66,8 @@ class DepContext(object): :type ignore_task_deps: bool :param ignore_ti_state: Ignore the task instance's previous failure/success :type ignore_ti_state: bool + :param finished_tasks: A list of all the finished tasks of this run + :type finished_tasks: list[airflow.models.TaskInstance] """ def __init__( self, @@ -76,7 +78,8 @@ def __init__( ignore_in_retry_period=False, ignore_in_reschedule_period=False, ignore_task_deps=False, - ignore_ti_state=False): + ignore_ti_state=False, + finished_tasks=None): self.deps = deps or set() self.flag_upstream_failed = flag_upstream_failed self.ignore_all_deps = ignore_all_deps @@ -85,6 +88,28 @@ def __init__( self.ignore_in_reschedule_period = ignore_in_reschedule_period self.ignore_task_deps = ignore_task_deps self.ignore_ti_state = ignore_ti_state + self.finished_tasks = finished_tasks + + def ensure_finished_tasks(self, dag, execution_date, session): + """ + This method makes sure finished_tasks is populated if it's currently None. + This is for the strange feature of running tasks without dag_run. + + :param dag: The DAG for which to find finished tasks + :type dag: airflow.models.DAG + :param execution_date: The execution_date to look for + :param session: Database session to use + :return: A list of all the finished tasks of this DAG and execution_date + :rtype: list[airflow.models.TaskInstance] + """ + if self.finished_tasks is None: + self.finished_tasks = dag.get_task_instances( + start_date=execution_date, + end_date=execution_date, + state=State.finished() + [State.UPSTREAM_FAILED], + session=session, + ) + return self.finished_tasks # In order to be able to get queued a task must have one of these states diff --git a/airflow/ti_deps/deps/not_previously_skipped_dep.py b/airflow/ti_deps/deps/not_previously_skipped_dep.py new file mode 100644 index 00000000000000..34ff6acff39c41 --- /dev/null +++ b/airflow/ti_deps/deps/not_previously_skipped_dep.py @@ -0,0 +1,88 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from airflow.ti_deps.deps.base_ti_dep import BaseTIDep + + +class NotPreviouslySkippedDep(BaseTIDep): + """ + Determines if any of the task's direct upstream relatives have decided this task should + be skipped. + """ + + NAME = "Not Previously Skipped" + IGNORABLE = True + IS_TASK_DEP = True + + def _get_dep_statuses( + self, ti, session, dep_context + ): # pylint: disable=signature-differs + from airflow.models.skipmixin import ( + SkipMixin, + XCOM_SKIPMIXIN_KEY, + XCOM_SKIPMIXIN_SKIPPED, + XCOM_SKIPMIXIN_FOLLOWED, + ) + from airflow.utils.state import State + + upstream = ti.task.get_direct_relatives(upstream=True) + + finished_tasks = dep_context.ensure_finished_tasks( + ti.task.dag, ti.execution_date, session + ) + + finished_task_ids = {t.task_id for t in finished_tasks} + + for parent in upstream: + if isinstance(parent, SkipMixin): + if parent.task_id not in finished_task_ids: + # This can happen if the parent task has not yet run. + continue + + prev_result = ti.xcom_pull( + task_ids=parent.task_id, key=XCOM_SKIPMIXIN_KEY + ) + + if prev_result is None: + # This can happen if the parent task has not yet run. + continue + + should_skip = False + if ( + XCOM_SKIPMIXIN_FOLLOWED in prev_result + and ti.task_id not in prev_result[XCOM_SKIPMIXIN_FOLLOWED] + ): + # Skip any tasks that are not in "followed" + should_skip = True + elif ( + XCOM_SKIPMIXIN_SKIPPED in prev_result + and ti.task_id in prev_result[XCOM_SKIPMIXIN_SKIPPED] + ): + # Skip any tasks that are in "skipped" + should_skip = True + + if should_skip: + # If the parent SkipMixin has run, and the XCom result stored indicates this + # ti should be skipped, set ti.state to SKIPPED and fail the rule so that the + # ti does not execute. + ti.set_state(State.SKIPPED, session) + yield self._failing_status( + reason="Skipping because of previous XCom result from parent task {}" + .format(parent.task_id) + ) + return diff --git a/tests/jobs/test_scheduler_job.py b/tests/jobs/test_scheduler_job.py index 48f70a9f4753b0..161e4796397a11 100644 --- a/tests/jobs/test_scheduler_job.py +++ b/tests/jobs/test_scheduler_job.py @@ -3011,3 +3011,42 @@ def test_should_mark_dummy_task_as_success(self): self.assertIsNone(start_date) self.assertIsNone(end_date) self.assertIsNone(duration) + + +def test_task_with_upstream_skip_process_task_instances(): + """ + Test if _process_task_instances puts a task instance into SKIPPED state if any of its + upstream tasks are skipped according to TriggerRuleDep. + """ + with DAG( + dag_id='test_task_with_upstream_skip_dag', + start_date=DEFAULT_DATE, + schedule_interval=None + ) as dag: + dummy1 = DummyOperator(task_id='dummy1') + dummy2 = DummyOperator(task_id="dummy2") + dummy3 = DummyOperator(task_id="dummy3") + [dummy1, dummy2] >> dummy3 + + dag_file_processor = SchedulerJob(dag_ids=[], log=mock.MagicMock()) + dag.clear() + dr = dag.create_dagrun(run_id="manual__{}".format(DEFAULT_DATE.isoformat()), + state=State.RUNNING, + execution_date=DEFAULT_DATE) + assert dr is not None + + with create_session() as session: + tis = {ti.task_id: ti for ti in dr.get_task_instances(session=session)} + # Set dummy1 to skipped and dummy2 to success. dummy3 remains as none. + tis[dummy1.task_id].state = State.SKIPPED + tis[dummy2.task_id].state = State.SUCCESS + assert tis[dummy3.task_id].state == State.NONE + + dag_file_processor._process_task_instances(dag, task_instances_list=Mock()) + + with create_session() as session: + tis = {ti.task_id: ti for ti in dr.get_task_instances(session=session)} + assert tis[dummy1.task_id].state == State.SKIPPED + assert tis[dummy2.task_id].state == State.SUCCESS + # dummy3 should be skipped because dummy1 is skipped. + assert tis[dummy3.task_id].state == State.SKIPPED diff --git a/tests/operators/test_latest_only_operator.py b/tests/operators/test_latest_only_operator.py index 3edff8de8292d8..6f23f595c63b70 100644 --- a/tests/operators/test_latest_only_operator.py +++ b/tests/operators/test_latest_only_operator.py @@ -47,15 +47,40 @@ def get_task_instances(task_id): class LatestOnlyOperatorTest(unittest.TestCase): + @classmethod + def setUpClass(cls): + from tests.compat import MagicMock + from airflow.jobs import SchedulerJob - def setUp(self): - super(LatestOnlyOperatorTest, self).setUp() - self.dag = DAG( + cls.dag = DAG( 'test_dag', default_args={ 'owner': 'airflow', 'start_date': DEFAULT_DATE}, schedule_interval=INTERVAL) + + cls.dag.create_dagrun( + run_id="manual__1", + execution_date=DEFAULT_DATE, + state=State.RUNNING + ) + + cls.dag.create_dagrun( + run_id="manual__2", + execution_date=timezone.datetime(2016, 1, 1, 12), + state=State.RUNNING + ) + + cls.dag.create_dagrun( + run_id="manual__3", + execution_date=END_DATE, + state=State.RUNNING + ) + + cls.dag_file_processor = SchedulerJob(dag_ids=[], log=MagicMock()) + + def setUp(self): + super(LatestOnlyOperatorTest, self).setUp() self.addCleanup(self.dag.clear) freezer = freeze_time(FROZEN_NOW) freezer.start() @@ -86,6 +111,7 @@ def test_skipping(self): downstream_task2.run(start_date=DEFAULT_DATE, end_date=END_DATE) latest_instances = get_task_instances('latest') + self.dag_file_processor._process_task_instances(self.dag, task_instances_list=latest_instances) exec_date_to_latest_state = { ti.execution_date: ti.state for ti in latest_instances} self.assertEqual({ @@ -95,6 +121,7 @@ def test_skipping(self): exec_date_to_latest_state) downstream_instances = get_task_instances('downstream') + self.dag_file_processor._process_task_instances(self.dag, task_instances_list=downstream_instances) exec_date_to_downstream_state = { ti.execution_date: ti.state for ti in downstream_instances} self.assertEqual({ @@ -104,6 +131,7 @@ def test_skipping(self): exec_date_to_downstream_state) downstream_instances = get_task_instances('downstream_2') + self.dag_file_processor._process_task_instances(self.dag, task_instances_list=downstream_instances) exec_date_to_downstream_state = { ti.execution_date: ti.state for ti in downstream_instances} self.assertEqual({ @@ -126,32 +154,13 @@ def test_skipping_dagrun(self): downstream_task.set_upstream(latest_task) downstream_task2.set_upstream(downstream_task) - self.dag.create_dagrun( - run_id="manual__1", - start_date=timezone.utcnow(), - execution_date=DEFAULT_DATE, - state=State.RUNNING - ) - - self.dag.create_dagrun( - run_id="manual__2", - start_date=timezone.utcnow(), - execution_date=timezone.datetime(2016, 1, 1, 12), - state=State.RUNNING - ) - - self.dag.create_dagrun( - run_id="manual__3", - start_date=timezone.utcnow(), - execution_date=END_DATE, - state=State.RUNNING - ) - latest_task.run(start_date=DEFAULT_DATE, end_date=END_DATE) downstream_task.run(start_date=DEFAULT_DATE, end_date=END_DATE) downstream_task2.run(start_date=DEFAULT_DATE, end_date=END_DATE) latest_instances = get_task_instances('latest') + self.dag_file_processor._process_task_instances(self.dag, task_instances_list=latest_instances) + exec_date_to_latest_state = { ti.execution_date: ti.state for ti in latest_instances} self.assertEqual({ @@ -161,6 +170,8 @@ def test_skipping_dagrun(self): exec_date_to_latest_state) downstream_instances = get_task_instances('downstream') + self.dag_file_processor._process_task_instances(self.dag, task_instances_list=downstream_instances) + exec_date_to_downstream_state = { ti.execution_date: ti.state for ti in downstream_instances} self.assertEqual({ @@ -170,6 +181,7 @@ def test_skipping_dagrun(self): exec_date_to_downstream_state) downstream_instances = get_task_instances('downstream_2') + self.dag_file_processor._process_task_instances(self.dag, task_instances_list=downstream_instances) exec_date_to_downstream_state = { ti.execution_date: ti.state for ti in downstream_instances} self.assertEqual({ diff --git a/tests/operators/test_python_operator.py b/tests/operators/test_python_operator.py index 6f3dfe2703df2d..a92213add8a343 100644 --- a/tests/operators/test_python_operator.py +++ b/tests/operators/test_python_operator.py @@ -32,6 +32,7 @@ from airflow.exceptions import AirflowException from airflow.models import TaskInstance as TI, DAG, DagRun +from airflow.models.taskinstance import clear_task_instances from airflow.operators.dummy_operator import DummyOperator from airflow.operators.python_operator import PythonOperator, BranchPythonOperator from airflow.operators.python_operator import ShortCircuitOperator @@ -491,7 +492,7 @@ def test_with_skip_in_branch_downstream_dependencies(self): elif ti.task_id == 'branch_2': self.assertEqual(ti.state, State.NONE) else: - raise Exception + raise def test_with_skip_in_branch_downstream_dependencies2(self): self.branch_op = BranchPythonOperator(task_id='make_choice', @@ -520,7 +521,63 @@ def test_with_skip_in_branch_downstream_dependencies2(self): elif ti.task_id == 'branch_2': self.assertEqual(ti.state, State.NONE) else: - raise Exception + raise + + def test_clear_skipped_downstream_task(self): + """ + After a downstream task is skipped by BranchPythonOperator, clearing the skipped task + should not cause it to be executed. + """ + branch_op = BranchPythonOperator(task_id='make_choice', + dag=self.dag, + python_callable=lambda: 'branch_1') + branches = [self.branch_1, self.branch_2] + branch_op >> branches + self.dag.clear() + + dr = self.dag.create_dagrun( + run_id="manual__", + start_date=timezone.utcnow(), + execution_date=DEFAULT_DATE, + state=State.RUNNING + ) + + branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + + for task in branches: + task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + + tis = dr.get_task_instances() + for ti in tis: + if ti.task_id == 'make_choice': + self.assertEqual(ti.state, State.SUCCESS) + elif ti.task_id == 'branch_1': + self.assertEqual(ti.state, State.SUCCESS) + elif ti.task_id == 'branch_2': + self.assertEqual(ti.state, State.SKIPPED) + else: + raise + + children_tis = [ti for ti in tis if ti.task_id in branch_op.get_direct_relative_ids()] + + # Clear the children tasks. + with create_session() as session: + clear_task_instances(children_tis, session=session, dag=self.dag) + + # Run the cleared tasks again. + for task in branches: + task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + + # Check if the states are correct after children tasks are cleared. + for ti in dr.get_task_instances(): + if ti.task_id == 'make_choice': + self.assertEqual(ti.state, State.SUCCESS) + elif ti.task_id == 'branch_1': + self.assertEqual(ti.state, State.SUCCESS) + elif ti.task_id == 'branch_2': + self.assertEqual(ti.state, State.SKIPPED) + else: + raise class ShortCircuitOperatorTest(unittest.TestCase): @@ -660,3 +717,61 @@ def test_with_dag_run(self): self.assertEqual(ti.state, State.NONE) else: raise + + def test_clear_skipped_downstream_task(self): + """ + After a downstream task is skipped by ShortCircuitOperator, clearing the skipped task + should not cause it to be executed. + """ + dag = DAG('shortcircuit_clear_skipped_downstream_task', + default_args={ + 'owner': 'airflow', + 'start_date': DEFAULT_DATE + }, + schedule_interval=INTERVAL) + short_op = ShortCircuitOperator(task_id='make_choice', + dag=dag, + python_callable=lambda: False) + downstream = DummyOperator(task_id='downstream', dag=dag) + + short_op >> downstream + + dag.clear() + + dr = dag.create_dagrun( + run_id="manual__", + start_date=timezone.utcnow(), + execution_date=DEFAULT_DATE, + state=State.RUNNING + ) + + short_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + downstream.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + + tis = dr.get_task_instances() + + for ti in tis: + if ti.task_id == 'make_choice': + self.assertEqual(ti.state, State.SUCCESS) + elif ti.task_id == 'downstream': + self.assertEqual(ti.state, State.SKIPPED) + else: + raise + + # Clear downstream + with create_session() as session: + clear_task_instances([t for t in tis if t.task_id == "downstream"], + session=session, + dag=dag) + + # Run downstream again + downstream.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + + # Check if the states are correct. + for ti in dr.get_task_instances(): + if ti.task_id == 'make_choice': + self.assertEqual(ti.state, State.SUCCESS) + elif ti.task_id == 'downstream': + self.assertEqual(ti.state, State.SKIPPED) + else: + raise diff --git a/tests/ti_deps/deps/test_not_previously_skipped_dep.py b/tests/ti_deps/deps/test_not_previously_skipped_dep.py new file mode 100644 index 00000000000000..30da9cf2bc5cd7 --- /dev/null +++ b/tests/ti_deps/deps/test_not_previously_skipped_dep.py @@ -0,0 +1,133 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pendulum + +from airflow.models import DAG, TaskInstance +from airflow.operators.dummy_operator import DummyOperator +from airflow.operators.python_operator import BranchPythonOperator +from airflow.ti_deps.dep_context import DepContext +from airflow.ti_deps.deps.not_previously_skipped_dep import NotPreviouslySkippedDep +from airflow.utils.db import create_session +from airflow.utils.state import State + + +def test_no_parent(): + """ + A simple DAG with a single task. NotPreviouslySkippedDep is met. + """ + start_date = pendulum.datetime(2020, 1, 1) + dag = DAG("test_test_no_parent_dag", schedule_interval=None, start_date=start_date) + op1 = DummyOperator(task_id="op1", dag=dag) + + ti1 = TaskInstance(op1, start_date) + + with create_session() as session: + dep = NotPreviouslySkippedDep() + assert len(list(dep.get_dep_statuses(ti1, session, DepContext()))) == 0 + assert dep.is_met(ti1, session) + assert ti1.state != State.SKIPPED + + +def test_no_skipmixin_parent(): + """ + A simple DAG with no branching. Both op1 and op2 are DummyOperator. NotPreviouslySkippedDep is met. + """ + start_date = pendulum.datetime(2020, 1, 1) + dag = DAG( + "test_no_skipmixin_parent_dag", schedule_interval=None, start_date=start_date + ) + op1 = DummyOperator(task_id="op1", dag=dag) + op2 = DummyOperator(task_id="op2", dag=dag) + op1 >> op2 + + ti2 = TaskInstance(op2, start_date) + + with create_session() as session: + dep = NotPreviouslySkippedDep() + assert len(list(dep.get_dep_statuses(ti2, session, DepContext()))) == 0 + assert dep.is_met(ti2, session) + assert ti2.state != State.SKIPPED + + +def test_parent_follow_branch(): + """ + A simple DAG with a BranchPythonOperator that follows op2. NotPreviouslySkippedDep is met. + """ + start_date = pendulum.datetime(2020, 1, 1) + dag = DAG( + "test_parent_follow_branch_dag", schedule_interval=None, start_date=start_date + ) + op1 = BranchPythonOperator(task_id="op1", python_callable=lambda: "op2", dag=dag) + op2 = DummyOperator(task_id="op2", dag=dag) + op1 >> op2 + + TaskInstance(op1, start_date).run() + ti2 = TaskInstance(op2, start_date) + + with create_session() as session: + dep = NotPreviouslySkippedDep() + assert len(list(dep.get_dep_statuses(ti2, session, DepContext()))) == 0 + assert dep.is_met(ti2, session) + assert ti2.state != State.SKIPPED + + +def test_parent_skip_branch(): + """ + A simple DAG with a BranchPythonOperator that does not follow op2. NotPreviouslySkippedDep is not met. + """ + start_date = pendulum.datetime(2020, 1, 1) + dag = DAG( + "test_parent_skip_branch_dag", schedule_interval=None, start_date=start_date + ) + op1 = BranchPythonOperator(task_id="op1", python_callable=lambda: "op3", dag=dag) + op2 = DummyOperator(task_id="op2", dag=dag) + op3 = DummyOperator(task_id="op3", dag=dag) + op1 >> [op2, op3] + + TaskInstance(op1, start_date).run() + ti2 = TaskInstance(op2, start_date) + + with create_session() as session: + dep = NotPreviouslySkippedDep() + assert len(list(dep.get_dep_statuses(ti2, session, DepContext()))) == 1 + assert not dep.is_met(ti2, session) + assert ti2.state == State.SKIPPED + + +def test_parent_not_executed(): + """ + A simple DAG with a BranchPythonOperator that does not follow op2. Parent task is not yet + executed (no xcom data). NotPreviouslySkippedDep is met (no decision). + """ + start_date = pendulum.datetime(2020, 1, 1) + dag = DAG( + "test_parent_not_executed_dag", schedule_interval=None, start_date=start_date + ) + op1 = BranchPythonOperator(task_id="op1", python_callable=lambda: "op3", dag=dag) + op2 = DummyOperator(task_id="op2", dag=dag) + op3 = DummyOperator(task_id="op3", dag=dag) + op1 >> [op2, op3] + + ti2 = TaskInstance(op2, start_date) + + with create_session() as session: + dep = NotPreviouslySkippedDep() + assert len(list(dep.get_dep_statuses(ti2, session, DepContext()))) == 0 + assert dep.is_met(ti2, session) + assert ti2.state == State.NONE diff --git a/tests/ti_deps/deps/test_trigger_rule_dep.py b/tests/ti_deps/deps/test_trigger_rule_dep.py index 45514f6ca20fb9..82550159785b0a 100644 --- a/tests/ti_deps/deps/test_trigger_rule_dep.py +++ b/tests/ti_deps/deps/test_trigger_rule_dep.py @@ -165,6 +165,46 @@ def test_all_success_tr_failure(self): self.assertEqual(len(dep_statuses), 1) self.assertFalse(dep_statuses[0].passed) + def test_all_success_tr_skip(self): + """ + All-success trigger rule fails when some upstream tasks are skipped. + """ + ti = self._get_task_instance(TriggerRule.ALL_SUCCESS, + upstream_task_ids=["FakeTaskID", + "OtherFakeTaskID"]) + dep_statuses = tuple(TriggerRuleDep()._evaluate_trigger_rule( + ti=ti, + successes=1, + skipped=1, + failed=0, + upstream_failed=0, + done=2, + flag_upstream_failed=False, + session="Fake Session")) + self.assertEqual(len(dep_statuses), 1) + self.assertFalse(dep_statuses[0].passed) + + def test_all_success_tr_skip_flag_upstream(self): + """ + All-success trigger rule fails when some upstream tasks are skipped. The state of the ti + should be set to SKIPPED when flag_upstream_failed is True. + """ + ti = self._get_task_instance(TriggerRule.ALL_SUCCESS, + upstream_task_ids=["FakeTaskID", + "OtherFakeTaskID"]) + dep_statuses = tuple(TriggerRuleDep()._evaluate_trigger_rule( + ti=ti, + successes=1, + skipped=1, + failed=0, + upstream_failed=0, + done=2, + flag_upstream_failed=True, + session=Mock())) + self.assertEqual(len(dep_statuses), 1) + self.assertFalse(dep_statuses[0].passed) + self.assertEqual(ti.state, State.SKIPPED) + def test_none_failed_tr_success(self): """ All success including skip trigger rule success