diff --git a/UPDATING.md b/UPDATING.md index b52981ba428945..2aaac55300b2d7 100644 --- a/UPDATING.md +++ b/UPDATING.md @@ -217,9 +217,9 @@ The behavior of the `none_failed` trigger rule is documented as "all parents hav `upstream_failed`) i.e. all parents have succeeded or been skipped." As previously implemented, the actual behavior would skip if all parents of a task had also skipped. -This may break workflows that depend on the previous behavior. - If you really need the old behavior, you can have your workflow manually check the status of upstream tasks for non- - skipped tasks and respond appropriately. +### Add new trigger rule `none_failed_or_skipped` +The fix to `none_failed` trigger rule breaks workflows that depend on the previous behavior. + If you need the old behavior, you should change the tasks with `none_failed` trigger rule to `none_failed_or_skipped`. ### Standardize handling http exception in BigQuery diff --git a/airflow/example_dags/example_branch_operator.py b/airflow/example_dags/example_branch_operator.py index 61c9ac12d12103..0b37738360d358 100644 --- a/airflow/example_dags/example_branch_operator.py +++ b/airflow/example_dags/example_branch_operator.py @@ -53,7 +53,7 @@ join = DummyOperator( task_id='join', - trigger_rule='one_success', + trigger_rule='none_failed_or_skipped', dag=dag, ) diff --git a/airflow/example_dags/example_nested_branch_dag.py b/airflow/example_dags/example_nested_branch_dag.py new file mode 100644 index 00000000000000..90653a62196be6 --- /dev/null +++ b/airflow/example_dags/example_nested_branch_dag.py @@ -0,0 +1,42 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Example DAG demonstrating a workflow with nested branching. The join tasks are created with +``none_failed_or_skipped`` trigger rule such that they are skipped whenever their corresponding +``BranchPythonOperator`` are skipped. +""" + +from airflow.models import DAG +from airflow.operators.dummy_operator import DummyOperator +from airflow.operators.python import BranchPythonOperator +from airflow.utils.dates import days_ago + +with DAG(dag_id="example_nested_branch_dag", start_date=days_ago(2), schedule_interval="@daily") as dag: + branch_1 = BranchPythonOperator(task_id="branch_1", python_callable=lambda: "true_1") + join_1 = DummyOperator(task_id="join_1", trigger_rule="none_failed_or_skipped") + true_1 = DummyOperator(task_id="true_1") + false_1 = DummyOperator(task_id="false_1") + branch_2 = BranchPythonOperator(task_id="branch_2", python_callable=lambda: "true_2") + join_2 = DummyOperator(task_id="join_2", trigger_rule="none_failed_or_skipped") + true_2 = DummyOperator(task_id="true_2") + false_2 = DummyOperator(task_id="false_2") + false_3 = DummyOperator(task_id="false_3") + + branch_1 >> true_1 >> join_1 + branch_1 >> false_1 >> branch_2 >> [true_2, false_2] >> join_2 >> false_3 >> join_1 diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index e0e393c5adba2f..2e431ca21dcf9b 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -212,7 +212,7 @@ class derived from this one results in the creation of a task object, :param trigger_rule: defines the rule by which dependencies are applied for the task to get triggered. Options are: ``{ all_success | all_failed | all_done | one_success | - one_failed | none_failed | none_skipped | dummy}`` + one_failed | none_failed | none_failed_or_skipped | none_skipped | dummy}`` default is ``all_success``. Options can be set as string or using the constants defined in the static class ``airflow.utils.TriggerRule`` diff --git a/airflow/ti_deps/deps/trigger_rule_dep.py b/airflow/ti_deps/deps/trigger_rule_dep.py index 403bfc05cc620e..a816dbf025dcff 100644 --- a/airflow/ti_deps/deps/trigger_rule_dep.py +++ b/airflow/ti_deps/deps/trigger_rule_dep.py @@ -143,6 +143,11 @@ def _evaluate_trigger_rule( # pylint: disable=too-many-branches elif trigger_rule == TR.NONE_FAILED: if upstream_failed or failed: ti.set_state(State.UPSTREAM_FAILED, session) + elif trigger_rule == TR.NONE_FAILED_OR_SKIPPED: + if upstream_failed or failed: + ti.set_state(State.UPSTREAM_FAILED, session) + elif skipped == upstream: + ti.set_state(State.SKIPPED, session) elif trigger_rule == TR.NONE_SKIPPED: if skipped: ti.set_state(State.SKIPPED, session) @@ -197,6 +202,15 @@ def _evaluate_trigger_rule( # pylint: disable=too-many-branches "upstream_tasks_state={2}, upstream_task_ids={3}" .format(trigger_rule, num_failures, upstream_tasks_state, task.upstream_task_ids)) + elif trigger_rule == TR.NONE_FAILED_OR_SKIPPED: + num_failures = upstream - successes - skipped + if num_failures > 0: + yield self._failing_status( + reason="Task's trigger rule '{0}' requires all upstream " + "tasks to have succeeded or been skipped, but found {1} non-success(es). " + "upstream_tasks_state={2}, upstream_task_ids={3}" + .format(trigger_rule, num_failures, upstream_tasks_state, + task.upstream_task_ids)) elif trigger_rule == TR.NONE_SKIPPED: if not upstream_done or (skipped > 0): yield self._failing_status( diff --git a/airflow/utils/trigger_rule.py b/airflow/utils/trigger_rule.py index 48d21c91acc528..87c12c1876db88 100644 --- a/airflow/utils/trigger_rule.py +++ b/airflow/utils/trigger_rule.py @@ -29,6 +29,7 @@ class TriggerRule: ONE_SUCCESS = 'one_success' ONE_FAILED = 'one_failed' NONE_FAILED = 'none_failed' + NONE_FAILED_OR_SKIPPED = 'none_failed_or_skipped' NONE_SKIPPED = 'none_skipped' DUMMY = 'dummy' diff --git a/docs/concepts.rst b/docs/concepts.rst index 04167f46359da7..8c5622ffaf41e0 100644 --- a/docs/concepts.rst +++ b/docs/concepts.rst @@ -874,6 +874,7 @@ while creating tasks: * ``one_failed``: fires as soon as at least one parent has failed, it does not wait for all parents to be done * ``one_success``: fires as soon as at least one parent succeeds, it does not wait for all parents to be done * ``none_failed``: all parents have not failed (``failed`` or ``upstream_failed``) i.e. all parents have succeeded or been skipped +* ``none_failed_or_skipped``: all parents have not failed (``failed`` or ``upstream_failed``) and at least one parent has succeeded. * ``none_skipped``: no parent is in a ``skipped`` state, i.e. all parents are in a ``success``, ``failed``, or ``upstream_failed`` state * ``dummy``: dependencies are just for show, trigger at will @@ -884,7 +885,7 @@ previous schedule for the task hasn't succeeded. One must be aware of the interaction between trigger rules and skipped tasks in schedule level. Skipped tasks will cascade through trigger rules ``all_success`` and ``all_failed`` but not ``all_done``, ``one_failed``, ``one_success``, -``none_failed``, ``none_skipped`` and ``dummy``. +``none_failed``, ``none_failed_or_skipped``, ``none_skipped`` and ``dummy``. For example, consider the following DAG: @@ -927,19 +928,19 @@ skipped tasks will cascade through ``all_success``. .. image:: img/branch_without_trigger.png -By setting ``trigger_rule`` to ``none_failed`` in ``join`` task, +By setting ``trigger_rule`` to ``none_failed_or_skipped`` in ``join`` task, .. code:: python #dags/branch_with_trigger.py ... - join = DummyOperator(task_id='join', dag=dag, trigger_rule='none_failed') + join = DummyOperator(task_id='join', dag=dag, trigger_rule='none_failed_or_skipped') ... The ``join`` task will be triggered as soon as ``branch_false`` has been skipped (a valid completion state) and ``follow_branch_a`` has succeeded. Because skipped tasks **will not** -cascade through ``none_failed``. +cascade through ``none_failed_or_skipped``. .. image:: img/branch_with_trigger.png diff --git a/tests/ti_deps/deps/test_trigger_rule_dep.py b/tests/ti_deps/deps/test_trigger_rule_dep.py index 9604136a581596..039c408a5e8f23 100644 --- a/tests/ti_deps/deps/test_trigger_rule_dep.py +++ b/tests/ti_deps/deps/test_trigger_rule_dep.py @@ -17,7 +17,6 @@ # under the License. import unittest -from unittest.mock import Mock from datetime import datetime from unittest.mock import Mock @@ -267,6 +266,63 @@ def test_none_failed_tr_failure(self): self.assertEqual(len(dep_statuses), 1) self.assertFalse(dep_statuses[0].passed) + def test_none_failed_or_skipped_tr_success(self): + """ + All success including skip trigger rule success + """ + ti = self._get_task_instance(TriggerRule.NONE_FAILED_OR_SKIPPED, + upstream_task_ids=["FakeTaskID", + "OtherFakeTaskID"]) + dep_statuses = tuple(TriggerRuleDep()._evaluate_trigger_rule( + ti=ti, + successes=1, + skipped=1, + failed=0, + upstream_failed=0, + done=2, + flag_upstream_failed=False, + session="Fake Session")) + self.assertEqual(len(dep_statuses), 0) + + def test_none_failed_or_skipped_tr_skipped(self): + """ + All success including all upstream skips trigger rule success + """ + ti = self._get_task_instance(TriggerRule.NONE_FAILED_OR_SKIPPED, + upstream_task_ids=["FakeTaskID", + "OtherFakeTaskID"]) + dep_statuses = tuple(TriggerRuleDep()._evaluate_trigger_rule( + ti=ti, + successes=0, + skipped=2, + failed=0, + upstream_failed=0, + done=2, + flag_upstream_failed=True, + session=Mock())) + self.assertEqual(len(dep_statuses), 0) + self.assertEqual(ti.state, State.SKIPPED) + + def test_none_failed_or_skipped_tr_failure(self): + """ + All success including skip trigger rule failure + """ + ti = self._get_task_instance(TriggerRule.NONE_FAILED_OR_SKIPPED, + upstream_task_ids=["FakeTaskID", + "OtherFakeTaskID", + "FailedFakeTaskID"]) + dep_statuses = tuple(TriggerRuleDep()._evaluate_trigger_rule( + ti=ti, + successes=1, + skipped=1, + failed=1, + upstream_failed=0, + done=3, + flag_upstream_failed=False, + session="Fake Session")) + self.assertEqual(len(dep_statuses), 1) + self.assertFalse(dep_statuses[0].passed) + def test_all_failed_tr_success(self): """ All-failed trigger rule success diff --git a/tests/utils/test_trigger_rule.py b/tests/utils/test_trigger_rule.py index 9321c29fc955db..1ea399ebde3b75 100644 --- a/tests/utils/test_trigger_rule.py +++ b/tests/utils/test_trigger_rule.py @@ -30,6 +30,7 @@ def test_valid_trigger_rules(self): self.assertTrue(TriggerRule.is_valid(TriggerRule.ONE_SUCCESS)) self.assertTrue(TriggerRule.is_valid(TriggerRule.ONE_FAILED)) self.assertTrue(TriggerRule.is_valid(TriggerRule.NONE_FAILED)) + self.assertTrue(TriggerRule.is_valid(TriggerRule.NONE_FAILED_OR_SKIPPED)) self.assertTrue(TriggerRule.is_valid(TriggerRule.NONE_SKIPPED)) self.assertTrue(TriggerRule.is_valid(TriggerRule.DUMMY)) - self.assertEqual(len(TriggerRule.all_triggers()), 8) + self.assertEqual(len(TriggerRule.all_triggers()), 9)