Skip to content

Commit

Permalink
AIRFLOW-[3823] Exclude branch's downstream tasks from the tasks to sk…
Browse files Browse the repository at this point in the history
…ip (#4666)
  • Loading branch information
BasPH authored and ashb committed Feb 11, 2019
1 parent 23662db commit 3edc91c
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 3 deletions.
17 changes: 14 additions & 3 deletions airflow/operators/python_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,11 @@
import subprocess
import sys
import types
from builtins import str
from textwrap import dedent

import dill
from builtins import str
import six

from airflow.exceptions import AirflowException
from airflow.models import BaseOperator, SkipMixin
Expand Down Expand Up @@ -129,16 +130,26 @@ class BranchPythonOperator(PythonOperator, SkipMixin):
"""
def execute(self, context):
branch = super(BranchPythonOperator, self).execute(context)
if isinstance(branch, str):
if isinstance(branch, six.string_types):
branch = [branch]
self.log.info("Following branch %s", branch)
self.log.info("Marking other directly downstream tasks as skipped")

downstream_tasks = context['task'].downstream_list
self.log.debug("Downstream task_ids %s", downstream_tasks)

skip_tasks = [t for t in downstream_tasks if t.task_id not in branch]
if downstream_tasks:
# Also check downstream tasks of the branch task. In case the task to skip
# is a downstream task of the branch task, we exclude it from skipping.
branch_downstream_task_ids = set()
for b in branch:
branch_downstream_task_ids.update(context["dag"].
get_task(b).
get_flat_relative_ids(upstream=False))
skip_tasks = [t
for t in downstream_tasks
if t.task_id not in branch and
t.task_id not in branch_downstream_task_ids]
self.skip(context['dag_run'], context['ti'].execution_date, skip_tasks)

self.log.info("Done.")
Expand Down
58 changes: 58 additions & 0 deletions tests/operators/test_python_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,64 @@ def test_with_dag_run(self):
else:
raise

def test_with_skip_in_branch_downstream_dependencies(self):
self.branch_op = BranchPythonOperator(task_id='make_choice',
dag=self.dag,
python_callable=lambda: 'branch_1')

self.branch_op >> self.branch_1 >> self.branch_2
self.branch_op >> self.branch_2
self.dag.clear()

dr = self.dag.create_dagrun(
run_id="manual__",
start_date=timezone.utcnow(),
execution_date=DEFAULT_DATE,
state=State.RUNNING
)

self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)

tis = dr.get_task_instances()
for ti in tis:
if ti.task_id == 'make_choice':
self.assertEqual(ti.state, State.SUCCESS)
elif ti.task_id == 'branch_1':
self.assertEqual(ti.state, State.NONE)
elif ti.task_id == 'branch_2':
self.assertEqual(ti.state, State.NONE)
else:
raise Exception

def test_with_skip_in_branch_downstream_dependencies2(self):
self.branch_op = BranchPythonOperator(task_id='make_choice',
dag=self.dag,
python_callable=lambda: 'branch_2')

self.branch_op >> self.branch_1 >> self.branch_2
self.branch_op >> self.branch_2
self.dag.clear()

dr = self.dag.create_dagrun(
run_id="manual__",
start_date=timezone.utcnow(),
execution_date=DEFAULT_DATE,
state=State.RUNNING
)

self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)

tis = dr.get_task_instances()
for ti in tis:
if ti.task_id == 'make_choice':
self.assertEqual(ti.state, State.SUCCESS)
elif ti.task_id == 'branch_1':
self.assertEqual(ti.state, State.SKIPPED)
elif ti.task_id == 'branch_2':
self.assertEqual(ti.state, State.NONE)
else:
raise Exception


class ShortCircuitOperatorTest(unittest.TestCase):
@classmethod
Expand Down

0 comments on commit 3edc91c

Please sign in to comment.