Skip to content

Commit

Permalink
Revert changes related to create_branch_join_task()
Browse files Browse the repository at this point in the history
  • Loading branch information
yuqian90 committed Mar 25, 2020
1 parent f70d643 commit 0b051da
Show file tree
Hide file tree
Showing 7 changed files with 15 additions and 178 deletions.
13 changes: 6 additions & 7 deletions UPDATING.md
Original file line number Diff line number Diff line change
Expand Up @@ -213,14 +213,13 @@ The following methods were moved:
| airflow.providers.google.cloud.hooks.bigquery.BigQueryBaseCursor.update_dataset | airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.update_dataset |

### Make behavior of `none_failed` trigger rule consistent with documentation
The behavior of the `none_failed` trigger rule is documented as "all parents have not failed
(`failed` or `upstream_failed`) i.e. all parents have succeeded or been skipped." As previously implemented, the actual
behavior would skip if all parents of a task had also skipped.

This fix may break workflows that depend on the previous behavior. If you really need the old behavior, you can make the task
with ``none_failed`` trigger rule explicitly check the status of its upstream tasks and skip itself if all upstream tasks are
skipped. As an example, look at ``airflow.operators.python.create_branch_join_task()``.
The behavior of the `none_failed` trigger rule is documented as "all parents have not failed (`failed` or
`upstream_failed`) i.e. all parents have succeeded or been skipped." As previously implemented, the actual behavior
would skip if all parents of a task had also skipped.

This may break workflows that depend on the previous behavior.
If you really need the old behavior, you can have your workflow manually check the status of upstream tasks for non-
skipped tasks and respond appropriately.

### Standardize handling http exception in BigQuery

Expand Down
8 changes: 6 additions & 2 deletions airflow/example_dags/example_branch_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

from airflow import DAG
from airflow.operators.dummy_operator import DummyOperator
from airflow.operators.python import BranchPythonOperator, create_branch_join_task
from airflow.operators.python import BranchPythonOperator
from airflow.utils.dates import days_ago

args = {
Expand Down Expand Up @@ -51,7 +51,11 @@
)
run_this_first >> branching

join = create_branch_join_task(task_id='join', branch_operator=branching, dag=dag)
join = DummyOperator(
task_id='join',
trigger_rule='one_success',
dag=dag,
)

for option in options:
t = DummyOperator(
Expand Down
42 changes: 0 additions & 42 deletions airflow/example_dags/example_nested_branch_dag.py

This file was deleted.

28 changes: 0 additions & 28 deletions airflow/operators/python.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,34 +165,6 @@ def execute(self, context: Dict):
return branch


def create_branch_join_task(branch_operator, *args, **kwargs):
"""
Create a join task for a branching logic. This join task is always executed regardless
of which branches are followed. It is only skipped if the ``branch_operator`` is skipped.
"""
def python_callable(ti, **_):
from airflow.utils.session import create_session
from airflow.exceptions import AirflowSkipException
from airflow.utils.state import State
from airflow.models import TaskInstance

with create_session() as session:
branch_ti = session.query(TaskInstance).filter(
TaskInstance.dag_id == ti.dag_id,
TaskInstance.task_id == branch_operator.task_id,
TaskInstance.execution_date == ti.execution_date
).one_or_none()

if not branch_ti:
return

if branch_ti.state == State.SKIPPED:
raise AirflowSkipException(f"Skipping because parent task {branch_operator.task_id} "
"is skipped.")

return PythonOperator(trigger_rule="none_failed", python_callable=python_callable, *args, **kwargs)


class ShortCircuitOperator(PythonOperator, SkipMixin):
"""
Allows a workflow to continue only if a condition is met. Otherwise, the
Expand Down
32 changes: 2 additions & 30 deletions docs/concepts.rst
Original file line number Diff line number Diff line change
Expand Up @@ -696,8 +696,8 @@ based on an arbitrary condition which is typically related to something
that happened in an upstream task. One way to do this is by using the
``BranchPythonOperator``.

The ``BranchPythonOperator`` is a ``PythonOperator`` that expects a
``python_callable`` that returns a task_id (or list of task_ids). The
The ``BranchPythonOperator`` is much like the PythonOperator except that it
expects a ``python_callable`` that returns a task_id (or list of task_ids). The
task_id returned is followed, and all of the other paths are skipped.
The task_id returned by the Python function has to reference a task
directly downstream from the BranchPythonOperator task.
Expand Down Expand Up @@ -760,34 +760,6 @@ or a list of task IDs, which will be run, and all others will be skipped.
return 'daily_task_id'
Nested Branching
================

If you have more complicated workflows with multiple levels of branching,
be careful about how the branches are joined together. The join tasks are usually
given ``none_failed`` trigger rule, which are executed even if all its parents are
skipped. But the join task in nested branching should logically be skipped whenever
its parent ``BranchPythonOperator`` is skipped. Use ``create_branch_join_task()``
to create the join task that does this.

.. code:: python
branch_1 = BranchPythonOperator(task_id="branch_1", python_callable=lambda: "true_1")
join_1 = create_branch_join_task(task_id="join_1", branch_operator=branch_1)
true_1 = DummyOperator(task_id="true_1")
false_1 = DummyOperator(task_id="false_1")
branch_2 = BranchPythonOperator(task_id="branch_2", python_callable=lambda: "true_2")
join_2 = create_branch_join_task(task_id="join_2", branch_operator=branch_2)
true_2 = DummyOperator(task_id="true_2")
false_2 = DummyOperator(task_id="false_2")
false_3 = DummyOperator(task_id="false_3")
branch_1 >> true_1 >> join_1
branch_1 >> false_1 >> branch_2 >> [true_2, false_2] >> join_2 >> false_3 >> join_1
.. image:: img/nested_branching.png


SubDAGs
=======

Expand Down
69 changes: 0 additions & 69 deletions tests/operators/test_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,13 @@
from typing import List

import funcsigs
import pendulum
import pytest

from airflow.exceptions import AirflowException
from airflow.models import DAG, DagRun, TaskInstance as TI
from airflow.models.taskinstance import clear_task_instances
from airflow.operators.dummy_operator import DummyOperator
from airflow.operators.python import (
BranchPythonOperator, PythonOperator, PythonVirtualenvOperator, ShortCircuitOperator,
create_branch_join_task,
)
from airflow.utils import timezone
from airflow.utils.session import create_session
Expand Down Expand Up @@ -926,69 +923,3 @@ def test_context(self):
def f(templates_dict):
return templates_dict['ds']
self._run_as_operator(f, templates_dict={'ds': '{{ ds }}'})


def test_create_branch_join_task():
"""
Test create_branch_join_task creates a PythonOperator with none_failed trigger_rule and skip
itslef when its parent branch operator is skipped.
"""
from airflow.exceptions import AirflowSkipException
from airflow.utils.trigger_rule import TriggerRule

execution_date = pendulum.datetime(2020, 1, 1)
dag = DAG(dag_id="test_create_branch_join_task", start_date=execution_date)
branch = BranchPythonOperator(task_id="branch", python_callable=lambda: "not_exist", dag=dag)
join = create_branch_join_task(branch, task_id="join", dag=dag)

assert join.trigger_rule == TriggerRule.NONE_FAILED

with create_session() as session:
branch_ti = TI(branch, execution_date)
session.add(branch_ti)
branch_ti.state = State.SUCCESS
context = {"ti": TI(join, execution_date=execution_date)}
# If branch is SUCCESS, join should not raise errors.
join.execute(context)

branch_ti.state = State.SKIPPED
session.merge(branch_ti)

# If branch is SKIPPED, join should skip itself.
with pytest.raises(AirflowSkipException):
join.execute(context)

session.rollback()


def test_nested_branch_python_operator():
"""
Test nested branching logic in example_nested_branch_dag.py.
When branch_1 skips false_1, false_3 should skip itself.
"""
from airflow.models import DagBag

dag = DagBag().get_dag("example_nested_branch_dag")

assert dag

execution_date = dag.start_date

dag_run = dag.create_dagrun(
run_id=f'manual__{execution_date.isoformat()}',
execution_date=execution_date,
state=State.RUNNING
)
for task in dag.topological_sort():
TI(task, execution_date=execution_date).run()
dag_run.update_state()

with create_session() as session:
def get_state(task_id):
return session.query(TI).filter(TI.dag_id == dag.dag_id,
TI.execution_date == execution_date,
TI.task_id == task_id).one().state

assert get_state("true_1") == State.SUCCESS
assert get_state("false_3") == State.SKIPPED
assert get_state("join_1") == State.SUCCESS
1 change: 1 addition & 0 deletions tests/ti_deps/deps/test_trigger_rule_dep.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# under the License.

import unittest
from unittest.mock import Mock
from datetime import datetime
from unittest.mock import Mock

Expand Down

0 comments on commit 0b051da

Please sign in to comment.