diff --git a/UPDATING.md b/UPDATING.md index ad8a1c03f97cb8..b52981ba428945 100644 --- a/UPDATING.md +++ b/UPDATING.md @@ -213,14 +213,13 @@ The following methods were moved: | airflow.providers.google.cloud.hooks.bigquery.BigQueryBaseCursor.update_dataset | airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.update_dataset | ### Make behavior of `none_failed` trigger rule consistent with documentation -The behavior of the `none_failed` trigger rule is documented as "all parents have not failed -(`failed` or `upstream_failed`) i.e. all parents have succeeded or been skipped." As previously implemented, the actual -behavior would skip if all parents of a task had also skipped. - -This fix may break workflows that depend on the previous behavior. If you really need the old behavior, you can make the task -with ``none_failed`` trigger rule explicitly check the status of its upstream tasks and skip itself if all upstream tasks are -skipped. As an example, look at ``airflow.operators.python.create_branch_join_task()``. +The behavior of the `none_failed` trigger rule is documented as "all parents have not failed (`failed` or + `upstream_failed`) i.e. all parents have succeeded or been skipped." As previously implemented, the actual behavior + would skip if all parents of a task had also skipped. +This may break workflows that depend on the previous behavior. + If you really need the old behavior, you can have your workflow manually check the status of upstream tasks for non- + skipped tasks and respond appropriately. ### Standardize handling http exception in BigQuery diff --git a/airflow/example_dags/example_branch_operator.py b/airflow/example_dags/example_branch_operator.py index ce20de7efa7db7..61c9ac12d12103 100644 --- a/airflow/example_dags/example_branch_operator.py +++ b/airflow/example_dags/example_branch_operator.py @@ -22,7 +22,7 @@ from airflow import DAG from airflow.operators.dummy_operator import DummyOperator -from airflow.operators.python import BranchPythonOperator, create_branch_join_task +from airflow.operators.python import BranchPythonOperator from airflow.utils.dates import days_ago args = { @@ -51,7 +51,11 @@ ) run_this_first >> branching -join = create_branch_join_task(task_id='join', branch_operator=branching, dag=dag) +join = DummyOperator( + task_id='join', + trigger_rule='one_success', + dag=dag, +) for option in options: t = DummyOperator( diff --git a/airflow/example_dags/example_nested_branch_dag.py b/airflow/example_dags/example_nested_branch_dag.py deleted file mode 100644 index d98b338c9dde77..00000000000000 --- a/airflow/example_dags/example_nested_branch_dag.py +++ /dev/null @@ -1,42 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -""" -Example DAG demonstrating a workflow with nested branching. The join tasks are created with -``create_branch_join_task()`` which returns a ``PythonOperator`` with ``none_failed`` -trigger rule that skips itself whenever its corresponding ``BranchPythonOperator`` is skipped. -""" - -from airflow.models import DAG -from airflow.operators.dummy_operator import DummyOperator -from airflow.operators.python import BranchPythonOperator, create_branch_join_task -from airflow.utils.dates import days_ago - -with DAG(dag_id="example_nested_branch_dag", start_date=days_ago(2), schedule_interval="@daily") as dag: - branch_1 = BranchPythonOperator(task_id="branch_1", python_callable=lambda: "true_1") - join_1 = create_branch_join_task(task_id="join_1", branch_operator=branch_1) - true_1 = DummyOperator(task_id="true_1") - false_1 = DummyOperator(task_id="false_1") - branch_2 = BranchPythonOperator(task_id="branch_2", python_callable=lambda: "true_2") - join_2 = create_branch_join_task(task_id="join_2", branch_operator=branch_2) - true_2 = DummyOperator(task_id="true_2") - false_2 = DummyOperator(task_id="false_2") - false_3 = DummyOperator(task_id="false_3") - - branch_1 >> true_1 >> join_1 - branch_1 >> false_1 >> branch_2 >> [true_2, false_2] >> join_2 >> false_3 >> join_1 diff --git a/airflow/operators/python.py b/airflow/operators/python.py index e619d223f922ee..0fac9f087dcc82 100644 --- a/airflow/operators/python.py +++ b/airflow/operators/python.py @@ -165,34 +165,6 @@ def execute(self, context: Dict): return branch -def create_branch_join_task(branch_operator, *args, **kwargs): - """ - Create a join task for a branching logic. This join task is always executed regardless - of which branches are followed. It is only skipped if the ``branch_operator`` is skipped. - """ - def python_callable(ti, **_): - from airflow.utils.session import create_session - from airflow.exceptions import AirflowSkipException - from airflow.utils.state import State - from airflow.models import TaskInstance - - with create_session() as session: - branch_ti = session.query(TaskInstance).filter( - TaskInstance.dag_id == ti.dag_id, - TaskInstance.task_id == branch_operator.task_id, - TaskInstance.execution_date == ti.execution_date - ).one_or_none() - - if not branch_ti: - return - - if branch_ti.state == State.SKIPPED: - raise AirflowSkipException(f"Skipping because parent task {branch_operator.task_id} " - "is skipped.") - - return PythonOperator(trigger_rule="none_failed", python_callable=python_callable, *args, **kwargs) - - class ShortCircuitOperator(PythonOperator, SkipMixin): """ Allows a workflow to continue only if a condition is met. Otherwise, the diff --git a/docs/concepts.rst b/docs/concepts.rst index b65a5e855fc3ea..04167f46359da7 100644 --- a/docs/concepts.rst +++ b/docs/concepts.rst @@ -696,8 +696,8 @@ based on an arbitrary condition which is typically related to something that happened in an upstream task. One way to do this is by using the ``BranchPythonOperator``. -The ``BranchPythonOperator`` is a ``PythonOperator`` that expects a -``python_callable`` that returns a task_id (or list of task_ids). The +The ``BranchPythonOperator`` is much like the PythonOperator except that it +expects a ``python_callable`` that returns a task_id (or list of task_ids). The task_id returned is followed, and all of the other paths are skipped. The task_id returned by the Python function has to reference a task directly downstream from the BranchPythonOperator task. @@ -760,34 +760,6 @@ or a list of task IDs, which will be run, and all others will be skipped. return 'daily_task_id' -Nested Branching -================ - -If you have more complicated workflows with multiple levels of branching, -be careful about how the branches are joined together. The join tasks are usually -given ``none_failed`` trigger rule, which are executed even if all its parents are -skipped. But the join task in nested branching should logically be skipped whenever -its parent ``BranchPythonOperator`` is skipped. Use ``create_branch_join_task()`` -to create the join task that does this. - -.. code:: python - - branch_1 = BranchPythonOperator(task_id="branch_1", python_callable=lambda: "true_1") - join_1 = create_branch_join_task(task_id="join_1", branch_operator=branch_1) - true_1 = DummyOperator(task_id="true_1") - false_1 = DummyOperator(task_id="false_1") - branch_2 = BranchPythonOperator(task_id="branch_2", python_callable=lambda: "true_2") - join_2 = create_branch_join_task(task_id="join_2", branch_operator=branch_2) - true_2 = DummyOperator(task_id="true_2") - false_2 = DummyOperator(task_id="false_2") - false_3 = DummyOperator(task_id="false_3") - - branch_1 >> true_1 >> join_1 - branch_1 >> false_1 >> branch_2 >> [true_2, false_2] >> join_2 >> false_3 >> join_1 - -.. image:: img/nested_branching.png - - SubDAGs ======= diff --git a/tests/operators/test_python.py b/tests/operators/test_python.py index 1303bebe271ab7..ab2313cfc1eae8 100644 --- a/tests/operators/test_python.py +++ b/tests/operators/test_python.py @@ -28,8 +28,6 @@ from typing import List import funcsigs -import pendulum -import pytest from airflow.exceptions import AirflowException from airflow.models import DAG, DagRun, TaskInstance as TI @@ -37,7 +35,6 @@ from airflow.operators.dummy_operator import DummyOperator from airflow.operators.python import ( BranchPythonOperator, PythonOperator, PythonVirtualenvOperator, ShortCircuitOperator, - create_branch_join_task, ) from airflow.utils import timezone from airflow.utils.session import create_session @@ -926,69 +923,3 @@ def test_context(self): def f(templates_dict): return templates_dict['ds'] self._run_as_operator(f, templates_dict={'ds': '{{ ds }}'}) - - -def test_create_branch_join_task(): - """ - Test create_branch_join_task creates a PythonOperator with none_failed trigger_rule and skip - itslef when its parent branch operator is skipped. - """ - from airflow.exceptions import AirflowSkipException - from airflow.utils.trigger_rule import TriggerRule - - execution_date = pendulum.datetime(2020, 1, 1) - dag = DAG(dag_id="test_create_branch_join_task", start_date=execution_date) - branch = BranchPythonOperator(task_id="branch", python_callable=lambda: "not_exist", dag=dag) - join = create_branch_join_task(branch, task_id="join", dag=dag) - - assert join.trigger_rule == TriggerRule.NONE_FAILED - - with create_session() as session: - branch_ti = TI(branch, execution_date) - session.add(branch_ti) - branch_ti.state = State.SUCCESS - context = {"ti": TI(join, execution_date=execution_date)} - # If branch is SUCCESS, join should not raise errors. - join.execute(context) - - branch_ti.state = State.SKIPPED - session.merge(branch_ti) - - # If branch is SKIPPED, join should skip itself. - with pytest.raises(AirflowSkipException): - join.execute(context) - - session.rollback() - - -def test_nested_branch_python_operator(): - """ - Test nested branching logic in example_nested_branch_dag.py. - When branch_1 skips false_1, false_3 should skip itself. - """ - from airflow.models import DagBag - - dag = DagBag().get_dag("example_nested_branch_dag") - - assert dag - - execution_date = dag.start_date - - dag_run = dag.create_dagrun( - run_id=f'manual__{execution_date.isoformat()}', - execution_date=execution_date, - state=State.RUNNING - ) - for task in dag.topological_sort(): - TI(task, execution_date=execution_date).run() - dag_run.update_state() - - with create_session() as session: - def get_state(task_id): - return session.query(TI).filter(TI.dag_id == dag.dag_id, - TI.execution_date == execution_date, - TI.task_id == task_id).one().state - - assert get_state("true_1") == State.SUCCESS - assert get_state("false_3") == State.SKIPPED - assert get_state("join_1") == State.SUCCESS diff --git a/tests/ti_deps/deps/test_trigger_rule_dep.py b/tests/ti_deps/deps/test_trigger_rule_dep.py index 912df6916f0eea..9604136a581596 100644 --- a/tests/ti_deps/deps/test_trigger_rule_dep.py +++ b/tests/ti_deps/deps/test_trigger_rule_dep.py @@ -17,6 +17,7 @@ # under the License. import unittest +from unittest.mock import Mock from datetime import datetime from unittest.mock import Mock